1use crate::{algorithms, nlimbs, Uint};
2use core::{
3 iter::Product,
4 num::Wrapping,
5 ops::{Mul, MulAssign},
6};
7
8impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
9 #[inline(always)]
11 #[must_use]
12 pub fn checked_mul(self, rhs: Self) -> Option<Self> {
13 match self.overflowing_mul(rhs) {
14 (value, false) => Some(value),
15 _ => None,
16 }
17 }
18
19 #[inline]
38 #[must_use]
39 pub fn overflowing_mul(self, rhs: Self) -> (Self, bool) {
40 let mut result = Self::ZERO;
41 let mut overflow = algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
42 if BITS > 0 {
43 overflow |= result.limbs[LIMBS - 1] > Self::MASK;
44 result.limbs[LIMBS - 1] &= Self::MASK;
45 }
46 (result, overflow)
47 }
48
49 #[inline(always)]
52 #[must_use]
53 pub fn saturating_mul(self, rhs: Self) -> Self {
54 match self.overflowing_mul(rhs) {
55 (value, false) => value,
56 _ => Self::MAX,
57 }
58 }
59
60 #[inline(always)]
62 #[must_use]
63 pub fn wrapping_mul(self, rhs: Self) -> Self {
64 let mut result = Self::ZERO;
65 algorithms::addmul_n(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
66 if BITS > 0 {
67 result.limbs[LIMBS - 1] &= Self::MASK;
68 }
69 result
70 }
71
72 #[inline]
75 #[must_use]
76 pub fn inv_ring(self) -> Option<Self> {
77 if BITS == 0 || self.limbs[0] & 1 == 0 {
78 return None;
79 }
80
81 let mut result = Self::ZERO;
83 result.limbs[0] = {
84 const W2: Wrapping<u64> = Wrapping(2);
85 const W3: Wrapping<u64> = Wrapping(3);
86 let n = Wrapping(self.limbs[0]);
87 let mut inv = (n * W3) ^ W2; inv *= W2 - n * inv; inv *= W2 - n * inv; inv *= W2 - n * inv; inv *= W2 - n * inv; debug_assert_eq!(n.0.wrapping_mul(inv.0), 1);
93 inv.0
94 };
95
96 let mut correct_limbs = 1;
98 while correct_limbs < LIMBS {
99 result *= Self::from(2) - self * result;
100 correct_limbs *= 2;
101 }
102 result.limbs[LIMBS - 1] &= Self::MASK;
103
104 Some(result)
105 }
106
107 #[inline]
129 #[must_use]
130 #[allow(clippy::similar_names)] pub fn widening_mul<
132 const BITS_RHS: usize,
133 const LIMBS_RHS: usize,
134 const BITS_RES: usize,
135 const LIMBS_RES: usize,
136 >(
137 self,
138 rhs: Uint<BITS_RHS, LIMBS_RHS>,
139 ) -> Uint<BITS_RES, LIMBS_RES> {
140 assert_eq!(BITS_RES, BITS + BITS_RHS);
141 assert_eq!(LIMBS_RES, nlimbs(BITS_RES));
142 let mut result = Uint::<BITS_RES, LIMBS_RES>::ZERO;
143 algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
144 if LIMBS_RES > 0 {
145 debug_assert!(result.limbs[LIMBS_RES - 1] <= Uint::<BITS_RES, LIMBS_RES>::MASK);
146 }
147
148 result
149 }
150}
151
152impl<const BITS: usize, const LIMBS: usize> Product<Self> for Uint<BITS, LIMBS> {
153 #[inline]
154 fn product<I>(iter: I) -> Self
155 where
156 I: Iterator<Item = Self>,
157 {
158 if BITS == 0 {
159 return Self::ZERO;
160 }
161 iter.fold(Self::from(1), Self::wrapping_mul)
162 }
163}
164
165impl<'a, const BITS: usize, const LIMBS: usize> Product<&'a Self> for Uint<BITS, LIMBS> {
166 #[inline]
167 fn product<I>(iter: I) -> Self
168 where
169 I: Iterator<Item = &'a Self>,
170 {
171 if BITS == 0 {
172 return Self::ZERO;
173 }
174 iter.copied().fold(Self::from(1), Self::wrapping_mul)
175 }
176}
177
178impl_bin_op!(Mul, mul, MulAssign, mul_assign, wrapping_mul);
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::const_for;
184 use proptest::proptest;
185
186 #[test]
187 fn test_commutative() {
188 const_for!(BITS in SIZES {
189 const LIMBS: usize = nlimbs(BITS);
190 type U = Uint<BITS, LIMBS>;
191 proptest!(|(a: U, b: U)| {
192 assert_eq!(a * b, b * a);
193 });
194 });
195 }
196
197 #[test]
198 fn test_associative() {
199 const_for!(BITS in SIZES {
200 const LIMBS: usize = nlimbs(BITS);
201 type U = Uint<BITS, LIMBS>;
202 proptest!(|(a: U, b: U, c: U)| {
203 assert_eq!(a * (b * c), (a * b) * c);
204 });
205 });
206 }
207
208 #[test]
209 fn test_distributive() {
210 const_for!(BITS in SIZES {
211 const LIMBS: usize = nlimbs(BITS);
212 type U = Uint<BITS, LIMBS>;
213 proptest!(|(a: U, b: U, c: U)| {
214 assert_eq!(a * (b + c), (a * b) + (a *c));
215 });
216 });
217 }
218
219 #[test]
220 fn test_identity() {
221 const_for!(BITS in NON_ZERO {
222 const LIMBS: usize = nlimbs(BITS);
223 type U = Uint<BITS, LIMBS>;
224 proptest!(|(value: U)| {
225 assert_eq!(value * U::from(0), U::ZERO);
226 assert_eq!(value * U::from(1), value);
227 });
228 });
229 }
230
231 #[test]
232 fn test_inverse() {
233 const_for!(BITS in NON_ZERO {
234 const LIMBS: usize = nlimbs(BITS);
235 type U = Uint<BITS, LIMBS>;
236 proptest!(|(mut a: U)| {
237 a |= U::from(1); assert_eq!(a * a.inv_ring().unwrap(), U::from(1));
239 assert_eq!(a.inv_ring().unwrap().inv_ring().unwrap(), a);
240 });
241 });
242 }
243
244 #[test]
245 fn test_widening_mul() {
246 const_for!(BITS_LHS in BENCH {
248 const LIMBS_LHS: usize = nlimbs(BITS_LHS);
249 type Lhs = Uint<BITS_LHS, LIMBS_LHS>;
250
251 const_for!(BITS_RHS in BENCH {
253 const LIMBS_RHS: usize = nlimbs(BITS_RHS);
254 type Rhs = Uint<BITS_RHS, LIMBS_RHS>;
255
256 const BITS_RES: usize = BITS_LHS + BITS_RHS;
258 const LIMBS_RES: usize = nlimbs(BITS_RES);
259 type Res = Uint<BITS_RES, LIMBS_RES>;
260
261 proptest!(|(lhs: Lhs, rhs: Rhs)| {
262 let expected = Res::from(lhs) * Res::from(rhs);
264 assert_eq!(lhs.widening_mul(rhs), expected);
265 });
266 });
267 });
268 }
269}