aurora_engine_modexp/
mpnat.rs

1use crate::{
2    arith::{
3        big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add, compute_r_mod_n,
4        in_place_add, in_place_mul_sub, in_place_shl, in_place_shr, join_as_double, mod_inv,
5        monpro, monsq,
6    },
7    maybe_std::{vec, Vec},
8};
9
10pub type Word = u64;
11pub type DoubleWord = u128;
12pub const WORD_BYTES: usize = size_of::<Word>();
13pub const WORD_BITS: usize = Word::BITS as usize;
14pub const BASE: DoubleWord = (Word::MAX as DoubleWord) + 1;
15
16/// Multi-precision natural number, represented in base `Word::MAX + 1 = 2^WORD_BITS`.
17/// The digits are stored in little-endian order, i.e. digits[0] is the least
18/// significant digit.
19#[derive(Debug)]
20pub struct MPNat {
21    pub digits: Vec<Word>,
22}
23
24impl MPNat {
25    fn strip_leading_zeroes(a: &[u8]) -> (&[u8], bool) {
26        let len = a.len();
27        let end = a.iter().position(|&x| x != 0).unwrap_or(len);
28
29        if end == len {
30            (&[], true)
31        } else {
32            (&a[end..], false)
33        }
34    }
35
36    // Koç's algorithm for inversion mod 2^k
37    // https://eprint.iacr.org/2017/411.pdf
38    fn koc_2017_inverse(aa: &Self, k: usize) -> Self {
39        debug_assert!(aa.is_odd());
40
41        let length = k / WORD_BITS;
42        let mut b = Self {
43            digits: vec![0; length + 1],
44        };
45        b.digits[0] = 1;
46
47        let mut a = Self {
48            digits: aa.digits.clone(),
49        };
50        a.digits.resize(length + 1, 0);
51
52        let mut neg: bool = false;
53
54        let mut res = Self {
55            digits: vec![0; length + 1],
56        };
57
58        let (mut wordpos, mut bitpos) = (0, 0);
59
60        for _ in 0..k {
61            let x = b.digits[0] & 1;
62            if x != 0 {
63                if neg {
64                    // b = b - a
65                    in_place_add(&mut b.digits, &a.digits);
66                } else {
67                    // b = a - b
68                    let mut tmp = Self {
69                        digits: a.digits.clone(),
70                    };
71                    in_place_mul_sub(&mut tmp.digits, &b.digits, 1);
72                    b = tmp;
73                    neg = true;
74                }
75            }
76
77            in_place_shr(&mut b.digits, 1);
78
79            res.digits[wordpos] |= x << bitpos;
80
81            bitpos += 1;
82            if bitpos == WORD_BITS {
83                bitpos = 0;
84                wordpos += 1;
85            }
86        }
87
88        res
89    }
90
91    pub fn from_big_endian(bytes: &[u8]) -> Self {
92        if bytes.is_empty() {
93            return Self { digits: vec![0] };
94        }
95        // Remainder on division by WORD_BYTES
96        let r = bytes.len() & (WORD_BYTES - 1);
97        let n_digits = if r == 0 {
98            bytes.len() / WORD_BYTES
99        } else {
100            // Need an extra digit for the remainder
101            (bytes.len() / WORD_BYTES) + 1
102        };
103        let mut digits = vec![0; n_digits];
104        // buffer to hold Word-sized slices of the input bytes
105        let mut buf = [0u8; WORD_BYTES];
106        let mut i = n_digits - 1;
107        if r != 0 {
108            buf[(WORD_BYTES - r)..].copy_from_slice(&bytes[0..r]);
109            digits[i] = Word::from_be_bytes(buf);
110            if i == 0 {
111                // Special case where there is just one digit
112                return Self { digits };
113            }
114            i -= 1;
115        }
116        let mut j = r;
117        loop {
118            let next_j = j + WORD_BYTES;
119            buf.copy_from_slice(&bytes[j..next_j]);
120            digits[i] = Word::from_be_bytes(buf);
121            if i == 0 {
122                break;
123            }
124
125            i -= 1;
126            j = next_j;
127        }
128        // throw away leading zeros
129        while digits.len() > 1 && digits[digits.len() - 1] == 0 {
130            digits.pop();
131        }
132        Self { digits }
133    }
134
135    pub fn is_power_of_two(&self) -> bool {
136        // A multi-precision number is a power of 2 iff exactly one digit
137        // is a power of 2 and all others are zero.
138        let mut found_power_of_two = false;
139        for &d in &self.digits {
140            let is_p2 = d.is_power_of_two();
141            if (!is_p2 && d != 0) || (is_p2 && found_power_of_two) {
142                return false;
143            } else if is_p2 {
144                found_power_of_two = true;
145            }
146        }
147        found_power_of_two
148    }
149
150    pub fn is_odd(&self) -> bool {
151        // A binary number is odd iff its lowest order bit is set.
152        self.digits[0] & 1 == 1
153    }
154
155    /// Computes `self ^ exp mod modulus`. `exp` must be given as big-endian bytes.
156    #[allow(clippy::too_many_lines, clippy::debug_assert_with_mut_call)]
157    pub fn modpow(&mut self, exp: &[u8], modulus: &Self) -> Self {
158        // exp must be stripped because it is iterated over in
159        // `big_wrapping_pow` and `modpow_montgomery`, and a large
160        // zero-padded exp leads to performance issues.
161        let (exp, exp_is_zero) = Self::strip_leading_zeroes(exp);
162
163        // base^0 is always 1, regardless of base.
164        // Hence, the result is 0 for (base^0) % 1, and 1
165        // for every modulus larger than 1.
166        //
167        // The case of modulus being 0 should have already been
168        // handled in modexp().
169        debug_assert!(!(modulus.digits.len() == 1 && modulus.digits[0] == 0));
170        if exp_is_zero {
171            if modulus.digits.len() == 1 && modulus.digits[0] == 1 {
172                return Self { digits: vec![0] };
173            }
174
175            return Self { digits: vec![1] };
176        }
177
178        if exp.len() <= size_of::<usize>() {
179            let exp_as_number = {
180                let mut tmp: usize = 0;
181                for d in exp {
182                    tmp *= 256;
183                    tmp += (*d) as usize;
184                }
185                tmp
186            };
187
188            if let Some(max_output_digits) = self.digits.len().checked_mul(exp_as_number) {
189                if modulus.digits.len() > max_output_digits {
190                    // Special case: modulus is larger than `base ^ exp`, so division is not relevant
191                    let mut scratch_space = vec![0; max_output_digits];
192                    return big_wrapping_pow(self, exp, &mut scratch_space);
193                }
194            }
195        }
196
197        if modulus.is_power_of_two() {
198            return self.modpow_with_power_of_two(exp, modulus);
199        } else if modulus.is_odd() {
200            return self.modpow_montgomery(exp, modulus);
201        }
202
203        // If the modulus is not a power of two and not an odd number then
204        // it is a product of some power of two with an odd number. In this
205        // case we will use the Chinese remainder theorem to get the result.
206        // See http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf
207
208        let trailing_zeros = modulus.digits.iter().take_while(|x| x == &&0).count();
209        let additional_zero_bits = modulus.digits[trailing_zeros].trailing_zeros() as usize;
210        let power_of_two = {
211            let mut tmp = Self {
212                digits: vec![0; trailing_zeros + 1],
213            };
214            tmp.digits[trailing_zeros] = 1 << additional_zero_bits;
215            tmp
216        };
217        let power_of_two_mask = *power_of_two.digits.last().unwrap() - 1;
218        let odd = {
219            let num_digits = modulus.digits.len() - trailing_zeros;
220            let mut tmp = Self {
221                digits: vec![0; num_digits],
222            };
223            if additional_zero_bits > 0 {
224                tmp.digits[0] = modulus.digits[trailing_zeros] >> additional_zero_bits;
225                for i in 1..num_digits {
226                    let d = modulus.digits[trailing_zeros + i];
227                    tmp.digits[i - 1] +=
228                        (d & power_of_two_mask) << (WORD_BITS - additional_zero_bits);
229                    tmp.digits[i] = d >> additional_zero_bits;
230                }
231            } else {
232                tmp.digits
233                    .copy_from_slice(&modulus.digits[trailing_zeros..]);
234            }
235            while tmp.digits.last() == Some(&0) {
236                tmp.digits.pop();
237            }
238            tmp
239        };
240        debug_assert!(power_of_two.is_power_of_two(), "Factored out power of two");
241        debug_assert!(
242            odd.is_odd(),
243            "Remaining number is odd after factoring out powers of two"
244        );
245        debug_assert!(
246            {
247                let mut tmp = vec![0; modulus.digits.len()];
248                big_wrapping_mul(&power_of_two, &odd, &mut tmp);
249                tmp == modulus.digits
250            },
251            "modulus is factored"
252        );
253
254        let mut base_copy = Self {
255            digits: self.digits.clone(),
256        };
257        let x1 = base_copy.modpow_montgomery(exp, &odd);
258        let x2 = self.modpow_with_power_of_two(exp, &power_of_two);
259
260        let odd_inv =
261            Self::koc_2017_inverse(&odd, trailing_zeros * WORD_BITS + additional_zero_bits);
262
263        let s = power_of_two.digits.len();
264        let mut scratch = vec![0; s];
265        let diff = {
266            scratch.fill(0);
267            let mut b = false;
268            for (i, scratch_digit) in scratch.iter_mut().enumerate().take(s) {
269                let (diff, borrow) = borrowing_sub(
270                    x2.digits.get(i).copied().unwrap_or(0),
271                    x1.digits.get(i).copied().unwrap_or(0),
272                    b,
273                );
274                *scratch_digit = diff;
275                b = borrow;
276            }
277            Self { digits: scratch }
278        };
279        let y = {
280            let mut out = vec![0; s];
281            big_wrapping_mul(&diff, &odd_inv, &mut out);
282            *out.last_mut().unwrap() &= power_of_two_mask;
283            Self { digits: out }
284        };
285
286        // Re-use allocation for efficiency
287        let mut digits = diff.digits;
288        let s = modulus.digits.len();
289        digits.fill(0);
290        digits.resize(s, 0);
291        big_wrapping_mul(&odd, &y, &mut digits);
292        let mut c = false;
293        for (i, out_digit) in digits.iter_mut().enumerate() {
294            let (sum, carry) = carrying_add(x1.digits.get(i).copied().unwrap_or(0), *out_digit, c);
295            c = carry;
296            *out_digit = sum;
297        }
298        Self { digits }
299    }
300
301    // Computes `self ^ exp mod modulus` using Montgomery multiplication.
302    // See https://www.microsoft.com/en-us/research/wp-content/uploads/1996/01/j37acmon.pdf
303    fn modpow_montgomery(&mut self, exp: &[u8], modulus: &Self) -> Self {
304        // The montgomery method only works with odd modulus.
305        debug_assert!(modulus.is_odd());
306
307        // n_prime satisfies `r * (r^(-1)) - modulus * n' = 1`, where
308        // `r = 2^(WORD_BITS*modulus.digits.len())`.
309        let n_prime = Word::MAX - mod_inv(modulus.digits[0]) + 1;
310        let s = modulus.digits.len();
311
312        let mut x_bar = Self { digits: vec![0; s] };
313        // Initialize result as `r mod modulus` (Montgomery form of 1)
314        compute_r_mod_n(modulus, &mut x_bar.digits);
315
316        // Reduce base mod modulus
317        self.sub_to_same_size(modulus);
318
319        // Need to compute a_bar = base * r mod modulus;
320        // First directly multiply base * r to get a 2s-digit number,
321        // then reduce mod modulus.
322        let a_bar = {
323            let mut tmp = Self {
324                digits: vec![0; 2 * s],
325            };
326            big_wrapping_mul(self, &x_bar, &mut tmp.digits);
327            tmp.sub_to_same_size(modulus);
328            tmp
329        };
330
331        // scratch space for monpro algorithm
332        let mut scratch = vec![0; 2 * s + 1];
333        let monpro_len = s + 2;
334
335        // Use binary method for computing exp, but with monpro as the multiplication
336        for &b in exp {
337            let mut mask: u8 = 1 << 7;
338            while mask > 0 {
339                monsq(&x_bar, modulus, n_prime, &mut scratch);
340                x_bar.digits.copy_from_slice(&scratch[0..s]);
341                scratch.fill(0);
342                if b & mask != 0 {
343                    monpro(
344                        &x_bar,
345                        &a_bar,
346                        modulus,
347                        n_prime,
348                        &mut scratch[0..monpro_len],
349                    );
350                    x_bar.digits.copy_from_slice(&scratch[0..s]);
351                    scratch.fill(0);
352                }
353                mask >>= 1;
354            }
355        }
356
357        // Convert out of Montgomery form by computing monpro with 1
358        let one = {
359            // We'll reuse the memory space from a_bar for efficiency.
360            let mut digits = a_bar.digits;
361            digits.fill(0);
362            digits[0] = 1;
363            Self { digits }
364        };
365        monpro(&x_bar, &one, modulus, n_prime, &mut scratch[0..monpro_len]);
366        scratch.resize(s, 0);
367        Self { digits: scratch }
368    }
369
370    fn modpow_with_power_of_two(&mut self, exp: &[u8], modulus: &Self) -> Self {
371        debug_assert!(modulus.is_power_of_two());
372        // We know `modulus` is a power of 2. So reducing is as easy as bit shifting.
373        // We also know the modulus is non-zero because 0 is not a power of 2.
374
375        // First reduce self to be the same size as the modulus
376        self.force_same_size(modulus);
377        // The modulus is a power of 2 but that power may not be a multiple of a whole word.
378        // We can clear out any higher order bits to fix this.
379        let modulus_mask = *modulus.digits.last().unwrap() - 1;
380        *self.digits.last_mut().unwrap() &= modulus_mask;
381
382        // We know that `totient(2^k) = 2^(k-1)`, therefore by Euler's theorem
383        // we can also reduce the exponent mod `2^(k-1)`. Effectively this means
384        // throwing away bytes to make `exp` shorter. Note: Euler's theorem only applies
385        // if the base and modulus are coprime (which in this case means the base is odd).
386        let exp = if self.is_odd() && (exp.len() > WORD_BYTES * modulus.digits.len()) {
387            &exp[(exp.len() - WORD_BYTES * modulus.digits.len())..]
388        } else {
389            exp
390        };
391
392        let mut scratch_space = vec![0; modulus.digits.len()];
393        let mut result = big_wrapping_pow(self, exp, &mut scratch_space);
394
395        // The modulus is a power of 2 but that power may not be a multiple of a whole word.
396        // We can clear out any higher order bits to fix this.
397        *result.digits.last_mut().unwrap() &= modulus_mask;
398
399        result
400    }
401
402    /// Makes `self` have the same number of digits as `other` by
403    /// pushing 0s or dropping higher order digits as needed.
404    /// This is equivalent to reducing `self` modulo `2^(WORD_BITS*k)` where
405    /// `k` is the number of digits in `other`.
406    fn force_same_size(&mut self, other: &Self) {
407        self.digits.resize(other.digits.len(), 0);
408
409        // This is here to just drive the point home about what the
410        // invariant is after calling this function.
411        debug_assert_eq!(self.digits.len(), other.digits.len());
412    }
413
414    /// Assumes `self` has more digits than `other`.
415    /// Makes `self` have the same number of digits as `other` by subtracting off multiples
416    /// of `other`. This is a partial reduction of `self` modulo `other`, but rather
417    /// than doing the full division, the goal is simply to make the two numbers have the
418    /// same number of digits.
419    fn sub_to_same_size(&mut self, other: &Self) {
420        // Remove leading zeros before starting
421        while self.digits.len() > 1 && self.digits.last() == Some(&0) {
422            self.digits.pop();
423        }
424
425        let n = other.digits.len();
426        let m = self.digits.len().saturating_sub(n);
427        if m == 0 {
428            return;
429        }
430
431        let other_most_sig = *other.digits.last().unwrap() as DoubleWord;
432
433        if self.digits.len() == 2 {
434            // This is the smallest case since `n >= 1` and `m > 0`
435            // implies that `self.digits.len() >= 2`.
436            // In this case we can use DoubleWord-sized arithmetic
437            // to get the answer directly.
438            let self_most_sig = self.digits.pop().unwrap();
439            let a = join_as_double(self_most_sig, self.digits[0]);
440            let b = other_most_sig;
441            self.digits[0] = (a % b) as Word;
442            return;
443        }
444
445        if n == 1 {
446            // The divisor is only 1 digit, so the long-division
447            // algorithm is easy.
448            let k = self.digits.len() - 1;
449            for j in (0..k).rev() {
450                let self_most_sig = self.digits.pop().unwrap();
451                let self_second_sig = self.digits[j];
452                let r = join_as_double(self_most_sig, self_second_sig) % other_most_sig;
453                self.digits[j] = r as Word;
454            }
455            return;
456        }
457
458        // At this stage we know that `n >= 2` and `self.digits.len() >= 3`.
459        // The smaller cases are covered in the if-statements above.
460
461        // The algorithm below only works well when the divisor's
462        // most significant digit is at least `BASE / 2`.
463        // If it is too small then we "normalize" by multiplying
464        // both numerator and denominator by a common factor
465        // and run the algorithm on those numbers.
466        // See Knuth The Art of Computer Programming vol. 2 section 4.3 for details.
467        let shift = (other_most_sig as Word).leading_zeros();
468        if shift > 0 {
469            // Normalize self
470            let overflow = in_place_shl(&mut self.digits, shift);
471            self.digits.push(overflow);
472
473            // Normalize other
474            let mut normalized = other.digits.clone();
475            let overflow = in_place_shl(&mut normalized, shift);
476            debug_assert_eq!(overflow, 0, "Normalizing modulus cannot overflow");
477            debug_assert_eq!(
478                normalized[n - 1].leading_zeros(),
479                0,
480                "Most significant bit is set"
481            );
482
483            // Run algorithm on normalized values
484            self.sub_to_same_size(&Self { digits: normalized });
485
486            // need to de-normalize to get the correct result
487            in_place_shr(&mut self.digits, shift);
488
489            return;
490        }
491
492        let other_second_sig = other.digits[n - 2] as DoubleWord;
493        let mut self_most_sig: Word = 0;
494        for j in (0..=m).rev() {
495            let self_second_sig = *self.digits.last().unwrap();
496            let self_third_sig = self.digits[self.digits.len() - 2];
497
498            let a = join_as_double(self_most_sig, self_second_sig);
499            let mut q_hat = a / other_most_sig;
500            let mut r_hat = a % other_most_sig;
501
502            loop {
503                let a = q_hat * other_second_sig;
504                let b = join_as_double(r_hat as Word, self_third_sig);
505                if q_hat >= BASE || a > b {
506                    q_hat -= 1;
507                    r_hat += other_most_sig;
508                    if BASE <= r_hat {
509                        break;
510                    }
511                } else {
512                    break;
513                }
514            }
515
516            let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat as Word);
517            if borrow > self_most_sig {
518                // q_hat was too large, add back one multiple of the modulus
519                let carry = in_place_add(&mut self.digits[j..], &other.digits);
520                debug_assert!(
521                    carry,
522                    "Adding back should cause overflow to cancel the borrow"
523                );
524                borrow -= 1;
525            }
526            // Most significant digit of self has been cancelled out
527            debug_assert_eq!(borrow, self_most_sig);
528            self_most_sig = self.digits.pop().unwrap();
529        }
530
531        self.digits.push(self_most_sig);
532        debug_assert!(self.digits.len() <= n);
533    }
534
535    pub fn to_big_endian(&self) -> Vec<u8> {
536        if self.digits.iter().all(|x| x == &0) {
537            return vec![0];
538        }
539
540        // Safety: unwrap is safe since `self.digits` is always non-empty.
541        let most_sig_bytes: [u8; WORD_BYTES] = self.digits.last().unwrap().to_be_bytes();
542        // The most significant digit may not need 4 bytes.
543        // Only include as many bytes as needed in the output.
544        let be_initial_bytes = {
545            let mut tmp: &[u8] = &most_sig_bytes;
546            while !tmp.is_empty() && tmp[0] == 0 {
547                tmp = &tmp[1..];
548            }
549            tmp
550        };
551
552        let mut result = vec![0u8; (self.digits.len() - 1) * WORD_BYTES + be_initial_bytes.len()];
553        result[0..be_initial_bytes.len()].copy_from_slice(be_initial_bytes);
554        for (i, d) in self.digits.iter().take(self.digits.len() - 1).enumerate() {
555            let bytes = d.to_be_bytes();
556            let j = result.len() - WORD_BYTES * i;
557            result[(j - WORD_BYTES)..j].copy_from_slice(&bytes);
558        }
559        result
560    }
561}
562
563#[test]
564fn test_modpow_even() {
565    fn check_modpow_even(base: u128, exp: u128, modulus: u128, expected: u128) {
566        let mut x = MPNat::from_big_endian(&base.to_be_bytes());
567        let m = MPNat::from_big_endian(&modulus.to_be_bytes());
568        let result = x.modpow(&exp.to_be_bytes(), &m);
569        let result = crate::arith::mp_nat_to_u128(&result);
570        assert_eq!(result, expected);
571    }
572
573    check_modpow_even(3, 5, 500, 243);
574    check_modpow_even(3, 5, 20, 3);
575
576    check_modpow_even(
577        0x2ff4f4df4c518867207c84b57a77aa50,
578        0xca83c2925d17c577c9a03598b6f360,
579        0xf863d4f17a5405d84814f54c92f803c8,
580        0x8d216c9a1fb275ed18eb340ed43cacc0,
581    );
582    check_modpow_even(
583        0x13881e1614244c56d15ac01096b070e7,
584        0x336df5b4567cbe4c093271dc151e6c72,
585        0x7540f399a0b6c220f1fc60d2451a1ff0,
586        0x1251d64c552e8f831f5b841d2811f9c1,
587    );
588    check_modpow_even(
589        0x774d5b2494a449d8f22b22ea542d4ddf,
590        0xd2f602e1688f271853e7794503c2837e,
591        0xa80d20ebf75f92192159197b60f36e8e,
592        0x3fbbba42489b27fc271fb39f54aae2e1,
593    );
594    check_modpow_even(
595        0x756e409cc3583a6b68ae27ccd9eb3d50,
596        0x16dafb38a334288954d038bedbddc970,
597        0x1f9b2237f09413d1fc44edf9bd02b8bc,
598        0x9347445ac61536a402723cd07a3f5a4,
599    );
600    check_modpow_even(
601        0x6dcb8405e2cc4dcebee3e2b14861b47d,
602        0xe6c1e5251d6d5deb8dddd0198481d671,
603        0xe34a31d814536e8b9ff6cc5300000000,
604        0xaa86af638386880334694967564d0c3d,
605    );
606    check_modpow_even(
607        0x9c12fe4a1a97d17c1e4573247a43b0e5,
608        0x466f3e0a2e8846b8c48ecbf612b96412,
609        0x710d7b9d5718acff0000000000000000,
610        0x569bf65929e71cd10a553a8623bdfc99,
611    );
612    check_modpow_even(
613        0x6d018fdeaa408222cb10ff2c36124dcf,
614        0x8e35fc05d490bb138f73c2bc284a67a7,
615        0x6c237160750d78400000000000000000,
616        0x3fe14e11392c6c6be8efe956c965d5af,
617    );
618
619    let base: Vec<u8> = vec![
620        0x36, 0xAB, 0xD4, 0x52, 0x4E, 0x89, 0xA3, 0x4C, 0x89, 0xC4, 0x20, 0x94, 0x25, 0x47, 0xE1,
621        0x2C, 0x7B, 0xE1,
622    ];
623    let exponent: Vec<u8> = vec![0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0x17, 0xEA, 0x78];
624    let modulus: Vec<u8> = vec![
625        0x02, 0xF0, 0x75, 0x8C, 0x6A, 0x04, 0x20, 0x09, 0x55, 0xB6, 0x49, 0xC3, 0x57, 0x22, 0xB8,
626        0x00, 0x00, 0x00, 0x00,
627    ];
628    let result = crate::modexp(&base, &exponent, &modulus);
629    assert_eq!(
630        result,
631        vec![2, 63, 79, 118, 41, 54, 235, 9, 115, 212, 107, 110, 173, 181, 157, 104, 208, 97, 1]
632    );
633
634    let base = hex::decode("36abd4524e89a34c89c420942547e12c7be1").unwrap();
635    let exponent = hex::decode("01000000000517ea78").unwrap();
636    let modulus = hex::decode("02f0758c6a04200955b649c35722b800000000").unwrap();
637    let result = crate::modexp(&base, &exponent, &modulus);
638    assert_eq!(
639        hex::encode(result),
640        "023f4f762936eb0973d46b6eadb59d68d06101"
641    );
642
643    // Test empty exp
644    let base = hex::decode("00").unwrap();
645    let exponent = hex::decode("").unwrap();
646    let modulus = hex::decode("02").unwrap();
647    let result = crate::modexp(&base, &exponent, &modulus);
648    assert_eq!(hex::encode(result), "01");
649
650    // Test zero exp
651    let base = hex::decode("00").unwrap();
652    let exponent = hex::decode("00").unwrap();
653    let modulus = hex::decode("02").unwrap();
654    let result = crate::modexp(&base, &exponent, &modulus);
655    assert_eq!(hex::encode(result), "01");
656}
657
658#[test]
659fn test_modpow_montgomery() {
660    fn check_modpow_montgomery(base: u128, exp: u128, modulus: u128, expected: u128) {
661        let mut x = MPNat::from_big_endian(&base.to_be_bytes());
662        let m = MPNat::from_big_endian(&modulus.to_be_bytes());
663        let result = x.modpow_montgomery(&exp.to_be_bytes(), &m);
664        let result = crate::arith::mp_nat_to_u128(&result);
665        assert_eq!(
666            result, expected,
667            "({base} ^ {exp}) % {modulus} failed check_modpow_montgomery"
668        );
669    }
670
671    check_modpow_montgomery(3, 5, 0x9346_9d50_1f74_d1c1, 243);
672    check_modpow_montgomery(3, 5, 19, 15);
673    check_modpow_montgomery(
674        0x5c4b74ec760dfb021499f5c5e3c69222,
675        0x62b2a34b21cf4cc036e880b3fb59fe09,
676        0x7b799c4502cd69bde8bb12601ce3ff15,
677        0x10c9d9071d0b86d6a59264d2f461200,
678    );
679    check_modpow_montgomery(
680        0xadb5ce8589030e3a9112123f4558f69c,
681        0xb002827068f05b84a87431a70fb763ab,
682        0xc4550871a1cfc67af3e77eceb2ecfce5,
683        0x7cb78c0e1c1b43f6412e9d1155ea96d2,
684    );
685    check_modpow_montgomery(
686        0x26eb51a5d9bf15a536b6e3c67867b492,
687        0xddf007944a79bf55806003220a58cc6,
688        0xc96275a80c694a62330872b2690f8773,
689        0x23b75090ead913def3a1e0bde863eda7,
690    );
691    check_modpow_montgomery(
692        0xb93fa81979e597f548c78f2ecb6800f3,
693        0x5fad650044963a271898d644984cb9f0,
694        0xbeb60d6bd0439ea39d447214a4f8d3ab,
695        0x354e63e6a5e007014acd3e5ea88dc3ad,
696    );
697    check_modpow_montgomery(
698        0x1993163e4f578869d04949bc005c878f,
699        0x8cb960f846475690259514af46868cf5,
700        0x52e104dc72423b534d8e49d878f29e3b,
701        0x2aa756846258d5cfa6a3f8b9b181a11c,
702    );
703}
704
705#[test]
706fn test_modpow_with_power_of_two() {
707    fn check_modpow_with_power_of_two(base: u128, exp: u128, modulus: u128, expected: u128) {
708        let mut x = MPNat::from_big_endian(&base.to_be_bytes());
709        let m = MPNat::from_big_endian(&modulus.to_be_bytes());
710        let result = x.modpow_with_power_of_two(&exp.to_be_bytes(), &m);
711        let result = crate::arith::mp_nat_to_u128(&result);
712        assert_eq!(result, expected);
713    }
714
715    check_modpow_with_power_of_two(3, 2, 1 << 30, 9);
716    check_modpow_with_power_of_two(3, 5, 1 << 30, 243);
717    check_modpow_with_power_of_two(3, 1_000_000, 1 << 30, 641836289);
718    check_modpow_with_power_of_two(3, 1_000_000, 1 << 31, 1715578113);
719    check_modpow_with_power_of_two(3, 1_000_000, 1 << 32, 3863061761);
720    check_modpow_with_power_of_two(
721        0xabcd_ef01_2345_6789_1111,
722        0x1234_5678_90ab_cdef,
723        1 << 5,
724        17,
725    );
726    check_modpow_with_power_of_two(
727        0x3f47_9dc0_d5b9_6003,
728        0xa180_e045_e314_8581,
729        1 << 118,
730        0x0028_3d19_e6cc_b8a0_e050_6abb_b9b1_1a03,
731    );
732}
733
734#[test]
735fn test_sub_to_same_size() {
736    fn check_sub_to_same_size(a: u128, n: u128) {
737        let mut x = MPNat::from_big_endian(&a.to_be_bytes());
738        let y = MPNat::from_big_endian(&n.to_be_bytes());
739        x.sub_to_same_size(&y);
740        assert!(x.digits.len() <= y.digits.len());
741        let result = crate::arith::mp_nat_to_u128(&x);
742        assert_eq!(result % n, a % n, "{a} % {n} failed sub_to_same_size check");
743    }
744
745    check_sub_to_same_size(0x10_00_00_00_00, 0xFF_00_00_00);
746    check_sub_to_same_size(0x10_00_00_00_00, 0x01_00_00_00);
747    check_sub_to_same_size(0x35_00_00_00_00, 0x01_00_00_00);
748    check_sub_to_same_size(0xEF_00_00_00_00_00_00, 0x02_FF_FF_FF);
749
750    let n = 10;
751    let a = 57 + 2 * n + 0x1234_0000_0000 * n + 0x000b_0000_0000_0000_0000 * n;
752    check_sub_to_same_size(a, n);
753
754    /* Test that borrow equals self_most_sig at end of sub_to_same_size */
755    {
756        let mut x = MPNat::from_big_endian(&[
757            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02,
758            0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x02, 0xc9, 0xd9, 0x12, 0x61, 0x8e, 0xf5, 0x02, 0x2c,
759            0xa0, 0x00, 0x00, 0x00,
760        ]);
761        let y = MPNat::from_big_endian(&[
762            0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02, 0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x0f, 0x70, 0x00,
763            0x00, 0x00,
764        ]);
765        x.sub_to_same_size(&y);
766    }
767
768    /* Additional test for sub_to_same_size q_hat/r_hat adjustment logic */
769    {
770        let mut x = MPNat::from_big_endian(&[
771            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff,
772            0xff, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
773            0x00, 0x00, 0x00, 0x00,
774        ]);
775        let y = MPNat::from_big_endian(&[
776            0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00,
777            0x00, 0x00,
778        ]);
779        x.sub_to_same_size(&y);
780    }
781}
782
783#[test]
784fn test_mp_nat_is_odd() {
785    fn check_is_odd(n: u128) {
786        let mp = MPNat::from_big_endian(&n.to_be_bytes());
787        assert_eq!(mp.is_odd(), n % 2 == 1, "{n} failed is_odd test");
788    }
789
790    for n in 0..1025 {
791        check_is_odd(n);
792    }
793    for n in 0xFF_FF_FF_FF_00_00_00_00..0xFF_FF_FF_FF_00_00_04_01 {
794        check_is_odd(n);
795    }
796}
797
798#[test]
799fn test_mp_nat_is_power_of_two() {
800    fn check_is_p2(n: u128, expected_result: bool) {
801        let mp = MPNat::from_big_endian(&n.to_be_bytes());
802        assert_eq!(
803            mp.is_power_of_two(),
804            expected_result,
805            "{n} failed is_power_of_two test"
806        );
807    }
808
809    check_is_p2(0, false);
810    check_is_p2(1, true);
811    check_is_p2(1327, false);
812    check_is_p2((1 << 1) + (1 << 35), false);
813    check_is_p2(1 << 1, true);
814    check_is_p2(1 << 2, true);
815    check_is_p2(1 << 3, true);
816    check_is_p2(1 << 4, true);
817    check_is_p2(1 << 5, true);
818    check_is_p2(1 << 31, true);
819    check_is_p2(1 << 32, true);
820    check_is_p2(1 << 64, true);
821    check_is_p2(1 << 65, true);
822    check_is_p2(1 << 127, true);
823}
824
825#[test]
826fn test_mp_nat_be() {
827    fn be_round_trip(hex_input: &str) {
828        let bytes = hex::decode(hex_input).unwrap();
829        let mp = MPNat::from_big_endian(&bytes);
830        let output = mp.to_big_endian();
831        let hex_output = hex::encode(output);
832        let trimmed = match hex_input.trim_start_matches('0') {
833            "" => "00",
834            x => x,
835        };
836        assert_eq!(hex_output, trimmed);
837    }
838
839    be_round_trip("");
840    be_round_trip("00");
841    be_round_trip("77");
842    be_round_trip("abcd");
843    be_round_trip("00000000abcd");
844    be_round_trip("abcdef");
845    be_round_trip("abcdef00");
846    be_round_trip("abcdef0011");
847    be_round_trip("abcdef001122");
848    be_round_trip("abcdef00112233");
849    be_round_trip("abcdef0011223344");
850    be_round_trip("abcdef001122334455");
851    be_round_trip("abcdef01234567891011121314151617181920");
852}