1use super::{borrowing_sub, carrying_add, cmp};
4use core::{cmp::Ordering, iter::zip};
5
6#[inline]
11#[must_use]
12pub fn mul_redc<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] {
13 debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX);
14 debug_assert_eq!(cmp(&a, &modulus), Ordering::Less);
15 debug_assert_eq!(cmp(&b, &modulus), Ordering::Less);
16
17 let mut result = [0; N];
22 let mut carry = false;
23 for b in b {
24 let mut m = 0;
25 let mut carry_1 = 0;
26 let mut carry_2 = 0;
27 for i in 0..N {
28 let (value, next_carry) = carrying_mul_add(a[i], b, result[i], carry_1);
30 carry_1 = next_carry;
31
32 if i == 0 {
33 m = value.wrapping_mul(inv);
35 }
36
37 let (value, next_carry) = carrying_mul_add(modulus[i], m, value, carry_2);
39 carry_2 = next_carry;
40
41 if i > 0 {
43 result[i - 1] = value;
44 } else {
45 debug_assert_eq!(value, 0);
46 }
47 }
48
49 let (value, next_carry) = carrying_add(carry_1, carry_2, carry);
51 result[N - 1] = value;
52 if modulus[N - 1] >= 0x7fff_ffff_ffff_ffff {
53 carry = next_carry;
54 } else {
55 debug_assert!(!next_carry);
56 }
57 }
58
59 reduce1_carry(result, modulus, carry)
61}
62
63#[inline]
68#[must_use]
69#[allow(clippy::cast_possible_truncation)]
70pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] {
71 debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX);
72 debug_assert_eq!(cmp(&a, &modulus), Ordering::Less);
73
74 let mut result = [0; N];
75 let mut carry_outer = 0;
76 for i in 0..N {
77 let (value, mut carry_lo) = carrying_mul_add(a[i], a[i], result[i], 0);
79 let mut carry_hi = false;
80 result[i] = value;
81 for j in (i + 1)..N {
82 let (value, next_carry_lo, next_carry_hi) =
83 carrying_double_mul_add(a[i], a[j], result[j], carry_lo, carry_hi);
84 result[j] = value;
85 carry_lo = next_carry_lo;
86 carry_hi = next_carry_hi;
87 }
88
89 let m = result[0].wrapping_mul(inv);
91 let (value, mut carry) = carrying_mul_add(m, modulus[0], result[0], 0);
92 debug_assert_eq!(value, 0);
93 for j in 1..N {
94 let (value, next_carry) = carrying_mul_add(modulus[j], m, result[j], carry);
95 result[j - 1] = value;
96 carry = next_carry;
97 }
98
99 if modulus[N - 1] >= 0x3fff_ffff_ffff_ffff {
101 let wide = (carry_outer as u128)
102 .wrapping_add(carry_lo as u128)
103 .wrapping_add((carry_hi as u128) << 64)
104 .wrapping_add(carry as u128);
105 result[N - 1] = wide as u64;
106
107 carry_outer = (wide >> 64) as u64;
109 debug_assert!(carry_outer <= 2);
110 } else {
111 debug_assert!(!carry_hi);
113 debug_assert_eq!(carry_outer, 0);
114 let (value, carry) = carry_lo.overflowing_add(carry);
115 debug_assert!(!carry);
116 result[N - 1] = value;
117 }
118 }
119
120 debug_assert!(carry_outer <= 1);
122 reduce1_carry(result, modulus, carry_outer > 0)
123}
124
125#[inline]
126#[must_use]
127#[allow(clippy::needless_bitwise_bool)]
128fn reduce1_carry<const N: usize>(value: [u64; N], modulus: [u64; N], carry: bool) -> [u64; N] {
129 let (reduced, borrow) = sub(value, modulus);
130 if carry | !borrow {
133 reduced
134 } else {
135 value
136 }
137}
138
139#[inline]
140#[must_use]
141fn sub<const N: usize>(lhs: [u64; N], rhs: [u64; N]) -> ([u64; N], bool) {
142 let mut result = [0; N];
143 let mut borrow = false;
144 for (result, (lhs, rhs)) in zip(&mut result, zip(lhs, rhs)) {
145 let (value, next_borrow) = borrowing_sub(lhs, rhs, borrow);
146 *result = value;
147 borrow = next_borrow;
148 }
149 (result, borrow)
150}
151
152#[inline]
155#[must_use]
156#[allow(clippy::cast_possible_truncation)]
157const fn carrying_mul_add(lhs: u64, rhs: u64, add: u64, carry: u64) -> (u64, u64) {
158 let wide = (lhs as u128)
159 .wrapping_mul(rhs as u128)
160 .wrapping_add(add as u128)
161 .wrapping_add(carry as u128);
162 (wide as u64, (wide >> 64) as u64)
163}
164
165#[inline]
168#[must_use]
169#[allow(clippy::cast_possible_truncation)]
170const fn carrying_double_mul_add(
171 lhs: u64,
172 rhs: u64,
173 add: u64,
174 carry_lo: u64,
175 carry_hi: bool,
176) -> (u64, u64, bool) {
177 let wide = (lhs as u128).wrapping_mul(rhs as u128);
178 let (wide, carry_1) = wide.overflowing_add(wide);
179 let carries = (add as u128)
180 .wrapping_add(carry_lo as u128)
181 .wrapping_add((carry_hi as u128) << 64);
182 let (wide, carry_2) = wide.overflowing_add(carries);
183 (wide as u64, (wide >> 64) as u64, carry_1 | carry_2)
184}
185
186#[cfg(test)]
187mod test {
188 use core::ops::Neg;
189
190 use super::{
191 super::{addmul, div},
192 *,
193 };
194 use crate::{aliases::U64, const_for, nlimbs, Uint};
195 use proptest::{prop_assert_eq, proptest};
196
197 fn modmul<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N]) -> [u64; N] {
198 let mut product = vec![0; 2 * N];
200 addmul(&mut product, &a, &b);
201
202 let mut reduced = modulus;
204 div(&mut product, &mut reduced);
205 reduced
206 }
207
208 fn mul_base<const N: usize>(a: [u64; N], modulus: [u64; N]) -> [u64; N] {
209 let mut product = vec![0; 2 * N];
211 product[N..].copy_from_slice(&a);
212
213 let mut reduced = modulus;
215 div(&mut product, &mut reduced);
216 reduced
217 }
218
219 #[test]
220 fn test_mul_redc() {
221 const_for!(BITS in NON_ZERO if (BITS >= 16) {
222 const LIMBS: usize = nlimbs(BITS);
223 type U = Uint<BITS, LIMBS>;
224 proptest!(|(mut a: U, mut b: U, mut m: U)| {
225 m |= U::from(1_u64); a %= m; b %= m; let a = *a.as_limbs();
229 let b = *b.as_limbs();
230 let m = *m.as_limbs();
231 let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0];
232
233 let result = mul_base(mul_redc(a, b, m, inv), m);
234 let expected = modmul(a, b, m);
235
236 prop_assert_eq!(result, expected);
237 });
238 });
239 }
240
241 #[test]
242 fn test_square_redc() {
243 const_for!(BITS in NON_ZERO if (BITS >= 16) {
244 const LIMBS: usize = nlimbs(BITS);
245 type U = Uint<BITS, LIMBS>;
246 proptest!(|(mut a: U, mut m: U)| {
247 m |= U::from(1_u64); a %= m; let a = *a.as_limbs();
250 let m = *m.as_limbs();
251 let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0];
252
253 let result = mul_base(square_redc(a, m, inv), m);
254 let expected = modmul(a, a, m);
255
256 prop_assert_eq!(result, expected);
257 });
258 });
259 }
260}