openvm_stark_backend/poly/
multi.rsuse std::{
iter::zip,
ops::{Deref, DerefMut},
};
use p3_field::{ExtensionField, Field};
use super::uni::UnivariatePolynomial;
pub trait MultivariatePolyOracle<F> {
fn arity(&self) -> usize;
fn marginalize_first(&self, claim: F) -> UnivariatePolynomial<F>;
fn partial_evaluation(self, alpha: F) -> Self;
}
#[derive(Debug, Clone)]
pub struct Mle<F> {
evals: Vec<F>,
}
impl<F: Field> Mle<F> {
pub fn new(evals: Vec<F>) -> Self {
assert!(evals.len().is_power_of_two());
Self { evals }
}
pub fn into_evals(self) -> Vec<F> {
self.evals
}
}
impl<F: Field> MultivariatePolyOracle<F> for Mle<F> {
fn arity(&self) -> usize {
self.evals.len().ilog2() as usize
}
fn marginalize_first(&self, claim: F) -> UnivariatePolynomial<F> {
let x0 = F::ZERO;
let x1 = F::ONE;
let y0 = self[0..self.len() / 2]
.iter()
.fold(F::ZERO, |acc, x| acc + *x);
let y1 = claim - y0;
UnivariatePolynomial::from_interpolation(&[(x0, y0), (x1, y1)])
}
fn partial_evaluation(self, alpha: F) -> Self {
let midpoint = self.len() / 2;
let (lhs_evals, rhs_evals) = self.split_at(midpoint);
let res = zip(lhs_evals, rhs_evals)
.map(|(&lhs_eval, &rhs_eval)| alpha * (rhs_eval - lhs_eval) + lhs_eval)
.collect();
Mle::new(res)
}
}
impl<F> Deref for Mle<F> {
type Target = [F];
fn deref(&self) -> &Self::Target {
&self.evals
}
}
impl<F: Field> DerefMut for Mle<F> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.evals
}
}
pub fn hypercube_eq<F: Field>(x: &[F], y: &[F]) -> F {
assert_eq!(x.len(), y.len());
zip(x, y)
.map(|(&xi, &yi)| xi * yi + (xi - F::ONE) * (yi - F::ONE))
.product()
}
pub fn fold_mle_evals<F, EF>(assignment: EF, eval0: F, eval1: F) -> EF
where
F: Field,
EF: ExtensionField<F>,
{
assignment * (eval1 - eval0) + eval0
}
#[cfg(test)]
mod test {
use p3_baby_bear::BabyBear;
use p3_field::{AbstractField, Field};
use super::*;
impl<F: Field> Mle<F> {
pub(crate) fn eval(&self, point: &[F]) -> F {
pub fn eval_rec<F: Field>(mle_evals: &[F], p: &[F]) -> F {
match p {
[] => mle_evals[0],
&[p_i, ref p @ ..] => {
let (lhs, rhs) = mle_evals.split_at(mle_evals.len() / 2);
let lhs_eval = eval_rec(lhs, p);
let rhs_eval = eval_rec(rhs, p);
p_i * (rhs_eval - lhs_eval) + lhs_eval
}
}
}
let mle_evals = self.clone().into_evals();
eval_rec(&mle_evals, point)
}
}
#[test]
fn test_mle_evaluation() {
let evals = vec![
BabyBear::from_canonical_u32(1),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(4),
];
let mle = Mle::new(evals);
let point = vec![
BabyBear::from_canonical_u32(0),
BabyBear::from_canonical_u32(0),
];
assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(1));
let point = vec![
BabyBear::from_canonical_u32(0),
BabyBear::from_canonical_u32(1),
];
assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(2));
let point = vec![
BabyBear::from_canonical_u32(1),
BabyBear::from_canonical_u32(0),
];
assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(3));
let point = vec![
BabyBear::from_canonical_u32(1),
BabyBear::from_canonical_u32(1),
];
assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(4));
let point = vec![
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(2),
];
assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(7));
}
#[test]
fn test_mle_marginalize_first() {
let evals = vec![
BabyBear::from_canonical_u32(1),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(4),
];
let sum = BabyBear::from_canonical_u32(10);
let mle = Mle::new(evals);
let poly = mle.marginalize_first(sum);
assert_eq!(
poly.evaluate(BabyBear::ZERO),
BabyBear::from_canonical_u32(3)
);
assert_eq!(
poly.evaluate(BabyBear::ONE),
BabyBear::from_canonical_u32(7)
);
}
#[test]
fn test_mle_partial_evaluation() {
let evals = vec![
BabyBear::from_canonical_u32(1),
BabyBear::from_canonical_u32(2),
BabyBear::from_canonical_u32(3),
BabyBear::from_canonical_u32(4),
];
let mle = Mle::new(evals);
let alpha = BabyBear::from_canonical_u32(2);
let partial_eval = mle.partial_evaluation(alpha);
assert_eq!(
partial_eval.eval(&[BabyBear::ZERO]),
BabyBear::from_canonical_u32(5)
);
assert_eq!(
partial_eval.eval(&[BabyBear::ONE]),
BabyBear::from_canonical_u32(6)
);
}
#[test]
fn eq_identical_hypercube_points_returns_one() {
let zero = BabyBear::ZERO;
let one = BabyBear::ONE;
let a = &[one, zero, one];
let eq_eval = hypercube_eq(a, a);
assert_eq!(eq_eval, one);
}
#[test]
fn eq_different_hypercube_points_returns_zero() {
let zero = BabyBear::ZERO;
let one = BabyBear::ONE;
let a = &[one, zero, one];
let b = &[one, zero, zero];
let eq_eval = hypercube_eq(a, b);
assert_eq!(eq_eval, zero);
}
#[test]
#[should_panic]
fn eq_different_size_points() {
let zero = BabyBear::ZERO;
let one = BabyBear::ONE;
hypercube_eq(&[zero, one], &[zero]);
}
}