openvm_stark_backend/poly/
multi.rs
1use std::{
3 iter::zip,
4 ops::{Deref, DerefMut},
5};
6
7use p3_field::{ExtensionField, Field};
8
9use super::uni::UnivariatePolynomial;
10
11pub trait MultivariatePolyOracle<F> {
13 fn arity(&self) -> usize;
15
16 fn marginalize_first(&self, claim: F) -> UnivariatePolynomial<F>;
18
19 fn partial_evaluation(self, alpha: F) -> Self;
21}
22
23#[derive(Debug, Clone)]
27pub struct Mle<F> {
28 evals: Vec<F>,
29}
30
31impl<F: Field> Mle<F> {
32 pub fn new(evals: Vec<F>) -> Self {
38 assert!(evals.len().is_power_of_two());
39 Self { evals }
40 }
41
42 pub fn into_evals(self) -> Vec<F> {
43 self.evals
44 }
45}
46
47impl<F: Field> MultivariatePolyOracle<F> for Mle<F> {
48 fn arity(&self) -> usize {
49 self.evals.len().ilog2() as usize
50 }
51
52 fn marginalize_first(&self, claim: F) -> UnivariatePolynomial<F> {
53 let x0 = F::ZERO;
54 let x1 = F::ONE;
55
56 let y0 = self[0..self.len() / 2]
57 .iter()
58 .fold(F::ZERO, |acc, x| acc + *x);
59 let y1 = claim - y0;
60
61 UnivariatePolynomial::from_interpolation(&[(x0, y0), (x1, y1)])
62 }
63
64 fn partial_evaluation(self, alpha: F) -> Self {
65 let midpoint = self.len() / 2;
66 let (lhs_evals, rhs_evals) = self.split_at(midpoint);
67
68 let res = zip(lhs_evals, rhs_evals)
69 .map(|(&lhs_eval, &rhs_eval)| alpha * (rhs_eval - lhs_eval) + lhs_eval)
70 .collect();
71
72 Mle::new(res)
73 }
74}
75
76impl<F> Deref for Mle<F> {
77 type Target = [F];
78
79 fn deref(&self) -> &Self::Target {
80 &self.evals
81 }
82}
83
84impl<F: Field> DerefMut for Mle<F> {
85 fn deref_mut(&mut self) -> &mut Self::Target {
86 &mut self.evals
87 }
88}
89
90pub fn hypercube_eq<F: Field>(x: &[F], y: &[F]) -> F {
101 assert_eq!(x.len(), y.len());
102 zip(x, y)
103 .map(|(&xi, &yi)| xi * yi + (xi - F::ONE) * (yi - F::ONE))
104 .product()
105}
106
107pub fn fold_mle_evals<F, EF>(assignment: EF, eval0: F, eval1: F) -> EF
109where
110 F: Field,
111 EF: ExtensionField<F>,
112{
113 assignment * (eval1 - eval0) + eval0
114}
115
116#[cfg(test)]
117mod test {
118 use p3_baby_bear::BabyBear;
119 use p3_field::{Field, FieldAlgebra};
120
121 use super::*;
122
123 impl<F: Field> Mle<F> {
124 pub(crate) fn eval(&self, point: &[F]) -> F {
126 pub fn eval_rec<F: Field>(mle_evals: &[F], p: &[F]) -> F {
127 match p {
128 [] => mle_evals[0],
129 &[p_i, ref p @ ..] => {
130 let (lhs, rhs) = mle_evals.split_at(mle_evals.len() / 2);
131 let lhs_eval = eval_rec(lhs, p);
132 let rhs_eval = eval_rec(rhs, p);
133 p_i * (rhs_eval - lhs_eval) + lhs_eval
135 }
136 }
137 }
138
139 let mle_evals = self.clone().into_evals();
140 eval_rec(&mle_evals, point)
141 }
142 }
143
144 #[test]
145 fn test_mle_evaluation() {
146 let evals = vec![
147 BabyBear::from_canonical_u32(1),
148 BabyBear::from_canonical_u32(2),
149 BabyBear::from_canonical_u32(3),
150 BabyBear::from_canonical_u32(4),
151 ];
152 let mle = Mle::new(evals);
154 let point = vec![
155 BabyBear::from_canonical_u32(0),
156 BabyBear::from_canonical_u32(0),
157 ];
158 assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(1));
159
160 let point = vec![
161 BabyBear::from_canonical_u32(0),
162 BabyBear::from_canonical_u32(1),
163 ];
164 assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(2));
165
166 let point = vec![
167 BabyBear::from_canonical_u32(1),
168 BabyBear::from_canonical_u32(0),
169 ];
170 assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(3));
171
172 let point = vec![
173 BabyBear::from_canonical_u32(1),
174 BabyBear::from_canonical_u32(1),
175 ];
176 assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(4));
177
178 let point = vec![
180 BabyBear::from_canonical_u32(2),
181 BabyBear::from_canonical_u32(2),
182 ];
183 assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(7));
184 }
185
186 #[test]
187 fn test_mle_marginalize_first() {
188 let evals = vec![
189 BabyBear::from_canonical_u32(1),
190 BabyBear::from_canonical_u32(2),
191 BabyBear::from_canonical_u32(3),
192 BabyBear::from_canonical_u32(4),
193 ];
194 let sum = BabyBear::from_canonical_u32(10);
195
196 let mle = Mle::new(evals);
198 let poly = mle.marginalize_first(sum);
200
201 assert_eq!(
202 poly.evaluate(BabyBear::ZERO),
203 BabyBear::from_canonical_u32(3)
204 );
205 assert_eq!(
206 poly.evaluate(BabyBear::ONE),
207 BabyBear::from_canonical_u32(7)
208 );
209 }
210
211 #[test]
212 fn test_mle_partial_evaluation() {
213 let evals = vec![
214 BabyBear::from_canonical_u32(1),
215 BabyBear::from_canonical_u32(2),
216 BabyBear::from_canonical_u32(3),
217 BabyBear::from_canonical_u32(4),
218 ];
219 let mle = Mle::new(evals);
221 let alpha = BabyBear::from_canonical_u32(2);
222 let partial_eval = mle.partial_evaluation(alpha);
224
225 assert_eq!(
226 partial_eval.eval(&[BabyBear::ZERO]),
227 BabyBear::from_canonical_u32(5)
228 );
229 assert_eq!(
230 partial_eval.eval(&[BabyBear::ONE]),
231 BabyBear::from_canonical_u32(6)
232 );
233 }
234
235 #[test]
236 fn eq_identical_hypercube_points_returns_one() {
237 let zero = BabyBear::ZERO;
238 let one = BabyBear::ONE;
239 let a = &[one, zero, one];
240
241 let eq_eval = hypercube_eq(a, a);
242
243 assert_eq!(eq_eval, one);
244 }
245
246 #[test]
247 fn eq_different_hypercube_points_returns_zero() {
248 let zero = BabyBear::ZERO;
249 let one = BabyBear::ONE;
250 let a = &[one, zero, one];
251 let b = &[one, zero, zero];
252
253 let eq_eval = hypercube_eq(a, b);
254
255 assert_eq!(eq_eval, zero);
256 }
257
258 #[test]
259 #[should_panic]
260 fn eq_different_size_points() {
261 let zero = BabyBear::ZERO;
262 let one = BabyBear::ONE;
263
264 hypercube_eq(&[zero, one], &[zero]);
265 }
266}