aurora_engine_modexp/
arith.rs

1use crate::{
2    maybe_std::vec,
3    mpnat::{DoubleWord, MPNat, Word, BASE, WORD_BITS},
4};
5
6// Computes the "Montgomery Product" of two numbers.
7// See Coarsely Integrated Operand Scanning (CIOS) Method in
8// https://www.microsoft.com/en-us/research/wp-content/uploads/1996/01/j37acmon.pdf
9// In short, computes `xy (r^-1) mod n`, where `r = 2^(8*4*s)` and `s` is the number of
10// digits needs to represent `n`. `n_prime` has the property that `r(r^(-1)) - nn' = 1`.
11// Note: This algorithm only works if `xy < rn` (generally we will either have both `x < n`, `y < n`
12// or we will have `x < r`, `y < n`).
13pub fn monpro(x: &MPNat, y: &MPNat, n: &MPNat, n_prime: Word, out: &mut [Word]) {
14    debug_assert!(
15        n.is_odd(),
16        "Montgomery multiplication only makes sense with odd modulus"
17    );
18    debug_assert!(
19        out.len() >= n.digits.len() + 2,
20        "Output needs 2 extra words over the size needed to represent n"
21    );
22    let s = out.len() - 2;
23    // Using a range loop as opposed to `out.iter_mut().enumerate().take(s)`
24    // does make a meaningful performance difference in this case.
25    #[allow(clippy::needless_range_loop)]
26    for i in 0..s {
27        let mut c = 0;
28        for j in 0..s {
29            let (prod, carry) = shifted_carrying_mul(
30                out[j],
31                x.digits.get(j).copied().unwrap_or(0),
32                y.digits.get(i).copied().unwrap_or(0),
33                c,
34            );
35            out[j] = prod;
36            c = carry;
37        }
38        let (sum, carry) = carrying_add(out[s], c, false);
39        out[s] = sum;
40        out[s + 1] = carry as Word;
41        let m = out[0].wrapping_mul(n_prime);
42        let (_, carry) = shifted_carrying_mul(out[0], m, n.digits.first().copied().unwrap_or(0), 0);
43        c = carry;
44        for j in 1..s {
45            let (prod, carry) =
46                shifted_carrying_mul(out[j], m, n.digits.get(j).copied().unwrap_or(0), c);
47            out[j - 1] = prod;
48            c = carry;
49        }
50        let (sum, carry) = carrying_add(out[s], c, false);
51        out[s - 1] = sum;
52        out[s] = out[s + 1] + (carry as Word); // overflow impossible at this stage
53    }
54    // Result is only in the first s + 1 words of the output.
55    out[s + 1] = 0;
56
57    // Check if we need to do the final subtraction
58    for i in (0..=s).rev() {
59        match out[i].cmp(n.digits.get(i).unwrap_or(&0)) {
60            core::cmp::Ordering::Less => return, // No subtraction needed
61            core::cmp::Ordering::Greater => break,
62            core::cmp::Ordering::Equal => (),
63        }
64    }
65
66    let mut b = false;
67    for (i, out_digit) in out.iter_mut().enumerate().take(s) {
68        let (diff, borrow) = borrowing_sub(*out_digit, n.digits.get(i).copied().unwrap_or(0), b);
69        *out_digit = diff;
70        b = borrow;
71    }
72    let (diff, borrow) = borrowing_sub(out[s], 0, b);
73    out[s] = diff;
74
75    debug_assert!(!borrow, "No borrow needed since out < n");
76}
77
78// Equivalent to `monpro(x, x, n, n_prime, out)`, but more efficient.
79pub fn monsq(x: &MPNat, n: &MPNat, n_prime: Word, out: &mut [Word]) {
80    debug_assert!(
81        n.is_odd(),
82        "Montgomery multiplication only makes sense with odd modulus"
83    );
84    debug_assert!(
85        x.digits.len() <= n.digits.len(),
86        "x cannot be larger than n"
87    );
88    debug_assert!(
89        out.len() > 2 * n.digits.len(),
90        "Output needs double the digits to hold the value of x^2 plus an extra word"
91    );
92    let s = n.digits.len();
93
94    big_sq(x, out);
95    for i in 0..s {
96        let mut c: Word = 0;
97        let m = out[i].wrapping_mul(n_prime);
98        for j in 0..s {
99            let (prod, carry) =
100                shifted_carrying_mul(out[i + j], m, n.digits.get(j).copied().unwrap_or(0), c);
101            out[i + j] = prod;
102            c = carry;
103        }
104        let mut j = i + s;
105        while c > 0 {
106            let (sum, carry) = carrying_add(out[j], c, false);
107            out[j] = sum;
108            c = carry as Word;
109            j += 1;
110        }
111    }
112    // Only keep the last `s + 1` digits in `out`.
113    for i in 0..=s {
114        out[i] = out[i + s];
115    }
116    out[(s + 1)..].fill(0);
117
118    // Check if we need to do the final subtraction
119    for i in (0..=s).rev() {
120        match out[i].cmp(n.digits.get(i).unwrap_or(&0)) {
121            core::cmp::Ordering::Less => return,
122            core::cmp::Ordering::Greater => break,
123            core::cmp::Ordering::Equal => (),
124        }
125    }
126
127    let mut b = false;
128    for (i, out_digit) in out.iter_mut().enumerate().take(s) {
129        let (diff, borrow) = borrowing_sub(*out_digit, n.digits.get(i).copied().unwrap_or(0), b);
130        *out_digit = diff;
131        b = borrow;
132    }
133    let (diff, borrow) = borrowing_sub(out[s], 0, b);
134    out[s] = diff;
135
136    debug_assert!(!borrow, "No borrow needed since out < n");
137}
138
139// Given x odd, computes `x^(-1) mod 2^32`.
140// See `MODULAR-INVERSE` in https://link.springer.com/content/pdf/10.1007/3-540-46877-3_21.pdf
141pub fn mod_inv(x: Word) -> Word {
142    debug_assert_eq!(x & 1, 1, "Algorithm only valid for odd n");
143
144    let mut y = 1;
145    for i in 2..WORD_BITS {
146        let mask = (1 << i) - 1;
147        let xy = x.wrapping_mul(y) & mask;
148        let q = 1 << (i - 1);
149        if xy >= q {
150            y += q;
151        }
152    }
153    let xy = x.wrapping_mul(y);
154    let q = 1 << (WORD_BITS - 1);
155    if xy >= q {
156        y += q;
157    }
158    y
159}
160
161/// Computes R mod n, where R = `2^(WORD_BITS*k)` and k = `n.digits.len()`
162/// Note that if R = qn + r, q must be smaller than `2^WORD_BITS` since `2^(WORD_BITS) * n > R`
163/// (adding a whole additional word to n is too much).
164/// Uses the two most significant digits of n to approximate the quotient,
165/// then computes the difference to get the remainder. It is possible that this
166/// quotient is too big by 1; we can catch that case by looking for overflow
167/// in the subtraction.
168pub fn compute_r_mod_n(n: &MPNat, out: &mut [Word]) {
169    let k = n.digits.len();
170
171    if k == 1 {
172        let r = BASE;
173        let result = r % (n.digits[0] as DoubleWord);
174        out[0] = result as Word;
175        return;
176    }
177
178    debug_assert!(n.is_odd(), "This algorithm only works for odd numbers");
179    debug_assert!(
180        out.len() >= k,
181        "Output must be able to hold numbers of the same size as n"
182    );
183
184    let approx_n = join_as_double(n.digits[k - 1], n.digits[k - 2]);
185    let approx_q = DoubleWord::MAX / approx_n;
186    debug_assert!(
187        approx_q <= (Word::MAX as DoubleWord),
188        "quotient must fit in a single digit"
189    );
190    let mut approx_q = approx_q as Word;
191
192    loop {
193        let mut c = 0;
194        let mut b = false;
195        for (n_digit, out_digit) in n.digits.iter().zip(out.iter_mut()) {
196            let (prod, carry) = carrying_mul(approx_q, *n_digit, c);
197            c = carry;
198            let (diff, borrow) = borrowing_sub(0, prod, b);
199            b = borrow;
200            *out_digit = diff;
201        }
202        let (diff, borrow) = borrowing_sub(1, c, b);
203        if borrow {
204            // approx_q was too large so `R - approx_q*n` overflowed.
205            // try again with approx_q -= 1
206            approx_q -= 1;
207        } else {
208            debug_assert_eq!(
209                diff, 0,
210                "R - qn must be smaller than n, hence fit in k digits"
211            );
212            break;
213        }
214    }
215}
216
217/// Computes `base ^ exp`, ignoring overflow.
218pub fn big_wrapping_pow(base: &MPNat, exp: &[u8], scratch_space: &mut [Word]) -> MPNat {
219    // Compute result via the "binary method", see Knuth The Art of Computer Programming
220    let mut result = MPNat {
221        digits: vec![0; scratch_space.len()],
222    };
223    result.digits[0] = 1;
224    for &b in exp {
225        let mut mask: u8 = 1 << 7;
226        while mask > 0 {
227            big_wrapping_mul(&result, &result, scratch_space);
228            result.digits.copy_from_slice(scratch_space);
229            scratch_space.fill(0); // zero-out the scratch space
230            if b & mask != 0 {
231                big_wrapping_mul(&result, base, scratch_space);
232                result.digits.copy_from_slice(scratch_space);
233                scratch_space.fill(0); // zero-out the scratch space
234            }
235            mask >>= 1;
236        }
237    }
238    result
239}
240
241/// Computes `(x * y) mod 2^(WORD_BITS*out.len())`.
242pub fn big_wrapping_mul(x: &MPNat, y: &MPNat, out: &mut [Word]) {
243    let s = out.len();
244    for i in 0..s {
245        let mut c: Word = 0;
246        for j in 0..(s - i) {
247            let (prod, carry) = shifted_carrying_mul(
248                out[i + j],
249                x.digits.get(j).copied().unwrap_or(0),
250                y.digits.get(i).copied().unwrap_or(0),
251                c,
252            );
253            c = carry;
254            out[i + j] = prod;
255        }
256    }
257}
258
259/// Computes `x^2`, storing the result in `out`.
260pub fn big_sq(x: &MPNat, out: &mut [Word]) {
261    debug_assert!(
262        out.len() > 2 * x.digits.len(),
263        "Output needs double the digits to hold the value of x^2"
264    );
265    let s = x.digits.len();
266    for i in 0..s {
267        let (product, carry) = shifted_carrying_mul(out[i + i], x.digits[i], x.digits[i], 0);
268        out[i + i] = product;
269        let mut c = carry as DoubleWord;
270        for j in (i + 1)..s {
271            let mut new_c: DoubleWord = 0;
272            let res = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord);
273            let (res, overflow) = res.overflowing_add(res);
274            if overflow {
275                new_c += BASE;
276            }
277            let (res, overflow) = (out[i + j] as DoubleWord).overflowing_add(res);
278            if overflow {
279                new_c += BASE;
280            }
281            let (res, overflow) = res.overflowing_add(c);
282            if overflow {
283                new_c += BASE;
284            }
285            out[i + j] = res as Word;
286            c = new_c + ((res >> WORD_BITS) as DoubleWord);
287        }
288        let (sum, carry) = carrying_add(out[i + s], c as Word, false);
289        out[i + s] = sum;
290        out[i + s + 1] = ((c >> WORD_BITS) as Word) + (carry as Word);
291    }
292}
293
294// Performs `a <<= shift`, returning the overflow
295pub fn in_place_shl(a: &mut [Word], shift: u32) -> Word {
296    let mut c: Word = 0;
297    let carry_shift = (WORD_BITS as u32) - shift;
298    for a_digit in a.iter_mut() {
299        let carry = a_digit.overflowing_shr(carry_shift).0;
300        *a_digit = a_digit.overflowing_shl(shift).0 | c;
301        c = carry;
302    }
303    c
304}
305
306// Performs `a >>= shift`, returning the overflow
307pub fn in_place_shr(a: &mut [Word], shift: u32) -> Word {
308    let mut b: Word = 0;
309    let borrow_shift = (WORD_BITS as u32) - shift;
310    for a_digit in a.iter_mut().rev() {
311        let borrow = a_digit.overflowing_shl(borrow_shift).0;
312        *a_digit = a_digit.overflowing_shr(shift).0 | b;
313        b = borrow;
314    }
315    b
316}
317
318// Performs a += b, returning if there was overflow
319pub fn in_place_add(a: &mut [Word], b: &[Word]) -> bool {
320    debug_assert!(a.len() == b.len());
321
322    let mut c = false;
323    for (a_digit, b_digit) in a.iter_mut().zip(b) {
324        let (sum, carry) = carrying_add(*a_digit, *b_digit, c);
325        *a_digit = sum;
326        c = carry;
327    }
328
329    c
330}
331
332// Performs `a -= xy`, returning the "borrow".
333pub fn in_place_mul_sub(a: &mut [Word], x: &[Word], y: Word) -> Word {
334    debug_assert!(a.len() == x.len());
335
336    // a -= x*0 leaves a unchanged, so return early
337    if y == 0 {
338        return 0;
339    }
340
341    // carry is between -big_digit::MAX and 0, so to avoid overflow we store
342    // offset_carry = carry + big_digit::MAX
343    let mut offset_carry = Word::MAX;
344
345    for (a_digit, x_digit) in a.iter_mut().zip(x) {
346        // We want to calculate sum = x - y * c + carry.
347        // sum >= -(big_digit::MAX * big_digit::MAX) - big_digit::MAX
348        // sum <= big_digit::MAX
349        // Offsetting sum by (big_digit::MAX << big_digit::BITS) puts it in DoubleBigDigit range.
350        let offset_sum = join_as_double(Word::MAX, *a_digit) - Word::MAX as DoubleWord
351            + offset_carry as DoubleWord
352            - ((*x_digit as DoubleWord) * (y as DoubleWord));
353
354        let new_offset_carry = (offset_sum >> WORD_BITS) as Word;
355        let new_x = offset_sum as Word;
356        offset_carry = new_offset_carry;
357        *a_digit = new_x;
358    }
359
360    // Return the borrow.
361    Word::MAX - offset_carry
362}
363
364/// Computes `a + xy + c` where any overflow is captured as the "carry",
365/// the second part of the output. The arithmetic in this function is
366/// guaranteed to never overflow because even when all 4 variables are
367/// equal to `Word::MAX` the output is smaller than `DoubleWord::MAX`.
368pub const fn shifted_carrying_mul(a: Word, x: Word, y: Word, c: Word) -> (Word, Word) {
369    let wide = { (a as DoubleWord) + ((x as DoubleWord) * (y as DoubleWord)) + (c as DoubleWord) };
370    (wide as Word, (wide >> WORD_BITS) as Word)
371}
372
373/// Computes `xy + c` where any overflow is captured as the "carry",
374/// the second part of the output. The arithmetic in this function is
375/// guaranteed to never overflow because even when all 3 variables are
376/// equal to `Word::MAX` the output is smaller than `DoubleWord::MAX`.
377pub const fn carrying_mul(x: Word, y: Word, c: Word) -> (Word, Word) {
378    let wide = { ((x as DoubleWord) * (y as DoubleWord)) + (c as DoubleWord) };
379    (wide as Word, (wide >> WORD_BITS) as Word)
380}
381
382// Computes `x + y` with "carry the 1" semantics
383pub const fn carrying_add(x: Word, y: Word, carry: bool) -> (Word, bool) {
384    let (a, b) = x.overflowing_add(y);
385    let (c, d) = a.overflowing_add(carry as Word);
386    (c, b | d)
387}
388
389// Computes `x - y` with "borrow from your neighbour" semantics
390pub const fn borrowing_sub(x: Word, y: Word, borrow: bool) -> (Word, bool) {
391    let (a, b) = x.overflowing_sub(y);
392    let (c, d) = a.overflowing_sub(borrow as Word);
393    (c, b | d)
394}
395
396pub fn join_as_double(hi: Word, lo: Word) -> DoubleWord {
397    DoubleWord::from(lo) | (DoubleWord::from(hi) << WORD_BITS)
398}
399
400#[test]
401fn test_monsq() {
402    fn check_monsq(x: u128, n: u128) {
403        let a = MPNat::from_big_endian(&x.to_be_bytes());
404        let m = MPNat::from_big_endian(&n.to_be_bytes());
405        let n_prime = Word::MAX - mod_inv(m.digits[0]) + 1;
406
407        let mut output = vec![0; 2 * m.digits.len() + 1];
408        monsq(&a, &m, n_prime, &mut output);
409        let result = MPNat { digits: output };
410
411        let mut output = vec![0; m.digits.len() + 2];
412        monpro(&a, &a, &m, n_prime, &mut output);
413        let expected = MPNat { digits: output };
414
415        assert_eq!(
416            num::BigUint::from_bytes_be(&result.to_big_endian()),
417            num::BigUint::from_bytes_be(&expected.to_big_endian()),
418            "{x}^2 failed monsq check"
419        );
420    }
421
422    check_monsq(1, 31);
423    check_monsq(6, 31);
424    // This example is intentionally chosen because 5 * 5 = 25 = 0 mod 25,
425    // therefore it requires the final subtraction step in the algorithm.
426    check_monsq(5, 25);
427    check_monsq(0x1FFF_FFFF_FFFF_FFF0, 0x1FFF_FFFF_FFFF_FFF1);
428    check_monsq(0x16FF_221F_CB7D, 0x011E_842B_6BAA_5017_EBF2_8293);
429    check_monsq(0x0A2D_63F5_CFF9, 0x1F3B_3BD9_43EF);
430    check_monsq(
431        0xa6b0ce71a380dea7c83435bc,
432        0xc4550871a1cfc67af3e77eceb2ecfce5,
433    );
434}
435
436#[test]
437fn test_monpro() {
438    use num::Integer;
439
440    fn check_monpro(x: u128, y: u128, n: u128) {
441        let a = MPNat::from_big_endian(&x.to_be_bytes());
442        let b = MPNat::from_big_endian(&y.to_be_bytes());
443        let m = MPNat::from_big_endian(&n.to_be_bytes());
444        let n_prime = Word::MAX - mod_inv(m.digits[0]) + 1;
445
446        let mut output = vec![0; m.digits.len() + 2];
447        monpro(&a, &b, &m, n_prime, &mut output);
448        let result = MPNat { digits: output };
449
450        let r = num::BigInt::from(2).pow((WORD_BITS * m.digits.len()) as u32);
451        let r_inv = r.extended_gcd(&num::BigInt::from(n as i128)).x;
452        let r_inv: u128 = r_inv.try_into().unwrap();
453
454        let expected = (((x * y) % n) * r_inv) % n;
455        let actual = mp_nat_to_u128(&result);
456        assert_eq!(actual, expected, "{x}*{y} failed monpro check");
457    }
458
459    check_monpro(1, 1, 31);
460    check_monpro(6, 7, 31);
461    // This example is intentionally chosen because 5 * 7 = 35 = 0 mod 35,
462    // therefore it requires the final subtraction step in the algorithm.
463    check_monpro(5, 7, 35);
464    check_monpro(0x1FFF_FFFF_FFFF_FFF0, 0x1234, 0x1FFF_FFFF_FFFF_FFF1);
465    check_monpro(
466        0x16FF_221F_CB7D,
467        0x0C75_8535_434F,
468        0x011E_842B_6BAA_5017_EBF2_8293,
469    );
470    check_monpro(0x0A2D_63F5_CFF9, 0x1B21_FF3C_FA8E, 0x1F3B_3BD9_43EF);
471}
472
473#[test]
474fn test_r_mod_n() {
475    fn check_r_mod_n(n: u128) {
476        let x = MPNat::from_big_endian(&n.to_be_bytes());
477        let mut out = vec![0; x.digits.len()];
478        compute_r_mod_n(&x, &mut out);
479        let result = mp_nat_to_u128(&MPNat { digits: out });
480        let expected = num::BigUint::from(2_u32).pow((WORD_BITS * x.digits.len()) as u32)
481            % num::BigUint::from(n);
482        assert_eq!(num::BigUint::from(result), expected);
483    }
484
485    check_r_mod_n(0x01_00_00_00_01);
486    check_r_mod_n(0x80_00_00_00_01);
487    check_r_mod_n(0xFFFF_FFFF_FFFF_FFFF);
488    check_r_mod_n(0x0001_0000_0000_0000_0001);
489    check_r_mod_n(0x8000_0000_0000_0000_0001);
490    check_r_mod_n(0xbf2d_c9a3_82c5_6e85_b033_7651);
491    check_r_mod_n(0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF);
492}
493
494#[test]
495fn test_in_place_shl() {
496    fn check_in_place_shl(n: u128, shift: u32) {
497        let mut x = MPNat::from_big_endian(&n.to_be_bytes());
498        in_place_shl(&mut x.digits, shift);
499        let result = mp_nat_to_u128(&x);
500        let mask = BASE
501            .overflowing_pow(x.digits.len() as u32)
502            .0
503            .wrapping_sub(1);
504        assert_eq!(result, n.overflowing_shl(shift).0 & mask);
505    }
506
507    check_in_place_shl(0, 0);
508    check_in_place_shl(1, 10);
509    check_in_place_shl(u128::from(Word::MAX), 5);
510    check_in_place_shl(u128::MAX, 16);
511}
512
513#[test]
514fn test_in_place_shr() {
515    fn check_in_place_shr(n: u128, shift: u32) {
516        let mut x = MPNat::from_big_endian(&n.to_be_bytes());
517        in_place_shr(&mut x.digits, shift);
518        let result = mp_nat_to_u128(&x);
519        assert_eq!(result, n.overflowing_shr(shift).0);
520    }
521
522    check_in_place_shr(0, 0);
523    check_in_place_shr(1, 10);
524    check_in_place_shr(0x1234_5678, 10);
525    check_in_place_shr(u128::from(Word::MAX), 5);
526    check_in_place_shr(u128::MAX, 16);
527}
528
529#[test]
530fn test_mod_inv() {
531    fn check_mod_inv(n: Word) {
532        let n_inv = mod_inv(n);
533        assert_eq!(n.wrapping_mul(n_inv), 1, "{n} failed mod_inv check");
534    }
535
536    for i in 1..1025 {
537        check_mod_inv(2 * i - 1);
538    }
539    for i in 0..1025 {
540        check_mod_inv(0xFF_FF_FF_FF - 2 * i);
541    }
542}
543
544#[test]
545fn test_big_wrapping_pow() {
546    fn check_big_wrapping_pow(a: u128, b: u32) {
547        let expected = num::BigUint::from(a).pow(b);
548        let x = MPNat::from_big_endian(&a.to_be_bytes());
549        let y = b.to_be_bytes();
550        let mut scratch = vec![0; 1 + (expected.to_bytes_be().len() / crate::mpnat::WORD_BYTES)];
551        let result = big_wrapping_pow(&x, &y, &mut scratch);
552        let result = {
553            let result = result.to_big_endian();
554            num::BigUint::from_bytes_be(&result)
555        };
556        assert_eq!(result, expected, "{a} ^ {b} != {expected}");
557    }
558
559    check_big_wrapping_pow(1, 1);
560    check_big_wrapping_pow(10, 2);
561    check_big_wrapping_pow(2, 32);
562    check_big_wrapping_pow(2, 64);
563    check_big_wrapping_pow(2766, 844);
564}
565
566#[test]
567fn test_big_wrapping_mul() {
568    fn check_big_wrapping_mul(a: u128, b: u128, output_digits: usize) {
569        let expected = (num::BigUint::from(a) * num::BigUint::from(b))
570            % num::BigUint::from(2_u32).pow(u32::try_from(output_digits * WORD_BITS).unwrap());
571        let x = MPNat::from_big_endian(&a.to_be_bytes());
572        let y = MPNat::from_big_endian(&b.to_be_bytes());
573        let mut out = vec![0; output_digits];
574        big_wrapping_mul(&x, &y, &mut out);
575        let result = {
576            let result = MPNat { digits: out }.to_big_endian();
577            num::BigUint::from_bytes_be(&result)
578        };
579        assert_eq!(result, expected, "{a}*{b} != {expected}");
580    }
581
582    check_big_wrapping_mul(0, 0, 1);
583    check_big_wrapping_mul(1, 1, 1);
584    check_big_wrapping_mul(7, 6, 1);
585    check_big_wrapping_mul(Word::MAX.into(), Word::MAX.into(), 2);
586    check_big_wrapping_mul(Word::MAX.into(), Word::MAX.into(), 1);
587    check_big_wrapping_mul(DoubleWord::MAX - 5, DoubleWord::MAX - 6, 2);
588    check_big_wrapping_mul(0xa945_aa5e_429a_6d1a, 0x4072_d45d_3355_237b, 3);
589    check_big_wrapping_mul(
590        0x8ae1_5515_fc92_b1c0_b473_8ce8_6bbf_7218,
591        0x43e9_8b77_1f7c_aa93_6c4c_85e9_7fd0_504f,
592        3,
593    );
594}
595
596#[test]
597fn test_big_sq() {
598    fn check_big_sq(a: u128) {
599        let expected = num::BigUint::from(a).pow(2_u32);
600        let x = MPNat::from_big_endian(&a.to_be_bytes());
601        let mut out = vec![0; 2 * x.digits.len() + 1];
602        big_sq(&x, &mut out);
603        let result = {
604            let result = MPNat { digits: out }.to_big_endian();
605            num::BigUint::from_bytes_be(&result)
606        };
607        assert_eq!(result, expected, "{a}^2 != {expected}");
608    }
609
610    check_big_sq(0);
611    check_big_sq(1);
612    check_big_sq(Word::MAX.into());
613    check_big_sq(2 * (Word::MAX as u128));
614    check_big_sq(0x8e67904953db9a2bf6da64bf8bda866d);
615    check_big_sq(0x9f8dc1c3fc0bf50fe75ac3bbc03124c9);
616    check_big_sq(0x9c9a17378f3d064e5eaa80eeb3850cd7);
617    check_big_sq(0xc7f03fbb1c186c05e54b3ee19106baa4);
618    check_big_sq(0xcf2025cee03025d247ad190e9366d926);
619    check_big_sq(u128::MAX);
620
621    /* Test for addition overflows in the big_sq inner loop */
622    {
623        let x = MPNat::from_big_endian(&[
624            0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x40, 0x00,
625            0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00,
626        ]);
627        let mut out = vec![0; 2 * x.digits.len() + 1];
628        big_sq(&x, &mut out);
629        let result = MPNat { digits: out }.to_big_endian();
630        let expected = vec![
631            0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00,
632            0x00, 0x01, 0xff, 0xff, 0xff, 0xfe, 0x40, 0x00, 0x00, 0x01, 0x90, 0x00, 0x00, 0x00,
633            0x00, 0x00, 0x00, 0x00, 0xbf, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00,
634            0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
635        ];
636        assert_eq!(result, expected);
637    }
638}
639
640#[test]
641fn test_borrowing_sub() {
642    assert_eq!(borrowing_sub(0, 0, false), (0, false));
643    assert_eq!(borrowing_sub(1, 0, false), (1, false));
644    assert_eq!(borrowing_sub(47, 5, false), (42, false));
645    assert_eq!(borrowing_sub(101, 7, true), (93, false));
646    assert_eq!(
647        borrowing_sub(0x00_00_01_00, 0x00_00_02_00, false),
648        (Word::MAX - 0xFF, true)
649    );
650    assert_eq!(
651        borrowing_sub(0x00_00_01_00, 0x00_00_10_00, true),
652        (Word::MAX - 0x0F_00, true)
653    );
654}
655
656// These examples are correctly stated
657#[allow(clippy::mistyped_literal_suffixes)]
658#[test]
659fn test_shifted_carrying_mul() {
660    assert_eq!(shifted_carrying_mul(0, 0, 0, 0), (0, 0));
661    assert_eq!(shifted_carrying_mul(0, 6, 7, 0), (42, 0));
662    assert_eq!(shifted_carrying_mul(0, 6, 7, 8), (50, 0));
663    assert_eq!(shifted_carrying_mul(5, 6, 7, 8), (55, 0));
664    assert_eq!(
665        shifted_carrying_mul(
666            Word::MAX - 0x11,
667            Word::MAX - 0x1234,
668            Word::MAX - 0xABCD,
669            Word::MAX - 0xFF
670        ),
671        (0x0C_38_0C_94, Word::MAX - 0xBE00)
672    );
673    assert_eq!(
674        shifted_carrying_mul(Word::MAX, Word::MAX, Word::MAX, Word::MAX),
675        (Word::MAX, Word::MAX)
676    );
677}
678
679#[cfg(test)]
680pub fn mp_nat_to_u128(x: &MPNat) -> u128 {
681    let mut buf = [0u8; 16];
682    let result = x.to_big_endian();
683    let k = result.len();
684    buf[(16 - k)..].copy_from_slice(&result);
685    u128::from_be_bytes(buf)
686}