ruint/
mul.rs

1use crate::{algorithms, nlimbs, Uint};
2use core::{
3    iter::Product,
4    num::Wrapping,
5    ops::{Mul, MulAssign},
6};
7
8impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
9    /// Computes `self * rhs`, returning [`None`] if overflow occurred.
10    #[inline(always)]
11    #[must_use]
12    pub fn checked_mul(self, rhs: Self) -> Option<Self> {
13        match self.overflowing_mul(rhs) {
14            (value, false) => Some(value),
15            _ => None,
16        }
17    }
18
19    /// Calculates the multiplication of self and rhs.
20    ///
21    /// Returns a tuple of the multiplication along with a boolean indicating
22    /// whether an arithmetic overflow would occur. If an overflow would have
23    /// occurred then the wrapped value is returned.
24    ///
25    /// # Examples
26    ///
27    /// ```
28    /// # use ruint::{Uint, uint};
29    /// # uint!{
30    /// assert_eq!(1_U1.overflowing_mul(1_U1), (1_U1, false));
31    /// assert_eq!(
32    ///     0x010000000000000000_U65.overflowing_mul(0x010000000000000000_U65),
33    ///     (0x000000000000000000_U65, true)
34    /// );
35    /// # }
36    /// ```
37    #[inline]
38    #[must_use]
39    pub fn overflowing_mul(self, rhs: Self) -> (Self, bool) {
40        let mut result = Self::ZERO;
41        let mut overflow = algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
42        if BITS > 0 {
43            overflow |= result.limbs[LIMBS - 1] > Self::MASK;
44            result.limbs[LIMBS - 1] &= Self::MASK;
45        }
46        (result, overflow)
47    }
48
49    /// Computes `self * rhs`, saturating at the numeric bounds instead of
50    /// overflowing.
51    #[inline(always)]
52    #[must_use]
53    pub fn saturating_mul(self, rhs: Self) -> Self {
54        match self.overflowing_mul(rhs) {
55            (value, false) => value,
56            _ => Self::MAX,
57        }
58    }
59
60    /// Computes `self * rhs`, wrapping around at the boundary of the type.
61    #[inline(always)]
62    #[must_use]
63    pub fn wrapping_mul(self, rhs: Self) -> Self {
64        let mut result = Self::ZERO;
65        algorithms::addmul_n(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
66        if BITS > 0 {
67            result.limbs[LIMBS - 1] &= Self::MASK;
68        }
69        result
70    }
71
72    /// Computes the inverse modulo $2^{\mathtt{BITS}}$ of `self`, returning
73    /// [`None`] if the inverse does not exist.
74    #[inline]
75    #[must_use]
76    pub fn inv_ring(self) -> Option<Self> {
77        if BITS == 0 || self.limbs[0] & 1 == 0 {
78            return None;
79        }
80
81        // Compute inverse of first limb
82        let mut result = Self::ZERO;
83        result.limbs[0] = {
84            const W2: Wrapping<u64> = Wrapping(2);
85            const W3: Wrapping<u64> = Wrapping(3);
86            let n = Wrapping(self.limbs[0]);
87            let mut inv = (n * W3) ^ W2; // Correct on 4 bits.
88            inv *= W2 - n * inv; // Correct on 8 bits.
89            inv *= W2 - n * inv; // Correct on 16 bits.
90            inv *= W2 - n * inv; // Correct on 32 bits.
91            inv *= W2 - n * inv; // Correct on 64 bits.
92            debug_assert_eq!(n.0.wrapping_mul(inv.0), 1);
93            inv.0
94        };
95
96        // Continue with rest of limbs
97        let mut correct_limbs = 1;
98        while correct_limbs < LIMBS {
99            result *= Self::from(2) - self * result;
100            correct_limbs *= 2;
101        }
102        result.limbs[LIMBS - 1] &= Self::MASK;
103
104        Some(result)
105    }
106
107    /// Calculates the complete product `self * rhs` without the possibility to
108    /// overflow.
109    ///
110    /// The argument `rhs` can be any size [`Uint`], the result size is the sum
111    /// of the bit-sizes of `self` and `rhs`.
112    ///
113    /// # Panics
114    ///
115    /// This function will runtime panic of the const generic arguments are
116    /// incorrect.
117    ///
118    /// # Examples
119    ///
120    /// ```
121    /// # use ruint::{Uint, uint};
122    /// # uint!{
123    /// assert_eq!(0_U0.widening_mul(0_U0), 0_U0);
124    /// assert_eq!(1_U1.widening_mul(1_U1), 1_U2);
125    /// assert_eq!(3_U2.widening_mul(7_U3), 21_U5);
126    /// # }
127    /// ```
128    #[inline]
129    #[must_use]
130    #[allow(clippy::similar_names)] // Don't confuse `res` and `rhs`.
131    pub fn widening_mul<
132        const BITS_RHS: usize,
133        const LIMBS_RHS: usize,
134        const BITS_RES: usize,
135        const LIMBS_RES: usize,
136    >(
137        self,
138        rhs: Uint<BITS_RHS, LIMBS_RHS>,
139    ) -> Uint<BITS_RES, LIMBS_RES> {
140        assert_eq!(BITS_RES, BITS + BITS_RHS);
141        assert_eq!(LIMBS_RES, nlimbs(BITS_RES));
142        let mut result = Uint::<BITS_RES, LIMBS_RES>::ZERO;
143        algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
144        if LIMBS_RES > 0 {
145            debug_assert!(result.limbs[LIMBS_RES - 1] <= Uint::<BITS_RES, LIMBS_RES>::MASK);
146        }
147
148        result
149    }
150}
151
152impl<const BITS: usize, const LIMBS: usize> Product<Self> for Uint<BITS, LIMBS> {
153    #[inline]
154    fn product<I>(iter: I) -> Self
155    where
156        I: Iterator<Item = Self>,
157    {
158        if BITS == 0 {
159            return Self::ZERO;
160        }
161        iter.fold(Self::from(1), Self::wrapping_mul)
162    }
163}
164
165impl<'a, const BITS: usize, const LIMBS: usize> Product<&'a Self> for Uint<BITS, LIMBS> {
166    #[inline]
167    fn product<I>(iter: I) -> Self
168    where
169        I: Iterator<Item = &'a Self>,
170    {
171        if BITS == 0 {
172            return Self::ZERO;
173        }
174        iter.copied().fold(Self::from(1), Self::wrapping_mul)
175    }
176}
177
178impl_bin_op!(Mul, mul, MulAssign, mul_assign, wrapping_mul);
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::const_for;
184    use proptest::proptest;
185
186    #[test]
187    fn test_commutative() {
188        const_for!(BITS in SIZES {
189            const LIMBS: usize = nlimbs(BITS);
190            type U = Uint<BITS, LIMBS>;
191            proptest!(|(a: U, b: U)| {
192                assert_eq!(a * b, b * a);
193            });
194        });
195    }
196
197    #[test]
198    fn test_associative() {
199        const_for!(BITS in SIZES {
200            const LIMBS: usize = nlimbs(BITS);
201            type U = Uint<BITS, LIMBS>;
202            proptest!(|(a: U, b: U, c: U)| {
203                assert_eq!(a * (b * c), (a * b) * c);
204            });
205        });
206    }
207
208    #[test]
209    fn test_distributive() {
210        const_for!(BITS in SIZES {
211            const LIMBS: usize = nlimbs(BITS);
212            type U = Uint<BITS, LIMBS>;
213            proptest!(|(a: U, b: U, c: U)| {
214                assert_eq!(a * (b + c), (a * b) + (a *c));
215            });
216        });
217    }
218
219    #[test]
220    fn test_identity() {
221        const_for!(BITS in NON_ZERO {
222            const LIMBS: usize = nlimbs(BITS);
223            type U = Uint<BITS, LIMBS>;
224            proptest!(|(value: U)| {
225                assert_eq!(value * U::from(0), U::ZERO);
226                assert_eq!(value * U::from(1), value);
227            });
228        });
229    }
230
231    #[test]
232    fn test_inverse() {
233        const_for!(BITS in NON_ZERO {
234            const LIMBS: usize = nlimbs(BITS);
235            type U = Uint<BITS, LIMBS>;
236            proptest!(|(mut a: U)| {
237                a |= U::from(1); // Make sure a is invertible
238                assert_eq!(a * a.inv_ring().unwrap(), U::from(1));
239                assert_eq!(a.inv_ring().unwrap().inv_ring().unwrap(), a);
240            });
241        });
242    }
243
244    #[test]
245    fn test_widening_mul() {
246        // Left hand side
247        const_for!(BITS_LHS in BENCH {
248            const LIMBS_LHS: usize = nlimbs(BITS_LHS);
249            type Lhs = Uint<BITS_LHS, LIMBS_LHS>;
250
251            // Right hand side
252            const_for!(BITS_RHS in BENCH {
253                const LIMBS_RHS: usize = nlimbs(BITS_RHS);
254                type Rhs = Uint<BITS_RHS, LIMBS_RHS>;
255
256                // Result
257                const BITS_RES: usize = BITS_LHS + BITS_RHS;
258                const LIMBS_RES: usize = nlimbs(BITS_RES);
259                type Res = Uint<BITS_RES, LIMBS_RES>;
260
261                proptest!(|(lhs: Lhs, rhs: Rhs)| {
262                    // Compute the result using the target size
263                    let expected = Res::from(lhs) * Res::from(rhs);
264                    assert_eq!(lhs.widening_mul(rhs), expected);
265                });
266            });
267        });
268    }
269}