ruint/
modular.rs

1use crate::{algorithms, Uint};
2
3// FEATURE: sub_mod, neg_mod, inv_mod, div_mod, root_mod
4// See <https://en.wikipedia.org/wiki/Cipolla's_algorithm>
5// FEATURE: mul_mod_redc
6// and maybe barrett
7// See also <https://static1.squarespace.com/static/61f7cacf2d7af938cad5b81c/t/62deb4e0c434f7134c2730ee/1658762465114/modular_multiplication.pdf>
8// FEATURE: Modular wrapper class, like Wrapping.
9
10impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
11    /// ⚠️ Compute $\mod{\mathtt{self}}_{\mathtt{modulus}}$.
12    ///
13    /// **Warning.** This function is not part of the stable API.
14    ///
15    /// Returns zero if the modulus is zero.
16    // FEATURE: Reduce larger bit-sizes to smaller ones.
17    #[inline]
18    #[must_use]
19    pub fn reduce_mod(mut self, modulus: Self) -> Self {
20        if modulus.is_zero() {
21            return Self::ZERO;
22        }
23        if self >= modulus {
24            self %= modulus;
25        }
26        self
27    }
28
29    /// Compute $\mod{\mathtt{self} + \mathtt{rhs}}_{\mathtt{modulus}}$.
30    ///
31    /// Returns zero if the modulus is zero.
32    #[inline]
33    #[must_use]
34    pub fn add_mod(self, rhs: Self, modulus: Self) -> Self {
35        // Reduce inputs
36        let lhs = self.reduce_mod(modulus);
37        let rhs = rhs.reduce_mod(modulus);
38
39        // Compute the sum and conditionally subtract modulus once.
40        let (mut result, overflow) = lhs.overflowing_add(rhs);
41        if overflow || result >= modulus {
42            result -= modulus;
43        }
44        result
45    }
46
47    /// Compute $\mod{\mathtt{self} ⋅ \mathtt{rhs}}_{\mathtt{modulus}}$.
48    ///
49    /// Returns zero if the modulus is zero.
50    ///
51    /// See [`mul_redc`](Self::mul_redc) for a faster variant at the cost of
52    /// some pre-computation.
53    #[inline]
54    #[must_use]
55    pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
56        if modulus.is_zero() {
57            return Self::ZERO;
58        }
59
60        // Allocate at least `nlimbs(2 * BITS)` limbs to store the product. This array
61        // casting is a workaround for `generic_const_exprs` not being stable.
62        let mut product = [[0u64; 2]; LIMBS];
63        let product_len = crate::nlimbs(2 * BITS);
64        debug_assert!(2 * LIMBS >= product_len);
65        // SAFETY: `[[u64; 2]; LIMBS] == [u64; 2 * LIMBS] >= [u64; nlimbs(2 * BITS)]`.
66        let product = unsafe {
67            core::slice::from_raw_parts_mut(product.as_mut_ptr().cast::<u64>(), product_len)
68        };
69
70        // Compute full product.
71        let overflow = algorithms::addmul(product, self.as_limbs(), rhs.as_limbs());
72        debug_assert!(!overflow);
73
74        // Compute modulus using `div_rem`.
75        // This stores the remainder in the divisor, `modulus`.
76        algorithms::div(product, &mut modulus.limbs);
77
78        modulus
79    }
80
81    /// Compute $\mod{\mathtt{self}^{\mathtt{rhs}}}_{\mathtt{modulus}}$.
82    ///
83    /// Returns zero if the modulus is zero.
84    #[inline]
85    #[must_use]
86    pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self {
87        if modulus.is_zero() || modulus <= Self::from(1) {
88            // Also covers Self::BITS == 0
89            return Self::ZERO;
90        }
91
92        // Exponentiation by squaring
93        let mut result = Self::from(1);
94        while exp > Self::ZERO {
95            // Multiply by base
96            if exp.limbs[0] & 1 == 1 {
97                result = result.mul_mod(self, modulus);
98            }
99
100            // Square base
101            self = self.mul_mod(self, modulus);
102            exp >>= 1;
103        }
104        result
105    }
106
107    /// Compute $\mod{\mathtt{self}^{-1}}_{\mathtt{modulus}}$.
108    ///
109    /// Returns `None` if the inverse does not exist.
110    #[inline]
111    #[must_use]
112    pub fn inv_mod(self, modulus: Self) -> Option<Self> {
113        algorithms::inv_mod(self, modulus)
114    }
115
116    /// Montgomery multiplication.
117    ///
118    /// Requires `self` and `other` to be less than `modulus`.
119    ///
120    /// Computes
121    ///
122    /// $$
123    /// \mod{\frac{\mathtt{self} ⋅ \mathtt{other}}{ 2^{64 ·
124    /// \mathtt{LIMBS}}}}_{\mathtt{modulus}} $$
125    ///
126    /// This is useful because it can be computed notably faster than
127    /// [`mul_mod`](Self::mul_mod). Many computations can be done by
128    /// pre-multiplying values with $R = 2^{64 · \mathtt{LIMBS}}$
129    /// and then using [`mul_redc`](Self::mul_redc) instead of
130    /// [`mul_mod`](Self::mul_mod).
131    ///
132    /// For this algorithm to work, it needs an extra parameter `inv` which must
133    /// be set to
134    ///
135    /// $$
136    /// \mathtt{inv} = \mod{\frac{-1}{\mathtt{modulus}} }_{2^{64}}
137    /// $$
138    ///
139    /// The `inv` value only exists for odd values of `modulus`. It can be
140    /// computed using [`inv_ring`](Self::inv_ring) from `U64`.
141    ///
142    /// ```
143    /// # use ruint::{uint, Uint, aliases::*};
144    /// # uint!{
145    /// # let modulus = 21888242871839275222246405745257275088548364400416034343698204186575808495617_U256;
146    /// let inv = U64::wrapping_from(modulus).inv_ring().unwrap().wrapping_neg().to();
147    /// let prod = 5_U256.mul_redc(6_U256, modulus, inv);
148    /// # assert_eq!(inv.wrapping_mul(modulus.wrapping_to()), u64::MAX);
149    /// # assert_eq!(inv, 0xc2e1f593efffffff);
150    /// # }
151    /// ```
152    ///
153    /// # Panics
154    ///
155    /// Panics if `inv` is not correct in debug mode.
156    #[inline]
157    #[must_use]
158    pub fn mul_redc(self, other: Self, modulus: Self, inv: u64) -> Self {
159        if BITS == 0 {
160            return Self::ZERO;
161        }
162        let result = algorithms::mul_redc(self.limbs, other.limbs, modulus.limbs, inv);
163        let result = Self::from_limbs(result);
164        debug_assert!(result < modulus);
165        result
166    }
167
168    /// Montgomery squaring.
169    ///
170    /// See [Self::mul_redc].
171    #[inline]
172    #[must_use]
173    pub fn square_redc(self, modulus: Self, inv: u64) -> Self {
174        if BITS == 0 {
175            return Self::ZERO;
176        }
177        let result = algorithms::square_redc(self.limbs, modulus.limbs, inv);
178        let result = Self::from_limbs(result);
179        debug_assert!(result < modulus);
180        result
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::{aliases::U64, const_for, nlimbs};
188    use proptest::{prop_assume, proptest, test_runner::Config};
189
190    #[test]
191    fn test_commutative() {
192        const_for!(BITS in SIZES {
193            const LIMBS: usize = nlimbs(BITS);
194            type U = Uint<BITS, LIMBS>;
195            proptest!(|(a: U, b: U, m: U)| {
196                assert_eq!(a.mul_mod(b, m), b.mul_mod(a, m));
197            });
198        });
199    }
200
201    #[test]
202    fn test_associative() {
203        const_for!(BITS in SIZES {
204            const LIMBS: usize = nlimbs(BITS);
205            type U = Uint<BITS, LIMBS>;
206            proptest!(|(a: U, b: U, c: U, m: U)| {
207                assert_eq!(a.mul_mod(b.mul_mod(c, m), m), a.mul_mod(b, m).mul_mod(c, m));
208            });
209        });
210    }
211
212    #[test]
213    fn test_distributive() {
214        const_for!(BITS in SIZES {
215            const LIMBS: usize = nlimbs(BITS);
216            type U = Uint<BITS, LIMBS>;
217            proptest!(|(a: U, b: U, c: U, m: U)| {
218                assert_eq!(a.mul_mod(b.add_mod(c, m), m), a.mul_mod(b, m).add_mod(a.mul_mod(c, m), m));
219            });
220        });
221    }
222
223    #[test]
224    fn test_add_identity() {
225        const_for!(BITS in NON_ZERO {
226            const LIMBS: usize = nlimbs(BITS);
227            type U = Uint<BITS, LIMBS>;
228            proptest!(|(value: U, m: U)| {
229                assert_eq!(value.add_mod(U::from(0), m), value.reduce_mod(m));
230            });
231        });
232    }
233
234    #[test]
235    fn test_mul_identity() {
236        const_for!(BITS in NON_ZERO {
237            const LIMBS: usize = nlimbs(BITS);
238            type U = Uint<BITS, LIMBS>;
239            proptest!(|(value: U, m: U)| {
240                assert_eq!(value.mul_mod(U::from(0), m), U::ZERO);
241                assert_eq!(value.mul_mod(U::from(1), m), value.reduce_mod(m));
242            });
243        });
244    }
245
246    #[test]
247    fn test_pow_identity() {
248        const_for!(BITS in NON_ZERO {
249            const LIMBS: usize = nlimbs(BITS);
250            type U = Uint<BITS, LIMBS>;
251            proptest!(|(a: U, m: U)| {
252                assert_eq!(a.pow_mod(U::from(0), m), U::from(1).reduce_mod(m));
253                assert_eq!(a.pow_mod(U::from(1), m), a.reduce_mod(m));
254            });
255        });
256    }
257
258    #[test]
259    fn test_pow_rules() {
260        const_for!(BITS in NON_ZERO {
261            const LIMBS: usize = nlimbs(BITS);
262            type U = Uint<BITS, LIMBS>;
263
264            // Too slow.
265            if LIMBS > 8 {
266                return;
267            }
268
269            let config = Config { cases: 5, ..Default::default() };
270            proptest!(config, |(a: U, b: U, c: U, m: U)| {
271                // TODO: a^(b+c) = a^b * a^c. Which requires carmichael fn.
272                // TODO: (a^b)^c = a^(b * c). Which requires carmichael fn.
273                assert_eq!(a.mul_mod(b, m).pow_mod(c, m), a.pow_mod(c, m).mul_mod(b.pow_mod(c, m), m));
274            });
275        });
276    }
277
278    #[test]
279    fn test_inv() {
280        const_for!(BITS in NON_ZERO {
281            const LIMBS: usize = nlimbs(BITS);
282            type U = Uint<BITS, LIMBS>;
283            proptest!(|(a: U, m: U)| {
284                if let Some(inv) = a.inv_mod(m) {
285                    assert_eq!(a.mul_mod(inv, m), U::from(1));
286                }
287            });
288        });
289    }
290
291    #[test]
292    fn test_mul_redc() {
293        const_for!(BITS in NON_ZERO if (BITS >= 16) {
294            const LIMBS: usize = nlimbs(BITS);
295            type U = Uint<BITS, LIMBS>;
296            proptest!(|(a: U, b: U, m: U)| {
297                prop_assume!(m >= U::from(2));
298                if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
299                    let inv = (-inv).as_limbs()[0];
300
301                    let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
302                    let ar = a.mul_mod(r, m);
303                    let br = b.mul_mod(r, m);
304                    // TODO: Test for larger (>= m) values of a, b.
305
306                    let expected = a.mul_mod(b, m).mul_mod(r, m);
307
308                    assert_eq!(ar.mul_redc(br, m, inv), expected);
309                }
310            });
311        });
312    }
313
314    #[test]
315    fn test_square_redc() {
316        const_for!(BITS in NON_ZERO if (BITS >= 16) {
317            const LIMBS: usize = nlimbs(BITS);
318            type U = Uint<BITS, LIMBS>;
319            proptest!(|(a: U, m: U)| {
320                prop_assume!(m >= U::from(2));
321                if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
322                    let inv = (-inv).as_limbs()[0];
323
324                    let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
325                    let ar = a.mul_mod(r, m);
326                    // TODO: Test for larger (>= m) values of a, b.
327
328                    let expected = a.mul_mod(a, m).mul_mod(r, m);
329
330                    assert_eq!(ar.square_redc(m, inv), expected);
331                }
332            });
333        });
334    }
335}