ruint/algorithms/
mul.rs

1#![allow(clippy::module_name_repetitions)]
2
3use crate::algorithms::{ops::sbb, DoubleWord};
4
5/// ⚠️ Computes `result += a * b` and checks for overflow.
6///
7/// **Warning.** This function is not part of the stable API.
8///
9/// Arrays are in little-endian order. All arrays can be arbitrary sized.
10///
11/// # Algorithm
12///
13/// Trims zeros from inputs, then uses the schoolbook multiplication algorithm.
14/// It takes the shortest input as the outer loop.
15///
16/// # Examples
17///
18/// ```
19/// # use ruint::algorithms::addmul;
20/// let mut result = [0];
21/// let overflow = addmul(&mut result, &[3], &[4]);
22/// assert_eq!(overflow, false);
23/// assert_eq!(result, [12]);
24/// ```
25#[inline(always)]
26pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
27    // Trim zeros from `a`
28    while let [0, rest @ ..] = a {
29        a = rest;
30        if let [_, rest @ ..] = lhs {
31            lhs = rest;
32        }
33    }
34    while let [rest @ .., 0] = a {
35        a = rest;
36    }
37
38    // Trim zeros from `b`
39    while let [0, rest @ ..] = b {
40        b = rest;
41        if let [_, rest @ ..] = lhs {
42            lhs = rest;
43        }
44    }
45    while let [rest @ .., 0] = b {
46        b = rest;
47    }
48
49    if a.is_empty() || b.is_empty() {
50        return false;
51    }
52    if lhs.is_empty() {
53        return true;
54    }
55
56    let (a, b) = if b.len() > a.len() { (b, a) } else { (a, b) };
57
58    // Iterate over limbs of `b` and add partial products to `lhs`.
59    let mut overflow = false;
60    for &b in b {
61        if lhs.len() >= a.len() {
62            let (target, rest) = lhs.split_at_mut(a.len());
63            let carry = addmul_nx1(target, a, b);
64            let carry = add_nx1(rest, carry);
65            overflow |= carry != 0;
66        } else {
67            overflow = true;
68            if lhs.is_empty() {
69                break;
70            }
71            addmul_nx1(lhs, &a[..lhs.len()], b);
72        }
73        lhs = &mut lhs[1..];
74    }
75    overflow
76}
77
78/// Computes `lhs += a` and returns the carry.
79#[inline(always)]
80pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
81    if a == 0 {
82        return 0;
83    }
84    for lhs in lhs {
85        (*lhs, a) = u128::add(*lhs, a).split();
86        if a == 0 {
87            return 0;
88        }
89    }
90    a
91}
92
93/// Computes wrapping `lhs += a * b` when all arguments are the same length.
94///
95/// # Panics
96///
97/// Panics if the lengts are not the same.
98#[inline(always)]
99pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) {
100    assert_eq!(lhs.len(), a.len());
101    assert_eq!(lhs.len(), b.len());
102    match lhs.len() {
103        0 => {}
104        1 => addmul_1(lhs, a, b),
105        2 => addmul_2(lhs, a, b),
106        3 => addmul_3(lhs, a, b),
107        4 => addmul_4(lhs, a, b),
108        _ => _ = addmul(lhs, a, b),
109    }
110}
111
112/// Computes `lhs += a * b` for 1 limb.
113#[inline(always)]
114fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) {
115    assume!(lhs.len() == 1);
116    assume!(a.len() == 1);
117    assume!(b.len() == 1);
118
119    mac(&mut lhs[0], a[0], b[0], 0);
120}
121
122/// Computes `lhs += a * b` for 2 limbs.
123#[inline(always)]
124fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) {
125    assume!(lhs.len() == 2);
126    assume!(a.len() == 2);
127    assume!(b.len() == 2);
128
129    let carry = mac(&mut lhs[0], a[0], b[0], 0);
130    mac(&mut lhs[1], a[0], b[1], carry);
131
132    mac(&mut lhs[1], a[1], b[0], 0);
133}
134
135/// Computes `lhs += a * b` for 3 limbs.
136#[inline(always)]
137fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) {
138    assume!(lhs.len() == 3);
139    assume!(a.len() == 3);
140    assume!(b.len() == 3);
141
142    let carry = mac(&mut lhs[0], a[0], b[0], 0);
143    let carry = mac(&mut lhs[1], a[0], b[1], carry);
144    mac(&mut lhs[2], a[0], b[2], carry);
145
146    let carry = mac(&mut lhs[1], a[1], b[0], 0);
147    mac(&mut lhs[2], a[1], b[1], carry);
148
149    mac(&mut lhs[2], a[2], b[0], 0);
150}
151
152/// Computes `lhs += a * b` for 4 limbs.
153#[inline(always)]
154fn addmul_4(lhs: &mut [u64], a: &[u64], b: &[u64]) {
155    assume!(lhs.len() == 4);
156    assume!(a.len() == 4);
157    assume!(b.len() == 4);
158
159    let carry = mac(&mut lhs[0], a[0], b[0], 0);
160    let carry = mac(&mut lhs[1], a[0], b[1], carry);
161    let carry = mac(&mut lhs[2], a[0], b[2], carry);
162    mac(&mut lhs[3], a[0], b[3], carry);
163
164    let carry = mac(&mut lhs[1], a[1], b[0], 0);
165    let carry = mac(&mut lhs[2], a[1], b[1], carry);
166    mac(&mut lhs[3], a[1], b[2], carry);
167
168    let carry = mac(&mut lhs[2], a[2], b[0], 0);
169    mac(&mut lhs[3], a[2], b[1], carry);
170
171    mac(&mut lhs[3], a[3], b[0], 0);
172}
173
174#[inline(always)]
175fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
176    let prod = u128::muladd2(a, b, c, *lhs);
177    *lhs = prod.low();
178    prod.high()
179}
180
181/// Computes `lhs *= a` and returns the carry.
182#[inline(always)]
183pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
184    let mut carry = 0;
185    for lhs in lhs {
186        (*lhs, carry) = u128::muladd(*lhs, a, carry).split();
187    }
188    carry
189}
190
191/// Computes `lhs += a * b` and returns the carry.
192///
193/// Requires `lhs.len() == a.len()`.
194///
195/// $$
196/// \begin{aligned}
197/// \mathsf{lhs'} &= \mod{\mathsf{lhs} + \mathsf{a} ⋅ \mathsf{b}}_{2^{64⋅N}}
198/// \\\\ \mathsf{carry} &= \floor{\frac{\mathsf{lhs} + \mathsf{a} ⋅ \mathsf{b}
199/// }{2^{64⋅N}}} \end{aligned}
200/// $$
201#[inline(always)]
202pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
203    assume!(lhs.len() == a.len());
204    let mut carry = 0;
205    for i in 0..a.len() {
206        (lhs[i], carry) = u128::muladd2(a[i], b, carry, lhs[i]).split();
207    }
208    carry
209}
210
211/// Computes `lhs -= a * b` and returns the borrow.
212///
213/// Requires `lhs.len() == a.len()`.
214///
215/// $$
216/// \begin{aligned}
217/// \mathsf{lhs'} &= \mod{\mathsf{lhs} - \mathsf{a} ⋅ \mathsf{b}}_{2^{64⋅N}}
218/// \\\\ \mathsf{borrow} &= \floor{\frac{\mathsf{a} ⋅ \mathsf{b} -
219/// \mathsf{lhs}}{2^{64⋅N}}} \end{aligned}
220/// $$
221// OPT: `carry` and `borrow` can probably be merged into a single var.
222#[inline(always)]
223pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
224    assume!(lhs.len() == a.len());
225    let mut carry = 0;
226    let mut borrow = 0;
227    for i in 0..a.len() {
228        // Compute product limbs
229        let limb;
230        (limb, carry) = u128::muladd(a[i], b, carry).split();
231
232        // Subtract
233        (lhs[i], borrow) = sbb(lhs[i], limb, borrow);
234    }
235    borrow + carry
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use proptest::{collection, num::u64, proptest};
242
243    #[allow(clippy::cast_possible_truncation)] // Intentional truncation.
244    fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
245        let mut overflow = 0;
246        for (i, a) in a.iter().copied().enumerate() {
247            let mut result = result.iter_mut().skip(i);
248            let mut b = b.iter().copied();
249            let mut carry = 0_u128;
250            loop {
251                match (result.next(), b.next()) {
252                    // Partial product.
253                    (Some(result), Some(b)) => {
254                        carry += u128::from(*result) + u128::from(a) * u128::from(b);
255                        *result = carry as u64;
256                        carry >>= 64;
257                    }
258                    // Carry propagation.
259                    (Some(result), None) => {
260                        carry += u128::from(*result);
261                        *result = carry as u64;
262                        carry >>= 64;
263                    }
264                    // Excess product.
265                    (None, Some(b)) => {
266                        carry += u128::from(a) * u128::from(b);
267                        overflow |= carry as u64;
268                        carry >>= 64;
269                    }
270                    // Fin.
271                    (None, None) => {
272                        break;
273                    }
274                }
275            }
276            overflow |= carry as u64;
277        }
278        overflow != 0
279    }
280
281    #[test]
282    fn test_addmul() {
283        let any_vec = collection::vec(u64::ANY, 0..10);
284        proptest!(|(mut lhs in &any_vec, a in &any_vec, b in &any_vec)| {
285            // Reference
286            let mut ref_lhs = lhs.clone();
287            let ref_overflow = addmul_ref(&mut ref_lhs, &a, &b);
288
289            // Test
290            let overflow = addmul(&mut lhs, &a, &b);
291            assert_eq!(lhs, ref_lhs);
292            assert_eq!(overflow, ref_overflow);
293        });
294    }
295
296    fn test_vals(lhs: &[u64], rhs: &[u64], expected: &[u64], expected_overflow: bool) {
297        let mut result = vec![0; expected.len()];
298        let overflow = addmul(&mut result, lhs, rhs);
299        assert_eq!(overflow, expected_overflow);
300        assert_eq!(result, expected);
301    }
302
303    #[test]
304    fn test_empty() {
305        test_vals(&[], &[], &[], false);
306        test_vals(&[], &[1], &[], false);
307        test_vals(&[1], &[], &[], false);
308        test_vals(&[1], &[1], &[], true);
309        test_vals(&[], &[], &[0], false);
310        test_vals(&[], &[1], &[0], false);
311        test_vals(&[1], &[], &[0], false);
312        test_vals(&[1], &[1], &[1], false);
313    }
314
315    #[test]
316    fn test_submul_nx1() {
317        let mut lhs = [
318            15520854688669198950,
319            13760048731709406392,
320            14363314282014368551,
321            13263184899940581802,
322        ];
323        let a = [
324            7955980792890017645,
325            6297379555503105007,
326            2473663400150304794,
327            18362433840513668572,
328        ];
329        let b = 17275533833223164845;
330        let borrow = submul_nx1(&mut lhs, &a, b);
331        assert_eq!(lhs, [
332            2427453526388035261,
333            7389014268281543265,
334            6670181329660292018,
335            8411211985208067428
336        ]);
337        assert_eq!(borrow, 17196576577663999042);
338    }
339}