use super::multicore;
pub use ff::Field;
use group::{
ff::{BatchInvert, PrimeField},
Group as _,
};
pub use pasta_curves::arithmetic::*;
fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
let c = if bases.len() < 4 {
1
} else if bases.len() < 32 {
3
} else {
(f64::from(bases.len() as u32)).ln().ceil() as usize
};
fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
let skip_bits = segment * c;
let skip_bytes = skip_bits / 8;
if skip_bytes >= 32 {
return 0;
}
let mut v = [0; 8];
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
*v = *o;
}
let mut tmp = u64::from_le_bytes(v);
tmp >>= skip_bits - (skip_bytes * 8);
tmp = tmp % (1 << c);
tmp as usize
}
let segments = (256 / c) + 1;
for current_segment in (0..segments).rev() {
for _ in 0..c {
*acc = acc.double();
}
#[derive(Clone, Copy)]
enum Bucket<C: CurveAffine> {
None,
Affine(C),
Projective(C::Curve),
}
impl<C: CurveAffine> Bucket<C> {
fn add_assign(&mut self, other: &C) {
*self = match *self {
Bucket::None => Bucket::Affine(*other),
Bucket::Affine(a) => Bucket::Projective(a + *other),
Bucket::Projective(mut a) => {
a += *other;
Bucket::Projective(a)
}
}
}
fn add(self, mut other: C::Curve) -> C::Curve {
match self {
Bucket::None => other,
Bucket::Affine(a) => {
other += a;
other
}
Bucket::Projective(a) => other + &a,
}
}
}
let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
if coeff != 0 {
buckets[coeff - 1].add_assign(base);
}
}
let mut running_sum = C::Curve::identity();
for exp in buckets.into_iter().rev() {
running_sum = exp.add(running_sum);
*acc = *acc + &running_sum;
}
}
}
pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
let mut acc = C::Curve::identity();
for byte_idx in (0..32).rev() {
for bit_idx in (0..8).rev() {
acc = acc.double();
for coeff_idx in 0..coeffs.len() {
let byte = coeffs[coeff_idx].as_ref()[byte_idx];
if ((byte >> bit_idx) & 1) != 0 {
acc += bases[coeff_idx];
}
}
}
}
acc
}
pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());
let num_threads = multicore::current_num_threads();
if coeffs.len() > num_threads {
let chunk = coeffs.len() / num_threads;
let num_chunks = coeffs.chunks(chunk).len();
let mut results = vec![C::Curve::identity(); num_chunks];
multicore::scope(|scope| {
let chunk = coeffs.len() / num_threads;
for ((coeffs, bases), acc) in coeffs
.chunks(chunk)
.zip(bases.chunks(chunk))
.zip(results.iter_mut())
{
scope.spawn(move |_| {
multiexp_serial(coeffs, bases, acc);
});
}
});
results.iter().fold(C::Curve::identity(), |a, b| a + b)
} else {
let mut acc = C::Curve::identity();
multiexp_serial(coeffs, bases, &mut acc);
acc
}
}
pub fn best_fft<G: Group>(a: &mut [G], omega: G::Scalar, log_n: u32) {
fn bitreverse(mut n: usize, l: usize) -> usize {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
}
r
}
let threads = multicore::current_num_threads();
let log_threads = log2_floor(threads);
let n = a.len() as usize;
assert_eq!(n, 1 << log_n);
for k in 0..n {
let rk = bitreverse(k, log_n as usize);
if k < rk {
a.swap(rk, k);
}
}
let twiddles: Vec<_> = (0..(n / 2) as usize)
.scan(G::Scalar::one(), |w, _| {
let tw = *w;
w.group_scale(&omega);
Some(tw)
})
.collect();
if log_n <= log_threads {
let mut chunk = 2_usize;
let mut twiddle_chunk = (n / 2) as usize;
for _ in 0..log_n {
a.chunks_mut(chunk).for_each(|coeffs| {
let (left, right) = coeffs.split_at_mut(chunk / 2);
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0].group_add(&t);
b[0].group_sub(&t);
left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t.group_scale(&twiddles[(i + 1) * twiddle_chunk]);
*b = *a;
a.group_add(&t);
b.group_sub(&t);
});
});
chunk *= 2;
twiddle_chunk /= 2;
}
} else {
recursive_butterfly_arithmetic(a, n, 1, &twiddles)
}
}
pub fn recursive_butterfly_arithmetic<G: Group>(
a: &mut [G],
n: usize,
twiddle_chunk: usize,
twiddles: &[G::Scalar],
) {
if n == 2 {
let t = a[1];
a[1] = a[0];
a[0].group_add(&t);
a[1].group_sub(&t);
} else {
let (left, right) = a.split_at_mut(n / 2);
rayon::join(
|| recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
|| recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
);
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0].group_add(&t);
b[0].group_sub(&t);
left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t.group_scale(&twiddles[(i + 1) * twiddle_chunk]);
*b = *a;
a.group_add(&t);
b.group_sub(&t);
});
}
}
pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
poly.iter()
.rev()
.fold(F::zero(), |acc, coeff| acc * point + coeff)
}
pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
assert_eq!(a.len(), b.len());
let mut acc = F::zero();
for (a, b) in a.iter().zip(b.iter()) {
acc += (*a) * (*b);
}
acc
}
pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
where
I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
{
b = -b;
let a = a.into_iter();
let mut q = vec![F::zero(); a.len() - 1];
let mut tmp = F::zero();
for (q, r) in q.iter_mut().rev().zip(a.rev()) {
let mut lead_coeff = *r;
lead_coeff.sub_assign(&tmp);
*q = lead_coeff;
tmp = lead_coeff;
tmp.mul_assign(&b);
}
q
}
pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
let n = v.len();
let num_threads = multicore::current_num_threads();
let mut chunk = (n as usize) / num_threads;
if chunk < num_threads {
chunk = n as usize;
}
multicore::scope(|scope| {
for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
let f = f.clone();
scope.spawn(move |_| {
let start = chunk_num * chunk;
f(v, start);
});
}
});
}
fn log2_floor(num: usize) -> u32 {
assert!(num > 0);
let mut pow = 0;
while (1 << (pow + 1)) <= num {
pow += 1;
}
pow
}
pub fn lagrange_interpolate<F: FieldExt>(points: &[F], evals: &[F]) -> Vec<F> {
assert_eq!(points.len(), evals.len());
if points.len() == 1 {
return vec![evals[0]];
} else {
let mut denoms = Vec::with_capacity(points.len());
for (j, x_j) in points.iter().enumerate() {
let mut denom = Vec::with_capacity(points.len() - 1);
for x_k in points
.iter()
.enumerate()
.filter(|&(k, _)| k != j)
.map(|a| a.1)
{
denom.push(*x_j - x_k);
}
denoms.push(denom);
}
denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
let mut final_poly = vec![F::zero(); points.len()];
for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
let mut tmp: Vec<F> = Vec::with_capacity(points.len());
let mut product = Vec::with_capacity(points.len() - 1);
tmp.push(F::one());
for (x_k, denom) in points
.iter()
.enumerate()
.filter(|&(k, _)| k != j)
.map(|a| a.1)
.zip(denoms.into_iter())
{
product.resize(tmp.len() + 1, F::zero());
for ((a, b), product) in tmp
.iter()
.chain(std::iter::once(&F::zero()))
.zip(std::iter::once(&F::zero()).chain(tmp.iter()))
.zip(product.iter_mut())
{
*product = *a * (-denom * x_k) + *b * denom;
}
std::mem::swap(&mut tmp, &mut product);
}
assert_eq!(tmp.len(), points.len());
assert_eq!(product.len(), points.len() - 1);
for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp.into_iter()) {
*final_coeff += interpolation_coeff * eval;
}
}
final_poly
}
}
#[cfg(test)]
use rand_core::OsRng;
#[cfg(test)]
use crate::pasta::Fp;
#[test]
fn test_lagrange_interpolate() {
let rng = OsRng;
let points = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
let evals = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
for coeffs in 0..5 {
let points = &points[0..coeffs];
let evals = &evals[0..coeffs];
let poly = lagrange_interpolate(points, evals);
assert_eq!(poly.len(), points.len());
for (point, eval) in points.iter().zip(evals) {
assert_eq!(eval_polynomial(&poly, *point), *eval);
}
}
}