1use crate::{algorithms, Uint};
2
3impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
11 #[inline]
18 #[must_use]
19 pub fn reduce_mod(mut self, modulus: Self) -> Self {
20 if modulus.is_zero() {
21 return Self::ZERO;
22 }
23 if self >= modulus {
24 self %= modulus;
25 }
26 self
27 }
28
29 #[inline]
33 #[must_use]
34 pub fn add_mod(self, rhs: Self, modulus: Self) -> Self {
35 let lhs = self.reduce_mod(modulus);
37 let rhs = rhs.reduce_mod(modulus);
38
39 let (mut result, overflow) = lhs.overflowing_add(rhs);
41 if overflow || result >= modulus {
42 result -= modulus;
43 }
44 result
45 }
46
47 #[inline]
54 #[must_use]
55 pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
56 if modulus.is_zero() {
57 return Self::ZERO;
58 }
59
60 let mut product = [[0u64; 2]; LIMBS];
63 let product_len = crate::nlimbs(2 * BITS);
64 debug_assert!(2 * LIMBS >= product_len);
65 let product = unsafe {
67 core::slice::from_raw_parts_mut(product.as_mut_ptr().cast::<u64>(), product_len)
68 };
69
70 let overflow = algorithms::addmul(product, self.as_limbs(), rhs.as_limbs());
72 debug_assert!(!overflow);
73
74 algorithms::div(product, &mut modulus.limbs);
77
78 modulus
79 }
80
81 #[inline]
85 #[must_use]
86 pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self {
87 if modulus.is_zero() || modulus <= Self::from(1) {
88 return Self::ZERO;
90 }
91
92 let mut result = Self::from(1);
94 while exp > Self::ZERO {
95 if exp.limbs[0] & 1 == 1 {
97 result = result.mul_mod(self, modulus);
98 }
99
100 self = self.mul_mod(self, modulus);
102 exp >>= 1;
103 }
104 result
105 }
106
107 #[inline]
111 #[must_use]
112 pub fn inv_mod(self, modulus: Self) -> Option<Self> {
113 algorithms::inv_mod(self, modulus)
114 }
115
116 #[inline]
157 #[must_use]
158 pub fn mul_redc(self, other: Self, modulus: Self, inv: u64) -> Self {
159 if BITS == 0 {
160 return Self::ZERO;
161 }
162 let result = algorithms::mul_redc(self.limbs, other.limbs, modulus.limbs, inv);
163 let result = Self::from_limbs(result);
164 debug_assert!(result < modulus);
165 result
166 }
167
168 #[inline]
172 #[must_use]
173 pub fn square_redc(self, modulus: Self, inv: u64) -> Self {
174 if BITS == 0 {
175 return Self::ZERO;
176 }
177 let result = algorithms::square_redc(self.limbs, modulus.limbs, inv);
178 let result = Self::from_limbs(result);
179 debug_assert!(result < modulus);
180 result
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::{aliases::U64, const_for, nlimbs};
188 use proptest::{prop_assume, proptest, test_runner::Config};
189
190 #[test]
191 fn test_commutative() {
192 const_for!(BITS in SIZES {
193 const LIMBS: usize = nlimbs(BITS);
194 type U = Uint<BITS, LIMBS>;
195 proptest!(|(a: U, b: U, m: U)| {
196 assert_eq!(a.mul_mod(b, m), b.mul_mod(a, m));
197 });
198 });
199 }
200
201 #[test]
202 fn test_associative() {
203 const_for!(BITS in SIZES {
204 const LIMBS: usize = nlimbs(BITS);
205 type U = Uint<BITS, LIMBS>;
206 proptest!(|(a: U, b: U, c: U, m: U)| {
207 assert_eq!(a.mul_mod(b.mul_mod(c, m), m), a.mul_mod(b, m).mul_mod(c, m));
208 });
209 });
210 }
211
212 #[test]
213 fn test_distributive() {
214 const_for!(BITS in SIZES {
215 const LIMBS: usize = nlimbs(BITS);
216 type U = Uint<BITS, LIMBS>;
217 proptest!(|(a: U, b: U, c: U, m: U)| {
218 assert_eq!(a.mul_mod(b.add_mod(c, m), m), a.mul_mod(b, m).add_mod(a.mul_mod(c, m), m));
219 });
220 });
221 }
222
223 #[test]
224 fn test_add_identity() {
225 const_for!(BITS in NON_ZERO {
226 const LIMBS: usize = nlimbs(BITS);
227 type U = Uint<BITS, LIMBS>;
228 proptest!(|(value: U, m: U)| {
229 assert_eq!(value.add_mod(U::from(0), m), value.reduce_mod(m));
230 });
231 });
232 }
233
234 #[test]
235 fn test_mul_identity() {
236 const_for!(BITS in NON_ZERO {
237 const LIMBS: usize = nlimbs(BITS);
238 type U = Uint<BITS, LIMBS>;
239 proptest!(|(value: U, m: U)| {
240 assert_eq!(value.mul_mod(U::from(0), m), U::ZERO);
241 assert_eq!(value.mul_mod(U::from(1), m), value.reduce_mod(m));
242 });
243 });
244 }
245
246 #[test]
247 fn test_pow_identity() {
248 const_for!(BITS in NON_ZERO {
249 const LIMBS: usize = nlimbs(BITS);
250 type U = Uint<BITS, LIMBS>;
251 proptest!(|(a: U, m: U)| {
252 assert_eq!(a.pow_mod(U::from(0), m), U::from(1).reduce_mod(m));
253 assert_eq!(a.pow_mod(U::from(1), m), a.reduce_mod(m));
254 });
255 });
256 }
257
258 #[test]
259 fn test_pow_rules() {
260 const_for!(BITS in NON_ZERO {
261 const LIMBS: usize = nlimbs(BITS);
262 type U = Uint<BITS, LIMBS>;
263
264 if LIMBS > 8 {
266 return;
267 }
268
269 let config = Config { cases: 5, ..Default::default() };
270 proptest!(config, |(a: U, b: U, c: U, m: U)| {
271 assert_eq!(a.mul_mod(b, m).pow_mod(c, m), a.pow_mod(c, m).mul_mod(b.pow_mod(c, m), m));
274 });
275 });
276 }
277
278 #[test]
279 fn test_inv() {
280 const_for!(BITS in NON_ZERO {
281 const LIMBS: usize = nlimbs(BITS);
282 type U = Uint<BITS, LIMBS>;
283 proptest!(|(a: U, m: U)| {
284 if let Some(inv) = a.inv_mod(m) {
285 assert_eq!(a.mul_mod(inv, m), U::from(1));
286 }
287 });
288 });
289 }
290
291 #[test]
292 fn test_mul_redc() {
293 const_for!(BITS in NON_ZERO if (BITS >= 16) {
294 const LIMBS: usize = nlimbs(BITS);
295 type U = Uint<BITS, LIMBS>;
296 proptest!(|(a: U, b: U, m: U)| {
297 prop_assume!(m >= U::from(2));
298 if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
299 let inv = (-inv).as_limbs()[0];
300
301 let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
302 let ar = a.mul_mod(r, m);
303 let br = b.mul_mod(r, m);
304 let expected = a.mul_mod(b, m).mul_mod(r, m);
307
308 assert_eq!(ar.mul_redc(br, m, inv), expected);
309 }
310 });
311 });
312 }
313
314 #[test]
315 fn test_square_redc() {
316 const_for!(BITS in NON_ZERO if (BITS >= 16) {
317 const LIMBS: usize = nlimbs(BITS);
318 type U = Uint<BITS, LIMBS>;
319 proptest!(|(a: U, m: U)| {
320 prop_assume!(m >= U::from(2));
321 if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
322 let inv = (-inv).as_limbs()[0];
323
324 let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
325 let ar = a.mul_mod(r, m);
326 let expected = a.mul_mod(a, m).mul_mod(r, m);
329
330 assert_eq!(ar.square_redc(m, inv), expected);
331 }
332 });
333 });
334 }
335}