ruint/algorithms/
mul_redc.rs

1// TODO: https://baincapitalcrypto.com/optimizing-montgomery-multiplication-in-webassembly/
2
3use super::{borrowing_sub, carrying_add, cmp};
4use core::{cmp::Ordering, iter::zip};
5
6/// Computes a * b * 2^(-BITS) mod modulus
7///
8/// Requires that `inv` is the inverse of `-modulus[0]` modulo `2^64`.
9/// Requires that `a` and `b` are less than `modulus`.
10#[inline]
11#[must_use]
12pub fn mul_redc<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] {
13    debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX);
14    debug_assert_eq!(cmp(&a, &modulus), Ordering::Less);
15    debug_assert_eq!(cmp(&b, &modulus), Ordering::Less);
16
17    // Coarsely Integrated Operand Scanning (CIOS)
18    // See <https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf>
19    // See <https://hackmd.io/@gnark/modular_multiplication#fn1>
20    // See <https://tches.iacr.org/index.php/TCHES/article/view/10972>
21    let mut result = [0; N];
22    let mut carry = false;
23    for b in b {
24        let mut m = 0;
25        let mut carry_1 = 0;
26        let mut carry_2 = 0;
27        for i in 0..N {
28            // Add limb product
29            let (value, next_carry) = carrying_mul_add(a[i], b, result[i], carry_1);
30            carry_1 = next_carry;
31
32            if i == 0 {
33                // Compute reduction factor
34                m = value.wrapping_mul(inv);
35            }
36
37            // Add m * modulus to acc to clear next_result[0]
38            let (value, next_carry) = carrying_mul_add(modulus[i], m, value, carry_2);
39            carry_2 = next_carry;
40
41            // Shift result
42            if i > 0 {
43                result[i - 1] = value;
44            } else {
45                debug_assert_eq!(value, 0);
46            }
47        }
48
49        // Add carries
50        let (value, next_carry) = carrying_add(carry_1, carry_2, carry);
51        result[N - 1] = value;
52        if modulus[N - 1] >= 0x7fff_ffff_ffff_ffff {
53            carry = next_carry;
54        } else {
55            debug_assert!(!next_carry);
56        }
57    }
58
59    // Compute reduced product.
60    reduce1_carry(result, modulus, carry)
61}
62
63/// Computes a^2 * 2^(-BITS) mod modulus
64///
65/// Requires that `inv` is the inverse of `-modulus[0]` modulo `2^64`.
66/// Requires that `a` is less than `modulus`.
67#[inline]
68#[must_use]
69#[allow(clippy::cast_possible_truncation)]
70pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] {
71    debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX);
72    debug_assert_eq!(cmp(&a, &modulus), Ordering::Less);
73
74    let mut result = [0; N];
75    let mut carry_outer = 0;
76    for i in 0..N {
77        // Add limb product
78        let (value, mut carry_lo) = carrying_mul_add(a[i], a[i], result[i], 0);
79        let mut carry_hi = false;
80        result[i] = value;
81        for j in (i + 1)..N {
82            let (value, next_carry_lo, next_carry_hi) =
83                carrying_double_mul_add(a[i], a[j], result[j], carry_lo, carry_hi);
84            result[j] = value;
85            carry_lo = next_carry_lo;
86            carry_hi = next_carry_hi;
87        }
88
89        // Add m times modulus to result and shift one limb
90        let m = result[0].wrapping_mul(inv);
91        let (value, mut carry) = carrying_mul_add(m, modulus[0], result[0], 0);
92        debug_assert_eq!(value, 0);
93        for j in 1..N {
94            let (value, next_carry) = carrying_mul_add(modulus[j], m, result[j], carry);
95            result[j - 1] = value;
96            carry = next_carry;
97        }
98
99        // Add carries
100        if modulus[N - 1] >= 0x3fff_ffff_ffff_ffff {
101            let wide = (carry_outer as u128)
102                .wrapping_add(carry_lo as u128)
103                .wrapping_add((carry_hi as u128) << 64)
104                .wrapping_add(carry as u128);
105            result[N - 1] = wide as u64;
106
107            // Note carry_outer can be {0, 1, 2}.
108            carry_outer = (wide >> 64) as u64;
109            debug_assert!(carry_outer <= 2);
110        } else {
111            // `carry_outer` and `carry_hi` are always zero.
112            debug_assert!(!carry_hi);
113            debug_assert_eq!(carry_outer, 0);
114            let (value, carry) = carry_lo.overflowing_add(carry);
115            debug_assert!(!carry);
116            result[N - 1] = value;
117        }
118    }
119
120    // Compute reduced product.
121    debug_assert!(carry_outer <= 1);
122    reduce1_carry(result, modulus, carry_outer > 0)
123}
124
125#[inline]
126#[must_use]
127#[allow(clippy::needless_bitwise_bool)]
128fn reduce1_carry<const N: usize>(value: [u64; N], modulus: [u64; N], carry: bool) -> [u64; N] {
129    let (reduced, borrow) = sub(value, modulus);
130    // TODO: Ideally this turns into a cmov, which makes the whole mul_redc constant
131    // time.
132    if carry | !borrow {
133        reduced
134    } else {
135        value
136    }
137}
138
139#[inline]
140#[must_use]
141fn sub<const N: usize>(lhs: [u64; N], rhs: [u64; N]) -> ([u64; N], bool) {
142    let mut result = [0; N];
143    let mut borrow = false;
144    for (result, (lhs, rhs)) in zip(&mut result, zip(lhs, rhs)) {
145        let (value, next_borrow) = borrowing_sub(lhs, rhs, borrow);
146        *result = value;
147        borrow = next_borrow;
148    }
149    (result, borrow)
150}
151
152/// Compute `lhs * rhs + add + carry`.
153/// The output can not overflow for any input values.
154#[inline]
155#[must_use]
156#[allow(clippy::cast_possible_truncation)]
157const fn carrying_mul_add(lhs: u64, rhs: u64, add: u64, carry: u64) -> (u64, u64) {
158    let wide = (lhs as u128)
159        .wrapping_mul(rhs as u128)
160        .wrapping_add(add as u128)
161        .wrapping_add(carry as u128);
162    (wide as u64, (wide >> 64) as u64)
163}
164
165/// Compute `2 * lhs * rhs + add + carry_lo + 2^64 * carry_hi`.
166/// The output can not overflow for any input values.
167#[inline]
168#[must_use]
169#[allow(clippy::cast_possible_truncation)]
170const fn carrying_double_mul_add(
171    lhs: u64,
172    rhs: u64,
173    add: u64,
174    carry_lo: u64,
175    carry_hi: bool,
176) -> (u64, u64, bool) {
177    let wide = (lhs as u128).wrapping_mul(rhs as u128);
178    let (wide, carry_1) = wide.overflowing_add(wide);
179    let carries = (add as u128)
180        .wrapping_add(carry_lo as u128)
181        .wrapping_add((carry_hi as u128) << 64);
182    let (wide, carry_2) = wide.overflowing_add(carries);
183    (wide as u64, (wide >> 64) as u64, carry_1 | carry_2)
184}
185
186#[cfg(test)]
187mod test {
188    use core::ops::Neg;
189
190    use super::{
191        super::{addmul, div},
192        *,
193    };
194    use crate::{aliases::U64, const_for, nlimbs, Uint};
195    use proptest::{prop_assert_eq, proptest};
196
197    fn modmul<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N]) -> [u64; N] {
198        // Compute a * b
199        let mut product = vec![0; 2 * N];
200        addmul(&mut product, &a, &b);
201
202        // Compute product mod modulus
203        let mut reduced = modulus;
204        div(&mut product, &mut reduced);
205        reduced
206    }
207
208    fn mul_base<const N: usize>(a: [u64; N], modulus: [u64; N]) -> [u64; N] {
209        // Compute a * 2^(N * 64)
210        let mut product = vec![0; 2 * N];
211        product[N..].copy_from_slice(&a);
212
213        // Compute product mod modulus
214        let mut reduced = modulus;
215        div(&mut product, &mut reduced);
216        reduced
217    }
218
219    #[test]
220    fn test_mul_redc() {
221        const_for!(BITS in NON_ZERO if (BITS >= 16) {
222            const LIMBS: usize = nlimbs(BITS);
223            type U = Uint<BITS, LIMBS>;
224            proptest!(|(mut a: U, mut b: U, mut m: U)| {
225                m |= U::from(1_u64); // Make sure m is odd.
226                a %= m; // Make sure a is less than m.
227                b %= m; // Make sure b is less than m.
228                let a = *a.as_limbs();
229                let b = *b.as_limbs();
230                let m = *m.as_limbs();
231                let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0];
232
233                let result = mul_base(mul_redc(a, b, m, inv), m);
234                let expected = modmul(a, b, m);
235
236                prop_assert_eq!(result, expected);
237            });
238        });
239    }
240
241    #[test]
242    fn test_square_redc() {
243        const_for!(BITS in NON_ZERO if (BITS >= 16) {
244            const LIMBS: usize = nlimbs(BITS);
245            type U = Uint<BITS, LIMBS>;
246            proptest!(|(mut a: U, mut m: U)| {
247                m |= U::from(1_u64); // Make sure m is odd.
248                a %= m; // Make sure a is less than m.
249                let a = *a.as_limbs();
250                let m = *m.as_limbs();
251                let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0];
252
253                let result = mul_base(square_redc(a, m, inv), m);
254                let expected = modmul(a, a, m);
255
256                prop_assert_eq!(result, expected);
257            });
258        });
259    }
260}