1use crate::{algorithms, nlimbs, Uint};
2use core::{
3 iter::Product,
4 num::Wrapping,
5 ops::{Mul, MulAssign},
6};
7
8impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
9 #[inline(always)]
11 #[must_use]
12 pub fn checked_mul(self, rhs: Self) -> Option<Self> {
13 match self.overflowing_mul(rhs) {
14 (value, false) => Some(value),
15 _ => None,
16 }
17 }
18
19 #[inline]
38 #[must_use]
39 pub fn overflowing_mul(self, rhs: Self) -> (Self, bool) {
40 let mut result = Self::ZERO;
41 let mut overflow = algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
42 if BITS > 0 {
43 overflow |= result.limbs[LIMBS - 1] > Self::MASK;
44 result.apply_mask();
45 }
46 (result, overflow)
47 }
48
49 #[inline(always)]
52 #[must_use]
53 pub fn saturating_mul(self, rhs: Self) -> Self {
54 match self.overflowing_mul(rhs) {
55 (value, false) => value,
56 _ => Self::MAX,
57 }
58 }
59
60 #[cfg(not(target_os = "zkvm"))]
62 #[inline(always)]
63 #[must_use]
64 pub fn wrapping_mul(self, rhs: Self) -> Self {
65 let mut result = Self::ZERO;
66 algorithms::addmul_n(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
67 if BITS > 0 {
68 result.apply_mask();
69 }
70 result
71 }
72
73 #[cfg(target_os = "zkvm")]
75 #[inline(always)]
76 #[must_use]
77 pub fn wrapping_mul(mut self, rhs: Self) -> Self {
78 use crate::support::zkvm::zkvm_u256_wrapping_mul_impl;
79 if BITS == 256 {
80 unsafe {
81 zkvm_u256_wrapping_mul_impl(
82 self.limbs.as_mut_ptr() as *mut u8,
83 self.limbs.as_ptr() as *const u8,
84 rhs.limbs.as_ptr() as *const u8,
85 );
86 }
87 return self;
88 }
89 self.overflowing_mul(rhs).0
90 }
91
92 #[inline]
95 #[must_use]
96 pub fn inv_ring(self) -> Option<Self> {
97 if BITS == 0 || self.limbs[0] & 1 == 0 {
98 return None;
99 }
100
101 let mut result = Self::ZERO;
103 result.limbs[0] = {
104 const W2: Wrapping<u64> = Wrapping(2);
105 const W3: Wrapping<u64> = Wrapping(3);
106 let n = Wrapping(self.limbs[0]);
107 let mut inv = (n * W3) ^ W2; inv *= W2 - n * inv; inv *= W2 - n * inv; inv *= W2 - n * inv; inv *= W2 - n * inv; debug_assert_eq!(n.0.wrapping_mul(inv.0), 1);
113 inv.0
114 };
115
116 let mut correct_limbs = 1;
118 while correct_limbs < LIMBS {
119 result *= Self::from(2) - self * result;
120 correct_limbs *= 2;
121 }
122 result.apply_mask();
123
124 Some(result)
125 }
126
127 #[inline]
149 #[must_use]
150 #[allow(clippy::similar_names)] pub fn widening_mul<
152 const BITS_RHS: usize,
153 const LIMBS_RHS: usize,
154 const BITS_RES: usize,
155 const LIMBS_RES: usize,
156 >(
157 self,
158 rhs: Uint<BITS_RHS, LIMBS_RHS>,
159 ) -> Uint<BITS_RES, LIMBS_RES> {
160 assert_eq!(BITS_RES, BITS + BITS_RHS);
161 assert_eq!(LIMBS_RES, nlimbs(BITS_RES));
162 let mut result = Uint::<BITS_RES, LIMBS_RES>::ZERO;
163 algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
164 if LIMBS_RES > 0 {
165 debug_assert!(result.limbs[LIMBS_RES - 1] <= Uint::<BITS_RES, LIMBS_RES>::MASK);
166 }
167
168 result
169 }
170}
171
172impl<const BITS: usize, const LIMBS: usize> Product<Self> for Uint<BITS, LIMBS> {
173 #[inline]
174 fn product<I>(iter: I) -> Self
175 where
176 I: Iterator<Item = Self>,
177 {
178 if BITS == 0 {
179 return Self::ZERO;
180 }
181 iter.fold(Self::ONE, Self::wrapping_mul)
182 }
183}
184
185impl<'a, const BITS: usize, const LIMBS: usize> Product<&'a Self> for Uint<BITS, LIMBS> {
186 #[inline]
187 fn product<I>(iter: I) -> Self
188 where
189 I: Iterator<Item = &'a Self>,
190 {
191 if BITS == 0 {
192 return Self::ZERO;
193 }
194 iter.copied().fold(Self::ONE, Self::wrapping_mul)
195 }
196}
197
198impl_bin_op!(Mul, mul, MulAssign, mul_assign, wrapping_mul);
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::const_for;
204 use proptest::proptest;
205
206 #[test]
207 fn test_commutative() {
208 const_for!(BITS in SIZES {
209 const LIMBS: usize = nlimbs(BITS);
210 type U = Uint<BITS, LIMBS>;
211 proptest!(|(a: U, b: U)| {
212 assert_eq!(a * b, b * a);
213 });
214 });
215 }
216
217 #[test]
218 fn test_associative() {
219 const_for!(BITS in SIZES {
220 const LIMBS: usize = nlimbs(BITS);
221 type U = Uint<BITS, LIMBS>;
222 proptest!(|(a: U, b: U, c: U)| {
223 assert_eq!(a * (b * c), (a * b) * c);
224 });
225 });
226 }
227
228 #[test]
229 fn test_distributive() {
230 const_for!(BITS in SIZES {
231 const LIMBS: usize = nlimbs(BITS);
232 type U = Uint<BITS, LIMBS>;
233 proptest!(|(a: U, b: U, c: U)| {
234 assert_eq!(a * (b + c), (a * b) + (a *c));
235 });
236 });
237 }
238
239 #[test]
240 fn test_identity() {
241 const_for!(BITS in NON_ZERO {
242 const LIMBS: usize = nlimbs(BITS);
243 type U = Uint<BITS, LIMBS>;
244 proptest!(|(value: U)| {
245 assert_eq!(value * U::from(0), U::ZERO);
246 assert_eq!(value * U::from(1), value);
247 });
248 });
249 }
250
251 #[test]
252 fn test_inverse() {
253 const_for!(BITS in NON_ZERO {
254 const LIMBS: usize = nlimbs(BITS);
255 type U = Uint<BITS, LIMBS>;
256 proptest!(|(mut a: U)| {
257 a |= U::from(1); assert_eq!(a * a.inv_ring().unwrap(), U::from(1));
259 assert_eq!(a.inv_ring().unwrap().inv_ring().unwrap(), a);
260 });
261 });
262 }
263
264 #[test]
265 fn test_widening_mul() {
266 const_for!(BITS_LHS in BENCH {
268 const LIMBS_LHS: usize = nlimbs(BITS_LHS);
269 type Lhs = Uint<BITS_LHS, LIMBS_LHS>;
270
271 const_for!(BITS_RHS in BENCH {
273 const LIMBS_RHS: usize = nlimbs(BITS_RHS);
274 type Rhs = Uint<BITS_RHS, LIMBS_RHS>;
275
276 const BITS_RES: usize = BITS_LHS + BITS_RHS;
278 const LIMBS_RES: usize = nlimbs(BITS_RES);
279 type Res = Uint<BITS_RES, LIMBS_RES>;
280
281 proptest!(|(lhs: Lhs, rhs: Rhs)| {
282 let expected = Res::from(lhs) * Res::from(rhs);
284 assert_eq!(lhs.widening_mul(rhs), expected);
285 });
286 });
287 });
288 }
289}