openvm_stark_backend/poly/
multi.rs

1//! Copied from starkware-libs/stwo under Apache-2.0 license.
2use std::{
3    iter::zip,
4    ops::{Deref, DerefMut},
5};
6
7use p3_field::{ExtensionField, Field};
8
9use super::uni::UnivariatePolynomial;
10
11/// Represents a multivariate polynomial `g(x_1, ..., x_n)`.
12pub trait MultivariatePolyOracle<F> {
13    /// For an n-variate polynomial, returns n.
14    fn arity(&self) -> usize;
15
16    /// Returns the sum of `g(x_1, x_2, ..., x_n)` over all `(x_2, ..., x_n)` in `{0, 1}^(n-1)` as a polynomial in `x_1`.
17    fn marginalize_first(&self, claim: F) -> UnivariatePolynomial<F>;
18
19    /// Returns the multivariate polynomial `h(x_2, ..., x_n) = g(alpha, x_2, ..., x_n)`.
20    fn partial_evaluation(self, alpha: F) -> Self;
21}
22
23/// Multilinear extension of the function defined on the boolean hypercube.
24///
25/// The evaluations are stored in lexicographic order.
26#[derive(Debug, Clone)]
27pub struct Mle<F> {
28    evals: Vec<F>,
29}
30
31impl<F: Field> Mle<F> {
32    /// Creates a [`Mle`] from evaluations of a multilinear polynomial on the boolean hypercube.
33    ///
34    /// # Panics
35    ///
36    /// Panics if the number of evaluations is not a power of two.
37    pub fn new(evals: Vec<F>) -> Self {
38        assert!(evals.len().is_power_of_two());
39        Self { evals }
40    }
41
42    pub fn into_evals(self) -> Vec<F> {
43        self.evals
44    }
45}
46
47impl<F: Field> MultivariatePolyOracle<F> for Mle<F> {
48    fn arity(&self) -> usize {
49        self.evals.len().ilog2() as usize
50    }
51
52    fn marginalize_first(&self, claim: F) -> UnivariatePolynomial<F> {
53        let x0 = F::ZERO;
54        let x1 = F::ONE;
55
56        let y0 = self[0..self.len() / 2]
57            .iter()
58            .fold(F::ZERO, |acc, x| acc + *x);
59        let y1 = claim - y0;
60
61        UnivariatePolynomial::from_interpolation(&[(x0, y0), (x1, y1)])
62    }
63
64    fn partial_evaluation(self, alpha: F) -> Self {
65        let midpoint = self.len() / 2;
66        let (lhs_evals, rhs_evals) = self.split_at(midpoint);
67
68        let res = zip(lhs_evals, rhs_evals)
69            .map(|(&lhs_eval, &rhs_eval)| alpha * (rhs_eval - lhs_eval) + lhs_eval)
70            .collect();
71
72        Mle::new(res)
73    }
74}
75
76impl<F> Deref for Mle<F> {
77    type Target = [F];
78
79    fn deref(&self) -> &Self::Target {
80        &self.evals
81    }
82}
83
84impl<F: Field> DerefMut for Mle<F> {
85    fn deref_mut(&mut self) -> &mut Self::Target {
86        &mut self.evals
87    }
88}
89
90/// Evaluates the boolean Lagrange basis polynomial `eq(x, y)`.
91///
92/// Formally, the boolean Lagrange basis polynomial is defined as:
93/// ```text
94/// eq(x_1, \dots, x_n, y_1, \dots, y_n) = \prod_{i=1}^n (x_i * y_i + (1 - x_i) * (1 - y_i)).
95/// ```
96/// For boolean vectors `x` and `y`, the function returns `1` if `x` equals `y` and `0` otherwise.
97///
98/// # Panics
99/// - Panics if `x` and `y` have different lengths.
100pub fn hypercube_eq<F: Field>(x: &[F], y: &[F]) -> F {
101    assert_eq!(x.len(), y.len());
102    zip(x, y)
103        .map(|(&xi, &yi)| xi * yi + (xi - F::ONE) * (yi - F::ONE))
104        .product()
105}
106
107/// Computes `hypercube_eq(0, assignment) * eval0 + hypercube_eq(1, assignment) * eval1`.
108pub fn fold_mle_evals<F, EF>(assignment: EF, eval0: F, eval1: F) -> EF
109where
110    F: Field,
111    EF: ExtensionField<F>,
112{
113    assignment * (eval1 - eval0) + eval0
114}
115
116#[cfg(test)]
117mod test {
118    use p3_baby_bear::BabyBear;
119    use p3_field::{Field, FieldAlgebra};
120
121    use super::*;
122
123    impl<F: Field> Mle<F> {
124        /// Evaluates the multilinear polynomial at `point`.
125        pub(crate) fn eval(&self, point: &[F]) -> F {
126            pub fn eval_rec<F: Field>(mle_evals: &[F], p: &[F]) -> F {
127                match p {
128                    [] => mle_evals[0],
129                    &[p_i, ref p @ ..] => {
130                        let (lhs, rhs) = mle_evals.split_at(mle_evals.len() / 2);
131                        let lhs_eval = eval_rec(lhs, p);
132                        let rhs_eval = eval_rec(rhs, p);
133                        // Equivalent to `eq(0, p_i) * lhs_eval + eq(1, p_i) * rhs_eval`.
134                        p_i * (rhs_eval - lhs_eval) + lhs_eval
135                    }
136                }
137            }
138
139            let mle_evals = self.clone().into_evals();
140            eval_rec(&mle_evals, point)
141        }
142    }
143
144    #[test]
145    fn test_mle_evaluation() {
146        let evals = vec![
147            BabyBear::from_canonical_u32(1),
148            BabyBear::from_canonical_u32(2),
149            BabyBear::from_canonical_u32(3),
150            BabyBear::from_canonical_u32(4),
151        ];
152        // (1 - x_1)(1 - x_2) + 2 (1 - x_1) x_2 + 3 x_1 (1 - x_2) + 4 x_1 x_2
153        let mle = Mle::new(evals);
154        let point = vec![
155            BabyBear::from_canonical_u32(0),
156            BabyBear::from_canonical_u32(0),
157        ];
158        assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(1));
159
160        let point = vec![
161            BabyBear::from_canonical_u32(0),
162            BabyBear::from_canonical_u32(1),
163        ];
164        assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(2));
165
166        let point = vec![
167            BabyBear::from_canonical_u32(1),
168            BabyBear::from_canonical_u32(0),
169        ];
170        assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(3));
171
172        let point = vec![
173            BabyBear::from_canonical_u32(1),
174            BabyBear::from_canonical_u32(1),
175        ];
176        assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(4));
177
178        // Out of domain evaluation
179        let point = vec![
180            BabyBear::from_canonical_u32(2),
181            BabyBear::from_canonical_u32(2),
182        ];
183        assert_eq!(mle.eval(&point), BabyBear::from_canonical_u32(7));
184    }
185
186    #[test]
187    fn test_mle_marginalize_first() {
188        let evals = vec![
189            BabyBear::from_canonical_u32(1),
190            BabyBear::from_canonical_u32(2),
191            BabyBear::from_canonical_u32(3),
192            BabyBear::from_canonical_u32(4),
193        ];
194        let sum = BabyBear::from_canonical_u32(10);
195
196        // (1 - x_1)(1 - x_2) + 2 (1 - x_1) x_2 + 3 x_1 (1 - x_2) + 4 x_1 x_2
197        let mle = Mle::new(evals);
198        // (1 - x_1) + 2 (1 - x_1) + 3 x_1 + 4 x_1
199        let poly = mle.marginalize_first(sum);
200
201        assert_eq!(
202            poly.evaluate(BabyBear::ZERO),
203            BabyBear::from_canonical_u32(3)
204        );
205        assert_eq!(
206            poly.evaluate(BabyBear::ONE),
207            BabyBear::from_canonical_u32(7)
208        );
209    }
210
211    #[test]
212    fn test_mle_partial_evaluation() {
213        let evals = vec![
214            BabyBear::from_canonical_u32(1),
215            BabyBear::from_canonical_u32(2),
216            BabyBear::from_canonical_u32(3),
217            BabyBear::from_canonical_u32(4),
218        ];
219        // (1 - x_1)(1 - x_2) + 2 (1 - x_1) x_2 + 3 x_1 (1 - x_2) + 4 x_1 x_2
220        let mle = Mle::new(evals);
221        let alpha = BabyBear::from_canonical_u32(2);
222        // -(1 - x_2) - 2 x_2 + 6 (1 - x_2) + 8 x_2 = x_2 + 5
223        let partial_eval = mle.partial_evaluation(alpha);
224
225        assert_eq!(
226            partial_eval.eval(&[BabyBear::ZERO]),
227            BabyBear::from_canonical_u32(5)
228        );
229        assert_eq!(
230            partial_eval.eval(&[BabyBear::ONE]),
231            BabyBear::from_canonical_u32(6)
232        );
233    }
234
235    #[test]
236    fn eq_identical_hypercube_points_returns_one() {
237        let zero = BabyBear::ZERO;
238        let one = BabyBear::ONE;
239        let a = &[one, zero, one];
240
241        let eq_eval = hypercube_eq(a, a);
242
243        assert_eq!(eq_eval, one);
244    }
245
246    #[test]
247    fn eq_different_hypercube_points_returns_zero() {
248        let zero = BabyBear::ZERO;
249        let one = BabyBear::ONE;
250        let a = &[one, zero, one];
251        let b = &[one, zero, zero];
252
253        let eq_eval = hypercube_eq(a, b);
254
255        assert_eq!(eq_eval, zero);
256    }
257
258    #[test]
259    #[should_panic]
260    fn eq_different_size_points() {
261        let zero = BabyBear::ZERO;
262        let one = BabyBear::ONE;
263
264        hypercube_eq(&[zero, one], &[zero]);
265    }
266}