1#![allow(clippy::module_name_repetitions)]
2
3use crate::algorithms::{ops::sbb, DoubleWord};
4
5#[inline(always)]
26pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
27 while let [0, rest @ ..] = a {
29 a = rest;
30 if let [_, rest @ ..] = lhs {
31 lhs = rest;
32 }
33 }
34 while let [rest @ .., 0] = a {
35 a = rest;
36 }
37
38 while let [0, rest @ ..] = b {
40 b = rest;
41 if let [_, rest @ ..] = lhs {
42 lhs = rest;
43 }
44 }
45 while let [rest @ .., 0] = b {
46 b = rest;
47 }
48
49 if a.is_empty() || b.is_empty() {
50 return false;
51 }
52 if lhs.is_empty() {
53 return true;
54 }
55
56 let (a, b) = if b.len() > a.len() { (b, a) } else { (a, b) };
57
58 let mut overflow = false;
60 for &b in b {
61 if lhs.len() >= a.len() {
62 let (target, rest) = lhs.split_at_mut(a.len());
63 let carry = addmul_nx1(target, a, b);
64 let carry = add_nx1(rest, carry);
65 overflow |= carry != 0;
66 } else {
67 overflow = true;
68 if lhs.is_empty() {
69 break;
70 }
71 addmul_nx1(lhs, &a[..lhs.len()], b);
72 }
73 lhs = &mut lhs[1..];
74 }
75 overflow
76}
77
78#[inline(always)]
80pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
81 if a == 0 {
82 return 0;
83 }
84 for lhs in lhs {
85 (*lhs, a) = u128::add(*lhs, a).split();
86 if a == 0 {
87 return 0;
88 }
89 }
90 a
91}
92
93#[inline(always)]
99pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) {
100 assert_eq!(lhs.len(), a.len());
101 assert_eq!(lhs.len(), b.len());
102 match lhs.len() {
103 0 => {}
104 1 => addmul_1(lhs, a, b),
105 2 => addmul_2(lhs, a, b),
106 3 => addmul_3(lhs, a, b),
107 4 => addmul_4(lhs, a, b),
108 _ => _ = addmul(lhs, a, b),
109 }
110}
111
112#[inline(always)]
114fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) {
115 assume!(lhs.len() == 1);
116 assume!(a.len() == 1);
117 assume!(b.len() == 1);
118
119 mac(&mut lhs[0], a[0], b[0], 0);
120}
121
122#[inline(always)]
124fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) {
125 assume!(lhs.len() == 2);
126 assume!(a.len() == 2);
127 assume!(b.len() == 2);
128
129 let carry = mac(&mut lhs[0], a[0], b[0], 0);
130 mac(&mut lhs[1], a[0], b[1], carry);
131
132 mac(&mut lhs[1], a[1], b[0], 0);
133}
134
135#[inline(always)]
137fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) {
138 assume!(lhs.len() == 3);
139 assume!(a.len() == 3);
140 assume!(b.len() == 3);
141
142 let carry = mac(&mut lhs[0], a[0], b[0], 0);
143 let carry = mac(&mut lhs[1], a[0], b[1], carry);
144 mac(&mut lhs[2], a[0], b[2], carry);
145
146 let carry = mac(&mut lhs[1], a[1], b[0], 0);
147 mac(&mut lhs[2], a[1], b[1], carry);
148
149 mac(&mut lhs[2], a[2], b[0], 0);
150}
151
152#[inline(always)]
154fn addmul_4(lhs: &mut [u64], a: &[u64], b: &[u64]) {
155 assume!(lhs.len() == 4);
156 assume!(a.len() == 4);
157 assume!(b.len() == 4);
158
159 let carry = mac(&mut lhs[0], a[0], b[0], 0);
160 let carry = mac(&mut lhs[1], a[0], b[1], carry);
161 let carry = mac(&mut lhs[2], a[0], b[2], carry);
162 mac(&mut lhs[3], a[0], b[3], carry);
163
164 let carry = mac(&mut lhs[1], a[1], b[0], 0);
165 let carry = mac(&mut lhs[2], a[1], b[1], carry);
166 mac(&mut lhs[3], a[1], b[2], carry);
167
168 let carry = mac(&mut lhs[2], a[2], b[0], 0);
169 mac(&mut lhs[3], a[2], b[1], carry);
170
171 mac(&mut lhs[3], a[3], b[0], 0);
172}
173
174#[inline(always)]
175fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
176 let prod = u128::muladd2(a, b, c, *lhs);
177 *lhs = prod.low();
178 prod.high()
179}
180
181#[inline(always)]
183pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
184 let mut carry = 0;
185 for lhs in lhs {
186 (*lhs, carry) = u128::muladd(*lhs, a, carry).split();
187 }
188 carry
189}
190
191#[inline(always)]
202pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
203 assume!(lhs.len() == a.len());
204 let mut carry = 0;
205 for i in 0..a.len() {
206 (lhs[i], carry) = u128::muladd2(a[i], b, carry, lhs[i]).split();
207 }
208 carry
209}
210
211#[inline(always)]
223pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
224 assume!(lhs.len() == a.len());
225 let mut carry = 0;
226 let mut borrow = 0;
227 for i in 0..a.len() {
228 let limb;
230 (limb, carry) = u128::muladd(a[i], b, carry).split();
231
232 (lhs[i], borrow) = sbb(lhs[i], limb, borrow);
234 }
235 borrow + carry
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use proptest::{collection, num::u64, proptest};
242
243 #[allow(clippy::cast_possible_truncation)] fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
245 let mut overflow = 0;
246 for (i, a) in a.iter().copied().enumerate() {
247 let mut result = result.iter_mut().skip(i);
248 let mut b = b.iter().copied();
249 let mut carry = 0_u128;
250 loop {
251 match (result.next(), b.next()) {
252 (Some(result), Some(b)) => {
254 carry += u128::from(*result) + u128::from(a) * u128::from(b);
255 *result = carry as u64;
256 carry >>= 64;
257 }
258 (Some(result), None) => {
260 carry += u128::from(*result);
261 *result = carry as u64;
262 carry >>= 64;
263 }
264 (None, Some(b)) => {
266 carry += u128::from(a) * u128::from(b);
267 overflow |= carry as u64;
268 carry >>= 64;
269 }
270 (None, None) => {
272 break;
273 }
274 }
275 }
276 overflow |= carry as u64;
277 }
278 overflow != 0
279 }
280
281 #[test]
282 fn test_addmul() {
283 let any_vec = collection::vec(u64::ANY, 0..10);
284 proptest!(|(mut lhs in &any_vec, a in &any_vec, b in &any_vec)| {
285 let mut ref_lhs = lhs.clone();
287 let ref_overflow = addmul_ref(&mut ref_lhs, &a, &b);
288
289 let overflow = addmul(&mut lhs, &a, &b);
291 assert_eq!(lhs, ref_lhs);
292 assert_eq!(overflow, ref_overflow);
293 });
294 }
295
296 fn test_vals(lhs: &[u64], rhs: &[u64], expected: &[u64], expected_overflow: bool) {
297 let mut result = vec![0; expected.len()];
298 let overflow = addmul(&mut result, lhs, rhs);
299 assert_eq!(overflow, expected_overflow);
300 assert_eq!(result, expected);
301 }
302
303 #[test]
304 fn test_empty() {
305 test_vals(&[], &[], &[], false);
306 test_vals(&[], &[1], &[], false);
307 test_vals(&[1], &[], &[], false);
308 test_vals(&[1], &[1], &[], true);
309 test_vals(&[], &[], &[0], false);
310 test_vals(&[], &[1], &[0], false);
311 test_vals(&[1], &[], &[0], false);
312 test_vals(&[1], &[1], &[1], false);
313 }
314
315 #[test]
316 fn test_submul_nx1() {
317 let mut lhs = [
318 15520854688669198950,
319 13760048731709406392,
320 14363314282014368551,
321 13263184899940581802,
322 ];
323 let a = [
324 7955980792890017645,
325 6297379555503105007,
326 2473663400150304794,
327 18362433840513668572,
328 ];
329 let b = 17275533833223164845;
330 let borrow = submul_nx1(&mut lhs, &a, b);
331 assert_eq!(lhs, [
332 2427453526388035261,
333 7389014268281543265,
334 6670181329660292018,
335 8411211985208067428
336 ]);
337 assert_eq!(borrow, 17196576577663999042);
338 }
339}