#![allow(clippy::module_name_repetitions)]
use crate::algorithms::{ops::sbb, DoubleWord};
#[inline(always)]
pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
while let [0, rest @ ..] = a {
a = rest;
if let [_, rest @ ..] = lhs {
lhs = rest;
}
}
while let [rest @ .., 0] = a {
a = rest;
}
while let [0, rest @ ..] = b {
b = rest;
if let [_, rest @ ..] = lhs {
lhs = rest;
}
}
while let [rest @ .., 0] = b {
b = rest;
}
if a.is_empty() || b.is_empty() {
return false;
}
if lhs.is_empty() {
return true;
}
let (a, b) = if b.len() > a.len() { (b, a) } else { (a, b) };
let mut overflow = false;
for &b in b {
if lhs.len() >= a.len() {
let (target, rest) = lhs.split_at_mut(a.len());
let carry = addmul_nx1(target, a, b);
let carry = add_nx1(rest, carry);
overflow |= carry != 0;
} else {
overflow = true;
if lhs.is_empty() {
break;
}
addmul_nx1(lhs, &a[..lhs.len()], b);
}
lhs = &mut lhs[1..];
}
overflow
}
#[inline(always)]
pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
if a == 0 {
return 0;
}
for lhs in lhs {
(*lhs, a) = u128::add(*lhs, a).split();
if a == 0 {
return 0;
}
}
a
}
#[inline(always)]
pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assert_eq!(lhs.len(), a.len());
assert_eq!(lhs.len(), b.len());
match lhs.len() {
0 => {}
1 => addmul_1(lhs, a, b),
2 => addmul_2(lhs, a, b),
3 => addmul_3(lhs, a, b),
4 => addmul_4(lhs, a, b),
_ => _ = addmul(lhs, a, b),
}
}
#[inline(always)]
fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assume!(lhs.len() == 1);
assume!(a.len() == 1);
assume!(b.len() == 1);
mac(&mut lhs[0], a[0], b[0], 0);
}
#[inline(always)]
fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assume!(lhs.len() == 2);
assume!(a.len() == 2);
assume!(b.len() == 2);
let carry = mac(&mut lhs[0], a[0], b[0], 0);
mac(&mut lhs[1], a[0], b[1], carry);
mac(&mut lhs[1], a[1], b[0], 0);
}
#[inline(always)]
fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assume!(lhs.len() == 3);
assume!(a.len() == 3);
assume!(b.len() == 3);
let carry = mac(&mut lhs[0], a[0], b[0], 0);
let carry = mac(&mut lhs[1], a[0], b[1], carry);
mac(&mut lhs[2], a[0], b[2], carry);
let carry = mac(&mut lhs[1], a[1], b[0], 0);
mac(&mut lhs[2], a[1], b[1], carry);
mac(&mut lhs[2], a[2], b[0], 0);
}
#[inline(always)]
fn addmul_4(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assume!(lhs.len() == 4);
assume!(a.len() == 4);
assume!(b.len() == 4);
let carry = mac(&mut lhs[0], a[0], b[0], 0);
let carry = mac(&mut lhs[1], a[0], b[1], carry);
let carry = mac(&mut lhs[2], a[0], b[2], carry);
mac(&mut lhs[3], a[0], b[3], carry);
let carry = mac(&mut lhs[1], a[1], b[0], 0);
let carry = mac(&mut lhs[2], a[1], b[1], carry);
mac(&mut lhs[3], a[1], b[2], carry);
let carry = mac(&mut lhs[2], a[2], b[0], 0);
mac(&mut lhs[3], a[2], b[1], carry);
mac(&mut lhs[3], a[3], b[0], 0);
}
#[inline(always)]
fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
let prod = u128::muladd2(a, b, c, *lhs);
*lhs = prod.low();
prod.high()
}
#[inline(always)]
pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
let mut carry = 0;
for lhs in lhs {
(*lhs, carry) = u128::muladd(*lhs, a, carry).split();
}
carry
}
#[inline(always)]
pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
assume!(lhs.len() == a.len());
let mut carry = 0;
for i in 0..a.len() {
(lhs[i], carry) = u128::muladd2(a[i], b, carry, lhs[i]).split();
}
carry
}
#[inline(always)]
pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
assume!(lhs.len() == a.len());
let mut carry = 0;
let mut borrow = 0;
for i in 0..a.len() {
let limb;
(limb, carry) = u128::muladd(a[i], b, carry).split();
(lhs[i], borrow) = sbb(lhs[i], limb, borrow);
}
borrow + carry
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::{collection, num::u64, proptest};
#[allow(clippy::cast_possible_truncation)] fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
let mut overflow = 0;
for (i, a) in a.iter().copied().enumerate() {
let mut result = result.iter_mut().skip(i);
let mut b = b.iter().copied();
let mut carry = 0_u128;
loop {
match (result.next(), b.next()) {
(Some(result), Some(b)) => {
carry += u128::from(*result) + u128::from(a) * u128::from(b);
*result = carry as u64;
carry >>= 64;
}
(Some(result), None) => {
carry += u128::from(*result);
*result = carry as u64;
carry >>= 64;
}
(None, Some(b)) => {
carry += u128::from(a) * u128::from(b);
overflow |= carry as u64;
carry >>= 64;
}
(None, None) => {
break;
}
}
}
overflow |= carry as u64;
}
overflow != 0
}
#[test]
fn test_addmul() {
let any_vec = collection::vec(u64::ANY, 0..10);
proptest!(|(mut lhs in &any_vec, a in &any_vec, b in &any_vec)| {
let mut ref_lhs = lhs.clone();
let ref_overflow = addmul_ref(&mut ref_lhs, &a, &b);
let overflow = addmul(&mut lhs, &a, &b);
assert_eq!(lhs, ref_lhs);
assert_eq!(overflow, ref_overflow);
});
}
fn test_vals(lhs: &[u64], rhs: &[u64], expected: &[u64], expected_overflow: bool) {
let mut result = vec![0; expected.len()];
let overflow = addmul(&mut result, lhs, rhs);
assert_eq!(overflow, expected_overflow);
assert_eq!(result, expected);
}
#[test]
fn test_empty() {
test_vals(&[], &[], &[], false);
test_vals(&[], &[1], &[], false);
test_vals(&[1], &[], &[], false);
test_vals(&[1], &[1], &[], true);
test_vals(&[], &[], &[0], false);
test_vals(&[], &[1], &[0], false);
test_vals(&[1], &[], &[0], false);
test_vals(&[1], &[1], &[1], false);
}
#[test]
fn test_submul_nx1() {
let mut lhs = [
15520854688669198950,
13760048731709406392,
14363314282014368551,
13263184899940581802,
];
let a = [
7955980792890017645,
6297379555503105007,
2473663400150304794,
18362433840513668572,
];
let b = 17275533833223164845;
let borrow = submul_nx1(&mut lhs, &a, b);
assert_eq!(lhs, [
2427453526388035261,
7389014268281543265,
6670181329660292018,
8411211985208067428
]);
assert_eq!(borrow, 17196576577663999042);
}
}