ruint/
mul.rs

1use crate::{algorithms, nlimbs, Uint};
2use core::{
3    iter::Product,
4    num::Wrapping,
5    ops::{Mul, MulAssign},
6};
7
8impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
9    /// Computes `self * rhs`, returning [`None`] if overflow occurred.
10    #[inline(always)]
11    #[must_use]
12    pub fn checked_mul(self, rhs: Self) -> Option<Self> {
13        match self.overflowing_mul(rhs) {
14            (value, false) => Some(value),
15            _ => None,
16        }
17    }
18
19    /// Calculates the multiplication of self and rhs.
20    ///
21    /// Returns a tuple of the multiplication along with a boolean indicating
22    /// whether an arithmetic overflow would occur. If an overflow would have
23    /// occurred then the wrapped value is returned.
24    ///
25    /// # Examples
26    ///
27    /// ```
28    /// # use ruint::{Uint, uint};
29    /// # uint!{
30    /// assert_eq!(1_U1.overflowing_mul(1_U1), (1_U1, false));
31    /// assert_eq!(
32    ///     0x010000000000000000_U65.overflowing_mul(0x010000000000000000_U65),
33    ///     (0x000000000000000000_U65, true)
34    /// );
35    /// # }
36    /// ```
37    #[inline]
38    #[must_use]
39    pub fn overflowing_mul(self, rhs: Self) -> (Self, bool) {
40        let mut result = Self::ZERO;
41        let mut overflow = algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
42        if BITS > 0 {
43            overflow |= result.limbs[LIMBS - 1] > Self::MASK;
44            result.apply_mask();
45        }
46        (result, overflow)
47    }
48
49    /// Computes `self * rhs`, saturating at the numeric bounds instead of
50    /// overflowing.
51    #[inline(always)]
52    #[must_use]
53    pub fn saturating_mul(self, rhs: Self) -> Self {
54        match self.overflowing_mul(rhs) {
55            (value, false) => value,
56            _ => Self::MAX,
57        }
58    }
59
60    /// Computes `self * rhs`, wrapping around at the boundary of the type.
61    #[cfg(not(target_os = "zkvm"))]
62    #[inline(always)]
63    #[must_use]
64    pub fn wrapping_mul(self, rhs: Self) -> Self {
65        let mut result = Self::ZERO;
66        algorithms::addmul_n(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
67        if BITS > 0 {
68            result.apply_mask();
69        }
70        result
71    }
72
73    /// Computes `self * rhs`, wrapping around at the boundary of the type.
74    #[cfg(target_os = "zkvm")]
75    #[inline(always)]
76    #[must_use]
77    pub fn wrapping_mul(mut self, rhs: Self) -> Self {
78        use crate::support::zkvm::zkvm_u256_wrapping_mul_impl;
79        if BITS == 256 {
80            unsafe {
81                zkvm_u256_wrapping_mul_impl(
82                    self.limbs.as_mut_ptr() as *mut u8,
83                    self.limbs.as_ptr() as *const u8,
84                    rhs.limbs.as_ptr() as *const u8,
85                );
86            }
87            return self;
88        }
89        self.overflowing_mul(rhs).0
90    }
91
92    /// Computes the inverse modulo $2^{\mathtt{BITS}}$ of `self`, returning
93    /// [`None`] if the inverse does not exist.
94    #[inline]
95    #[must_use]
96    pub fn inv_ring(self) -> Option<Self> {
97        if BITS == 0 || self.limbs[0] & 1 == 0 {
98            return None;
99        }
100
101        // Compute inverse of first limb
102        let mut result = Self::ZERO;
103        result.limbs[0] = {
104            const W2: Wrapping<u64> = Wrapping(2);
105            const W3: Wrapping<u64> = Wrapping(3);
106            let n = Wrapping(self.limbs[0]);
107            let mut inv = (n * W3) ^ W2; // Correct on 4 bits.
108            inv *= W2 - n * inv; // Correct on 8 bits.
109            inv *= W2 - n * inv; // Correct on 16 bits.
110            inv *= W2 - n * inv; // Correct on 32 bits.
111            inv *= W2 - n * inv; // Correct on 64 bits.
112            debug_assert_eq!(n.0.wrapping_mul(inv.0), 1);
113            inv.0
114        };
115
116        // Continue with rest of limbs
117        let mut correct_limbs = 1;
118        while correct_limbs < LIMBS {
119            result *= Self::from(2) - self * result;
120            correct_limbs *= 2;
121        }
122        result.apply_mask();
123
124        Some(result)
125    }
126
127    /// Calculates the complete product `self * rhs` without the possibility to
128    /// overflow.
129    ///
130    /// The argument `rhs` can be any size [`Uint`], the result size is the sum
131    /// of the bit-sizes of `self` and `rhs`.
132    ///
133    /// # Panics
134    ///
135    /// This function will runtime panic of the const generic arguments are
136    /// incorrect.
137    ///
138    /// # Examples
139    ///
140    /// ```
141    /// # use ruint::{Uint, uint};
142    /// # uint!{
143    /// assert_eq!(0_U0.widening_mul(0_U0), 0_U0);
144    /// assert_eq!(1_U1.widening_mul(1_U1), 1_U2);
145    /// assert_eq!(3_U2.widening_mul(7_U3), 21_U5);
146    /// # }
147    /// ```
148    #[inline]
149    #[must_use]
150    #[allow(clippy::similar_names)] // Don't confuse `res` and `rhs`.
151    pub fn widening_mul<
152        const BITS_RHS: usize,
153        const LIMBS_RHS: usize,
154        const BITS_RES: usize,
155        const LIMBS_RES: usize,
156    >(
157        self,
158        rhs: Uint<BITS_RHS, LIMBS_RHS>,
159    ) -> Uint<BITS_RES, LIMBS_RES> {
160        assert_eq!(BITS_RES, BITS + BITS_RHS);
161        assert_eq!(LIMBS_RES, nlimbs(BITS_RES));
162        let mut result = Uint::<BITS_RES, LIMBS_RES>::ZERO;
163        algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
164        if LIMBS_RES > 0 {
165            debug_assert!(result.limbs[LIMBS_RES - 1] <= Uint::<BITS_RES, LIMBS_RES>::MASK);
166        }
167
168        result
169    }
170}
171
172impl<const BITS: usize, const LIMBS: usize> Product<Self> for Uint<BITS, LIMBS> {
173    #[inline]
174    fn product<I>(iter: I) -> Self
175    where
176        I: Iterator<Item = Self>,
177    {
178        if BITS == 0 {
179            return Self::ZERO;
180        }
181        iter.fold(Self::ONE, Self::wrapping_mul)
182    }
183}
184
185impl<'a, const BITS: usize, const LIMBS: usize> Product<&'a Self> for Uint<BITS, LIMBS> {
186    #[inline]
187    fn product<I>(iter: I) -> Self
188    where
189        I: Iterator<Item = &'a Self>,
190    {
191        if BITS == 0 {
192            return Self::ZERO;
193        }
194        iter.copied().fold(Self::ONE, Self::wrapping_mul)
195    }
196}
197
198impl_bin_op!(Mul, mul, MulAssign, mul_assign, wrapping_mul);
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::const_for;
204    use proptest::proptest;
205
206    #[test]
207    fn test_commutative() {
208        const_for!(BITS in SIZES {
209            const LIMBS: usize = nlimbs(BITS);
210            type U = Uint<BITS, LIMBS>;
211            proptest!(|(a: U, b: U)| {
212                assert_eq!(a * b, b * a);
213            });
214        });
215    }
216
217    #[test]
218    fn test_associative() {
219        const_for!(BITS in SIZES {
220            const LIMBS: usize = nlimbs(BITS);
221            type U = Uint<BITS, LIMBS>;
222            proptest!(|(a: U, b: U, c: U)| {
223                assert_eq!(a * (b * c), (a * b) * c);
224            });
225        });
226    }
227
228    #[test]
229    fn test_distributive() {
230        const_for!(BITS in SIZES {
231            const LIMBS: usize = nlimbs(BITS);
232            type U = Uint<BITS, LIMBS>;
233            proptest!(|(a: U, b: U, c: U)| {
234                assert_eq!(a * (b + c), (a * b) + (a *c));
235            });
236        });
237    }
238
239    #[test]
240    fn test_identity() {
241        const_for!(BITS in NON_ZERO {
242            const LIMBS: usize = nlimbs(BITS);
243            type U = Uint<BITS, LIMBS>;
244            proptest!(|(value: U)| {
245                assert_eq!(value * U::from(0), U::ZERO);
246                assert_eq!(value * U::from(1), value);
247            });
248        });
249    }
250
251    #[test]
252    fn test_inverse() {
253        const_for!(BITS in NON_ZERO {
254            const LIMBS: usize = nlimbs(BITS);
255            type U = Uint<BITS, LIMBS>;
256            proptest!(|(mut a: U)| {
257                a |= U::from(1); // Make sure a is invertible
258                assert_eq!(a * a.inv_ring().unwrap(), U::from(1));
259                assert_eq!(a.inv_ring().unwrap().inv_ring().unwrap(), a);
260            });
261        });
262    }
263
264    #[test]
265    fn test_widening_mul() {
266        // Left hand side
267        const_for!(BITS_LHS in BENCH {
268            const LIMBS_LHS: usize = nlimbs(BITS_LHS);
269            type Lhs = Uint<BITS_LHS, LIMBS_LHS>;
270
271            // Right hand side
272            const_for!(BITS_RHS in BENCH {
273                const LIMBS_RHS: usize = nlimbs(BITS_RHS);
274                type Rhs = Uint<BITS_RHS, LIMBS_RHS>;
275
276                // Result
277                const BITS_RES: usize = BITS_LHS + BITS_RHS;
278                const LIMBS_RES: usize = nlimbs(BITS_RES);
279                type Res = Uint<BITS_RES, LIMBS_RES>;
280
281                proptest!(|(lhs: Lhs, rhs: Rhs)| {
282                    // Compute the result using the target size
283                    let expected = Res::from(lhs) * Res::from(rhs);
284                    assert_eq!(lhs.widening_mul(rhs), expected);
285                });
286            });
287        });
288    }
289}