use super::{borrowing_sub, carrying_add, cmp};
use core::{cmp::Ordering, iter::zip};
#[inline]
#[must_use]
pub fn mul_redc<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] {
debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX);
debug_assert_eq!(cmp(&a, &modulus), Ordering::Less);
debug_assert_eq!(cmp(&b, &modulus), Ordering::Less);
let mut result = [0; N];
let mut carry = false;
for b in b {
let mut m = 0;
let mut carry_1 = 0;
let mut carry_2 = 0;
for i in 0..N {
let (value, next_carry) = carrying_mul_add(a[i], b, result[i], carry_1);
carry_1 = next_carry;
if i == 0 {
m = value.wrapping_mul(inv);
}
let (value, next_carry) = carrying_mul_add(modulus[i], m, value, carry_2);
carry_2 = next_carry;
if i > 0 {
result[i - 1] = value;
} else {
debug_assert_eq!(value, 0);
}
}
let (value, next_carry) = carrying_add(carry_1, carry_2, carry);
result[N - 1] = value;
if modulus[N - 1] >= 0x7fff_ffff_ffff_ffff {
carry = next_carry;
} else {
debug_assert!(!next_carry);
}
}
reduce1_carry(result, modulus, carry)
}
#[inline]
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] {
debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX);
debug_assert_eq!(cmp(&a, &modulus), Ordering::Less);
let mut result = [0; N];
let mut carry_outer = 0;
for i in 0..N {
let (value, mut carry_lo) = carrying_mul_add(a[i], a[i], result[i], 0);
let mut carry_hi = false;
result[i] = value;
for j in (i + 1)..N {
let (value, next_carry_lo, next_carry_hi) =
carrying_double_mul_add(a[i], a[j], result[j], carry_lo, carry_hi);
result[j] = value;
carry_lo = next_carry_lo;
carry_hi = next_carry_hi;
}
let m = result[0].wrapping_mul(inv);
let (value, mut carry) = carrying_mul_add(m, modulus[0], result[0], 0);
debug_assert_eq!(value, 0);
for j in 1..N {
let (value, next_carry) = carrying_mul_add(modulus[j], m, result[j], carry);
result[j - 1] = value;
carry = next_carry;
}
if modulus[N - 1] >= 0x3fff_ffff_ffff_ffff {
let wide = (carry_outer as u128)
.wrapping_add(carry_lo as u128)
.wrapping_add((carry_hi as u128) << 64)
.wrapping_add(carry as u128);
result[N - 1] = wide as u64;
carry_outer = (wide >> 64) as u64;
debug_assert!(carry_outer <= 2);
} else {
debug_assert!(!carry_hi);
debug_assert_eq!(carry_outer, 0);
let (value, carry) = carry_lo.overflowing_add(carry);
debug_assert!(!carry);
result[N - 1] = value;
}
}
debug_assert!(carry_outer <= 1);
reduce1_carry(result, modulus, carry_outer > 0)
}
#[inline]
#[must_use]
#[allow(clippy::needless_bitwise_bool)]
fn reduce1_carry<const N: usize>(value: [u64; N], modulus: [u64; N], carry: bool) -> [u64; N] {
let (reduced, borrow) = sub(value, modulus);
if carry | !borrow {
reduced
} else {
value
}
}
#[inline]
#[must_use]
fn sub<const N: usize>(lhs: [u64; N], rhs: [u64; N]) -> ([u64; N], bool) {
let mut result = [0; N];
let mut borrow = false;
for (result, (lhs, rhs)) in zip(&mut result, zip(lhs, rhs)) {
let (value, next_borrow) = borrowing_sub(lhs, rhs, borrow);
*result = value;
borrow = next_borrow;
}
(result, borrow)
}
#[inline]
#[must_use]
#[allow(clippy::cast_possible_truncation)]
const fn carrying_mul_add(lhs: u64, rhs: u64, add: u64, carry: u64) -> (u64, u64) {
let wide = (lhs as u128)
.wrapping_mul(rhs as u128)
.wrapping_add(add as u128)
.wrapping_add(carry as u128);
(wide as u64, (wide >> 64) as u64)
}
#[inline]
#[must_use]
#[allow(clippy::cast_possible_truncation)]
const fn carrying_double_mul_add(
lhs: u64,
rhs: u64,
add: u64,
carry_lo: u64,
carry_hi: bool,
) -> (u64, u64, bool) {
let wide = (lhs as u128).wrapping_mul(rhs as u128);
let (wide, carry_1) = wide.overflowing_add(wide);
let carries = (add as u128)
.wrapping_add(carry_lo as u128)
.wrapping_add((carry_hi as u128) << 64);
let (wide, carry_2) = wide.overflowing_add(carries);
(wide as u64, (wide >> 64) as u64, carry_1 | carry_2)
}
#[cfg(test)]
mod test {
use core::ops::Neg;
use super::{
super::{addmul, div},
*,
};
use crate::{aliases::U64, const_for, nlimbs, Uint};
use proptest::{prop_assert_eq, proptest};
fn modmul<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N]) -> [u64; N] {
let mut product = vec![0; 2 * N];
addmul(&mut product, &a, &b);
let mut reduced = modulus;
div(&mut product, &mut reduced);
reduced
}
fn mul_base<const N: usize>(a: [u64; N], modulus: [u64; N]) -> [u64; N] {
let mut product = vec![0; 2 * N];
product[N..].copy_from_slice(&a);
let mut reduced = modulus;
div(&mut product, &mut reduced);
reduced
}
#[test]
fn test_mul_redc() {
const_for!(BITS in NON_ZERO if (BITS >= 16) {
const LIMBS: usize = nlimbs(BITS);
type U = Uint<BITS, LIMBS>;
proptest!(|(mut a: U, mut b: U, mut m: U)| {
m |= U::from(1_u64); a %= m; b %= m; let a = *a.as_limbs();
let b = *b.as_limbs();
let m = *m.as_limbs();
let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0];
let result = mul_base(mul_redc(a, b, m, inv), m);
let expected = modmul(a, b, m);
prop_assert_eq!(result, expected);
});
});
}
#[test]
fn test_square_redc() {
const_for!(BITS in NON_ZERO if (BITS >= 16) {
const LIMBS: usize = nlimbs(BITS);
type U = Uint<BITS, LIMBS>;
proptest!(|(mut a: U, mut m: U)| {
m |= U::from(1_u64); a %= m; let a = *a.as_limbs();
let m = *m.as_limbs();
let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0];
let result = mul_base(square_redc(a, m, inv), m);
let expected = modmul(a, a, m);
prop_assert_eq!(result, expected);
});
});
}
}