ark_ff_macros/montgomery/sum_of_products.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
use quote::quote;
pub(super) fn sum_of_products_impl(num_limbs: usize, modulus: &[u64]) -> proc_macro2::TokenStream {
let modulus_size =
(((num_limbs - 1) * 64) as u32 + (64 - modulus[num_limbs - 1].leading_zeros())) as usize;
let mut body = proc_macro2::TokenStream::new();
// Adapted from https://github.com/zkcrypto/bls12_381/pull/84 by @str4d.
// For a single `a x b` multiplication, operand scanning (schoolbook) takes each
// limb of `a` in turn, and multiplies it by all of the limbs of `b` to compute
// the result as a double-width intermediate representation, which is then fully
// reduced at the carry. Here however we have pairs of multiplications (a_i, b_i),
// the results of which are summed.
//
// The intuition for this algorithm is two-fold:
// - We can interleave the operand scanning for each pair, by processing the jth
// limb of each `a_i` together. As these have the same offset within the overall
// operand scanning flow, their results can be summed directly.
// - We can interleave the multiplication and reduction steps, resulting in a
// single bitshift by the limb size after each iteration. This means we only
// need to store a single extra limb overall, instead of keeping around all the
// intermediate results and eventually having twice as many limbs.
if modulus_size >= 64 * num_limbs - 1 {
quote! {
a.iter().zip(b).map(|(a, b)| *a * b).sum()
}
} else {
let mut inner_loop_body = proc_macro2::TokenStream::new();
for k in 1..num_limbs {
inner_loop_body.extend(quote! {
result.0[#k] = fa::mac_with_carry(result.0[#k], a.0[j], b.0[#k], &mut carry2);
});
}
let mut mont_red_body = proc_macro2::TokenStream::new();
for (i, modulus_i) in modulus.iter().enumerate().take(num_limbs).skip(1) {
mont_red_body.extend(quote! {
result.0[#i - 1] = fa::mac_with_carry(result.0[#i], k, #modulus_i, &mut carry2);
});
}
let modulus_0 = modulus[0];
let chunk_size = 2 * (num_limbs * 64 - modulus_size) - 1;
body.extend(quote! {
if M <= #chunk_size {
// Algorithm 2, line 2
let result = (0..#num_limbs).fold(BigInt::zero(), |mut result, j| {
// Algorithm 2, line 3
let mut carry_a = 0;
let mut carry_b = 0;
for (a, b) in a.iter().zip(b) {
let a = &a.0;
let b = &b.0;
let mut carry2 = 0;
result.0[0] = fa::mac(result.0[0], a.0[j], b.0[0], &mut carry2);
#inner_loop_body
carry_b = fa::adc(&mut carry_a, carry_b, carry2);
}
let k = result.0[0].wrapping_mul(Self::INV);
let mut carry2 = 0;
fa::mac_discard(result.0[0], k, #modulus_0, &mut carry2);
#mont_red_body
result.0[#num_limbs - 1] = fa::adc_no_carry(carry_a, carry_b, &mut carry2);
result
});
let mut result = F::new_unchecked(result);
__subtract_modulus(&mut result);
debug_assert_eq!(
a.iter().zip(b).map(|(a, b)| *a * b).sum::<F>(),
result
);
result
} else {
a.chunks(#chunk_size).zip(b.chunks(#chunk_size)).map(|(a, b)| {
if a.len() == #chunk_size {
Self::sum_of_products::<#chunk_size>(a.try_into().unwrap(), b.try_into().unwrap())
} else {
a.iter().zip(b).map(|(a, b)| *a * b).sum()
}
}).sum()
}
});
body
}
}