use std::{
cmp,
collections::{HashMap, HashSet},
fmt,
hash::{Hash, Hasher},
marker::PhantomData,
ops::{Add, Mul, MulAssign, Neg, Sub},
sync::Arc,
};
use group::ff::Field;
use pasta_curves::arithmetic::FieldExt;
use super::{
Basis, Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation,
};
use crate::{arithmetic::parallelize, multicore};
fn get_chunk_params(poly_len: usize) -> (usize, usize) {
let num_threads = multicore::current_num_threads();
let num_chunks = num_threads * 4;
let chunk_size = (poly_len + num_chunks - 1) / num_chunks;
let num_chunks = (poly_len + chunk_size - 1) / chunk_size;
(chunk_size, num_chunks)
}
#[derive(Clone, Copy)]
pub(crate) struct AstLeaf<E, B: Basis> {
index: usize,
rotation: Rotation,
_evaluator: PhantomData<(E, B)>,
}
impl<E, B: Basis> fmt::Debug for AstLeaf<E, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AstLeaf")
.field("index", &self.index)
.field("rotation", &self.rotation)
.finish()
}
}
impl<E, B: Basis> PartialEq for AstLeaf<E, B> {
fn eq(&self, rhs: &Self) -> bool {
self.index.eq(&rhs.index) && self.rotation.0.eq(&rhs.rotation.0)
}
}
impl<E, B: Basis> Eq for AstLeaf<E, B> {}
impl<E, B: Basis> Hash for AstLeaf<E, B> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.index.hash(state);
self.rotation.0.hash(state);
}
}
impl<E, B: Basis> AstLeaf<E, B> {
pub(crate) fn with_rotation(&self, rotation: Rotation) -> Self {
AstLeaf {
index: self.index,
rotation,
_evaluator: PhantomData::default(),
}
}
}
pub(crate) struct Evaluator<E, F: Field, B: Basis> {
polys: Vec<Polynomial<F, B>>,
_context: E,
}
pub(crate) fn new_evaluator<E: Fn() + Clone, F: Field, B: Basis>(context: E) -> Evaluator<E, F, B> {
Evaluator {
polys: vec![],
_context: context,
}
}
impl<E, F: Field, B: Basis> Evaluator<E, F, B> {
pub(crate) fn register_poly(&mut self, poly: Polynomial<F, B>) -> AstLeaf<E, B> {
let index = self.polys.len();
self.polys.push(poly);
AstLeaf {
index,
rotation: Rotation::cur(),
_evaluator: PhantomData::default(),
}
}
pub(crate) fn evaluate(
&self,
ast: &Ast<E, F, B>,
domain: &EvaluationDomain<F>,
) -> Polynomial<F, B>
where
E: Copy + Send + Sync,
F: FieldExt,
B: BasisOps,
{
fn collect_rotations<E: Copy, F: Field, B: Basis>(
ast: &Ast<E, F, B>,
) -> HashSet<AstLeaf<E, B>> {
match ast {
Ast::Poly(leaf) => vec![*leaf].into_iter().collect(),
Ast::Add(a, b) | Ast::Mul(AstMul(a, b)) => {
let lhs = collect_rotations(a);
let rhs = collect_rotations(b);
lhs.union(&rhs).cloned().collect()
}
Ast::Scale(a, _) => collect_rotations(a),
Ast::DistributePowers(terms, _) => terms
.iter()
.flat_map(|term| collect_rotations(term).into_iter())
.collect(),
Ast::LinearTerm(_) | Ast::ConstantTerm(_) => HashSet::default(),
}
}
let leaves = collect_rotations(ast);
let rotated: HashMap<_, _> = leaves
.iter()
.cloned()
.map(|leaf| {
(
leaf,
B::rotate(domain, &self.polys[leaf.index], leaf.rotation),
)
})
.collect();
let poly_len = self.polys.first().unwrap().len();
let (chunk_size, num_chunks) = get_chunk_params(poly_len);
let chunks: Vec<HashMap<_, _>> = (0..num_chunks)
.map(|i| {
rotated
.iter()
.map(|(leaf, poly)| {
(
*leaf,
poly.chunks(chunk_size)
.nth(i)
.expect("num_chunks was calculated correctly"),
)
})
.collect()
})
.collect();
struct AstContext<'a, E, F: FieldExt, B: Basis> {
domain: &'a EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
leaves: &'a HashMap<AstLeaf<E, B>, &'a [F]>,
}
fn recurse<E, F: FieldExt, B: BasisOps>(
ast: &Ast<E, F, B>,
ctx: &AstContext<'_, E, F, B>,
) -> Vec<F> {
match ast {
Ast::Poly(leaf) => ctx.leaves.get(leaf).expect("We prepared this").to_vec(),
Ast::Add(a, b) => {
let mut lhs = recurse(a, ctx);
let rhs = recurse(b, ctx);
for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) {
*lhs += *rhs;
}
lhs
}
Ast::Mul(AstMul(a, b)) => {
let mut lhs = recurse(a, ctx);
let rhs = recurse(b, ctx);
for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) {
*lhs *= *rhs;
}
lhs
}
Ast::Scale(a, scalar) => {
let mut lhs = recurse(a, ctx);
for lhs in lhs.iter_mut() {
*lhs *= scalar;
}
lhs
}
Ast::DistributePowers(terms, base) => terms.iter().fold(
B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, F::zero()),
|mut acc, term| {
let term = recurse(term, ctx);
for (acc, term) in acc.iter_mut().zip(term) {
*acc *= base;
*acc += term;
}
acc
},
),
Ast::LinearTerm(scalar) => B::linear_term(
ctx.domain,
ctx.poly_len,
ctx.chunk_size,
ctx.chunk_index,
*scalar,
),
Ast::ConstantTerm(scalar) => {
B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, *scalar)
}
}
}
let mut result = B::empty_poly(domain);
multicore::scope(|scope| {
for (chunk_index, (out, leaves)) in
result.chunks_mut(chunk_size).zip(chunks.iter()).enumerate()
{
scope.spawn(move |_| {
let ctx = AstContext {
domain,
poly_len,
chunk_size,
chunk_index,
leaves,
};
out.copy_from_slice(&recurse(ast, &ctx));
});
}
});
result
}
}
#[derive(Clone)]
pub(crate) struct AstMul<E, F: Field, B: Basis>(Arc<Ast<E, F, B>>, Arc<Ast<E, F, B>>);
impl<E, F: Field, B: Basis> fmt::Debug for AstMul<E, F, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("AstMul")
.field(&self.0)
.field(&self.1)
.finish()
}
}
#[derive(Clone)]
pub(crate) enum Ast<E, F: Field, B: Basis> {
Poly(AstLeaf<E, B>),
Add(Arc<Ast<E, F, B>>, Arc<Ast<E, F, B>>),
Mul(AstMul<E, F, B>),
Scale(Arc<Ast<E, F, B>>, F),
DistributePowers(Arc<Vec<Ast<E, F, B>>>, F),
LinearTerm(F),
ConstantTerm(F),
}
impl<E, F: Field, B: Basis> Ast<E, F, B> {
pub fn distribute_powers<I: IntoIterator<Item = Self>>(i: I, base: F) -> Self {
Ast::DistributePowers(Arc::new(i.into_iter().collect()), base)
}
}
impl<E, F: Field, B: Basis> fmt::Debug for Ast<E, F, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Poly(leaf) => f.debug_tuple("Poly").field(leaf).finish(),
Self::Add(lhs, rhs) => f.debug_tuple("Add").field(lhs).field(rhs).finish(),
Self::Mul(x) => f.debug_tuple("Mul").field(x).finish(),
Self::Scale(base, scalar) => f.debug_tuple("Scale").field(base).field(scalar).finish(),
Self::DistributePowers(terms, base) => f
.debug_tuple("DistributePowers")
.field(terms)
.field(base)
.finish(),
Self::LinearTerm(x) => f.debug_tuple("LinearTerm").field(x).finish(),
Self::ConstantTerm(x) => f.debug_tuple("ConstantTerm").field(x).finish(),
}
}
}
impl<E, F: Field, B: Basis> From<AstLeaf<E, B>> for Ast<E, F, B> {
fn from(leaf: AstLeaf<E, B>) -> Self {
Ast::Poly(leaf)
}
}
impl<E, F: Field, B: Basis> Ast<E, F, B> {
pub(crate) fn one() -> Self {
Self::ConstantTerm(F::one())
}
}
impl<E, F: Field, B: Basis> Neg for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn neg(self) -> Self::Output {
Ast::Scale(Arc::new(self), -F::one())
}
}
impl<E: Clone, F: Field, B: Basis> Neg for &Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn neg(self) -> Self::Output {
-(self.clone())
}
}
impl<E, F: Field, B: Basis> Add for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn add(self, other: Self) -> Self::Output {
Ast::Add(Arc::new(self), Arc::new(other))
}
}
impl<'a, E: Clone, F: Field, B: Basis> Add<&'a Ast<E, F, B>> for &'a Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn add(self, other: &'a Ast<E, F, B>) -> Self::Output {
self.clone() + other.clone()
}
}
impl<E, F: Field, B: Basis> Add<AstLeaf<E, B>> for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn add(self, other: AstLeaf<E, B>) -> Self::Output {
Ast::Add(Arc::new(self), Arc::new(other.into()))
}
}
impl<E, F: Field, B: Basis> Sub for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn sub(self, other: Self) -> Self::Output {
self + (-other)
}
}
impl<'a, E: Clone, F: Field, B: Basis> Sub<&'a Ast<E, F, B>> for &'a Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn sub(self, other: &'a Ast<E, F, B>) -> Self::Output {
self + &(-other)
}
}
impl<E, F: Field, B: Basis> Sub<AstLeaf<E, B>> for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn sub(self, other: AstLeaf<E, B>) -> Self::Output {
self + (-Ast::from(other))
}
}
impl<E, F: Field> Mul for Ast<E, F, LagrangeCoeff> {
type Output = Ast<E, F, LagrangeCoeff>;
fn mul(self, other: Self) -> Self::Output {
Ast::Mul(AstMul(Arc::new(self), Arc::new(other)))
}
}
impl<'a, E: Clone, F: Field> Mul<&'a Ast<E, F, LagrangeCoeff>> for &'a Ast<E, F, LagrangeCoeff> {
type Output = Ast<E, F, LagrangeCoeff>;
fn mul(self, other: &'a Ast<E, F, LagrangeCoeff>) -> Self::Output {
self.clone() * other.clone()
}
}
impl<E, F: Field> Mul<AstLeaf<E, LagrangeCoeff>> for Ast<E, F, LagrangeCoeff> {
type Output = Ast<E, F, LagrangeCoeff>;
fn mul(self, other: AstLeaf<E, LagrangeCoeff>) -> Self::Output {
Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into())))
}
}
impl<E, F: Field> Mul for Ast<E, F, ExtendedLagrangeCoeff> {
type Output = Ast<E, F, ExtendedLagrangeCoeff>;
fn mul(self, other: Self) -> Self::Output {
Ast::Mul(AstMul(Arc::new(self), Arc::new(other)))
}
}
impl<'a, E: Clone, F: Field> Mul<&'a Ast<E, F, ExtendedLagrangeCoeff>>
for &'a Ast<E, F, ExtendedLagrangeCoeff>
{
type Output = Ast<E, F, ExtendedLagrangeCoeff>;
fn mul(self, other: &'a Ast<E, F, ExtendedLagrangeCoeff>) -> Self::Output {
self.clone() * other.clone()
}
}
impl<E, F: Field> Mul<AstLeaf<E, ExtendedLagrangeCoeff>> for Ast<E, F, ExtendedLagrangeCoeff> {
type Output = Ast<E, F, ExtendedLagrangeCoeff>;
fn mul(self, other: AstLeaf<E, ExtendedLagrangeCoeff>) -> Self::Output {
Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into())))
}
}
impl<E, F: Field, B: Basis> Mul<F> for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn mul(self, other: F) -> Self::Output {
Ast::Scale(Arc::new(self), other)
}
}
impl<E: Clone, F: Field, B: Basis> Mul<F> for &Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn mul(self, other: F) -> Self::Output {
Ast::Scale(Arc::new(self.clone()), other)
}
}
impl<E: Clone, F: Field> MulAssign for Ast<E, F, ExtendedLagrangeCoeff> {
fn mul_assign(&mut self, rhs: Self) {
*self = self.clone().mul(rhs)
}
}
pub(crate) trait BasisOps: Basis {
fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self>;
fn constant_term<F: FieldExt>(
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F>;
fn linear_term<F: FieldExt>(
domain: &EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F>;
fn rotate<F: FieldExt>(
domain: &EvaluationDomain<F>,
poly: &Polynomial<F, Self>,
rotation: Rotation,
) -> Polynomial<F, Self>;
}
impl BasisOps for Coeff {
fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self> {
domain.empty_coeff()
}
fn constant_term<F: FieldExt>(
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
let mut chunk = vec![F::zero(); cmp::min(chunk_size, poly_len - chunk_size * chunk_index)];
if chunk_index == 0 {
chunk[0] = scalar;
}
chunk
}
fn linear_term<F: FieldExt>(
_: &EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
let mut chunk = vec![F::zero(); cmp::min(chunk_size, poly_len - chunk_size * chunk_index)];
if chunk_size == 1 && chunk_index == 1 {
chunk[0] = scalar;
} else if chunk_index == 0 {
chunk[1] = scalar;
}
chunk
}
fn rotate<F: FieldExt>(
_: &EvaluationDomain<F>,
_: &Polynomial<F, Self>,
_: Rotation,
) -> Polynomial<F, Self> {
panic!("Can't rotate polynomials in the standard basis")
}
}
impl BasisOps for LagrangeCoeff {
fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self> {
domain.empty_lagrange()
}
fn constant_term<F: FieldExt>(
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]
}
fn linear_term<F: FieldExt>(
domain: &EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
let omega = domain.get_omega();
let start = chunk_size * chunk_index;
(0..cmp::min(chunk_size, poly_len - start))
.scan(omega.pow_vartime(&[start as u64]) * scalar, |acc, _| {
let ret = *acc;
*acc *= omega;
Some(ret)
})
.collect()
}
fn rotate<F: FieldExt>(
_: &EvaluationDomain<F>,
poly: &Polynomial<F, Self>,
rotation: Rotation,
) -> Polynomial<F, Self> {
poly.rotate(rotation)
}
}
impl BasisOps for ExtendedLagrangeCoeff {
fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self> {
domain.empty_extended()
}
fn constant_term<F: FieldExt>(
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]
}
fn linear_term<F: FieldExt>(
domain: &EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
let omega = domain.get_extended_omega();
let start = chunk_size * chunk_index;
(0..cmp::min(chunk_size, poly_len - start))
.scan(
omega.pow_vartime(&[start as u64]) * F::ZETA * scalar,
|acc, _| {
let ret = *acc;
*acc *= omega;
Some(ret)
},
)
.collect()
}
fn rotate<F: FieldExt>(
domain: &EvaluationDomain<F>,
poly: &Polynomial<F, Self>,
rotation: Rotation,
) -> Polynomial<F, Self> {
domain.rotate_extended(poly, rotation)
}
}
#[cfg(test)]
mod tests {
use std::iter;
use pasta_curves::pallas;
use super::{get_chunk_params, new_evaluator, Ast, BasisOps, Evaluator};
use crate::{
multicore,
poly::{Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff},
};
#[test]
fn short_chunk_regression_test() {
let k = match (1..16)
.map(|k| (k, get_chunk_params(1 << k)))
.find(|(k, (chunk_size, num_chunks))| (1 << k) < chunk_size * num_chunks)
.map(|(k, _)| k)
{
Some(k) => k,
None => {
eprintln!(
"can't find a polynomial length for short_chunk_regression_test; skipping"
);
return;
}
};
eprintln!("Testing short-chunk regression with k = {}", k);
fn test_case<E: Copy + Send + Sync, B: BasisOps>(
k: u32,
mut evaluator: Evaluator<E, pallas::Base, B>,
) {
let domain = EvaluationDomain::new(1, k);
evaluator.register_poly(B::empty_poly(&domain));
let _ = evaluator.evaluate(&Ast::ConstantTerm(pallas::Base::zero()), &domain);
let _ = evaluator.evaluate(&Ast::LinearTerm(pallas::Base::zero()), &domain);
}
test_case(k, new_evaluator::<_, _, Coeff>(|| {}));
test_case(k, new_evaluator::<_, _, LagrangeCoeff>(|| {}));
test_case(k, new_evaluator::<_, _, ExtendedLagrangeCoeff>(|| {}));
}
}