openvm_stark_backend/
sumcheck.rs

1// Copied from starkware-libs/stwo under Apache-2.0 license.
2//
3//! Sum-check protocol that proves and verifies claims about `sum_x g(x)` for all x in `{0, 1}^n`.
4//!
5//! [`MultivariatePolyOracle`] provides methods for evaluating sums and making transformations on
6//! `g` in the context of the protocol. It is intended to be used in conjunction with
7//! [`prove_batch`] to generate proofs.
8
9use std::iter::zip;
10
11use itertools::Itertools;
12use p3_challenger::FieldChallenger;
13use p3_field::Field;
14use thiserror::Error;
15
16use crate::poly::{multi::MultivariatePolyOracle, uni::UnivariatePolynomial};
17
18pub struct SumcheckArtifacts<F, O> {
19    pub evaluation_point: Vec<F>,
20    pub constant_poly_oracles: Vec<O>,
21    pub claimed_evals: Vec<F>,
22}
23
24/// Performs sum-check on a random linear combinations of multiple multivariate polynomials.
25///
26/// Let the multivariate polynomials be `g_0, ..., g_{n-1}`. A single sum-check is performed on
27/// multivariate polynomial `h = g_0 + lambda * g_1 + ... + lambda^(n-1) * g_{n-1}`. The `g_i`s do
28/// not need to have the same number of variables. `g_i`s with less variables are folded in the
29/// latest possible round of the protocol. For instance with `g_0(x, y, z)` and `g_1(x, y)`
30/// sum-check is performed on `h(x, y, z) = g_0(x, y, z) + lambda * g_1(y, z)`. Claim `c_i` should
31/// equal the claimed sum of `g_i(x_0, ..., x_{j-1})` over all `(x_0, ..., x_{j-1})` in `{0, 1}^j`.
32///
33/// The degree of each `g_i` should not exceed [`MAX_DEGREE`] in any variable.  The sum-check proof
34/// of `h`, list of challenges (variable assignment) and the constant oracles (i.e. the `g_i` with
35/// all variables fixed to their corresponding challenges) are returned.
36///
37/// Output is of the form: `(proof, artifacts)`.
38///
39/// # Panics
40///
41/// Panics if:
42/// - No multivariate polynomials are provided.
43/// - There aren't the same number of multivariate polynomials and claims.
44/// - The degree of any multivariate polynomial exceeds [`MAX_DEGREE`] in any variable.
45/// - The round polynomials are inconsistent with their corresponding claimed sum on `0` and `1`.
46pub fn prove_batch<F: Field, O: MultivariatePolyOracle<F>>(
47    mut claims: Vec<F>,
48    mut polys: Vec<O>,
49    lambda: F,
50    challenger: &mut impl FieldChallenger<F>,
51) -> (SumcheckProof<F>, SumcheckArtifacts<F, O>) {
52    let n_variables = polys.iter().map(O::arity).max().unwrap();
53    assert_eq!(claims.len(), polys.len());
54
55    let mut round_polys = vec![];
56    let mut evaluation_point = vec![];
57
58    // Update the claims for the sum over `h`'s hypercube.
59    for (claim, multivariate_poly) in zip(&mut claims, &polys) {
60        let n_unused_variables = n_variables - multivariate_poly.arity();
61        *claim *= F::from_canonical_u32(1 << n_unused_variables);
62    }
63
64    // Prove sum-check rounds
65    for round in 0..n_variables {
66        let n_remaining_rounds = n_variables - round;
67
68        let this_round_polys = zip(&polys, &claims)
69            .enumerate()
70            .map(|(i, (multivariate_poly, &claim))| {
71                let round_poly = if n_remaining_rounds == multivariate_poly.arity() {
72                    multivariate_poly.marginalize_first(claim)
73                } else {
74                    claim.halve().into()
75                };
76
77                let eval_at_0 = round_poly.evaluate(F::ZERO);
78                let eval_at_1 = round_poly.evaluate(F::ONE);
79
80                assert_eq!(
81                    eval_at_0 + eval_at_1,
82                    claim,
83                    "Round {round}, poly {i}: eval(0) + eval(1) != claim ({} != {claim})",
84                    eval_at_0 + eval_at_1,
85                );
86                assert!(
87                    round_poly.degree() <= MAX_DEGREE,
88                    "Round {round}, poly {i}: degree {} > max {MAX_DEGREE}",
89                    round_poly.degree(),
90                );
91
92                round_poly
93            })
94            .collect_vec();
95
96        let round_poly = random_linear_combination(&this_round_polys, lambda);
97
98        challenger.observe_slice(&round_poly);
99
100        let challenge = challenger.sample_ext_element();
101
102        claims = this_round_polys
103            .iter()
104            .map(|round_poly| round_poly.evaluate(challenge))
105            .collect();
106
107        polys = polys
108            .into_iter()
109            .map(|multivariate_poly| {
110                if n_remaining_rounds != multivariate_poly.arity() {
111                    multivariate_poly
112                } else {
113                    multivariate_poly.partial_evaluation(challenge)
114                }
115            })
116            .collect();
117
118        round_polys.push(round_poly);
119        evaluation_point.push(challenge);
120    }
121
122    let proof = SumcheckProof { round_polys };
123    let artifacts = SumcheckArtifacts {
124        evaluation_point,
125        constant_poly_oracles: polys,
126        claimed_evals: claims,
127    };
128
129    (proof, artifacts)
130}
131
132/// Returns `p_0 + alpha * p_1 + ... + alpha^(n-1) * p_{n-1}`.
133#[allow(dead_code)]
134fn random_linear_combination<F: Field>(
135    polys: &[UnivariatePolynomial<F>],
136    alpha: F,
137) -> UnivariatePolynomial<F> {
138    polys
139        .iter()
140        .rfold(UnivariatePolynomial::<F>::zero(), |acc, poly| {
141            acc * alpha + poly.clone()
142        })
143}
144
145/// Partially verifies a sum-check proof.
146///
147/// Only "partial" since it does not fully verify the prover's claimed evaluation on the variable
148/// assignment but checks if the sum of the round polynomials evaluated on `0` and `1` matches the
149/// claim for each round. If the proof passes these checks, the variable assignment and the prover's
150/// claimed evaluation are returned for the caller to validate otherwise an [`Err`] is returned.
151///
152/// Output is of the form `(variable_assignment, claimed_eval)`.
153pub fn partially_verify<F: Field>(
154    mut claim: F,
155    proof: &SumcheckProof<F>,
156    challenger: &mut impl FieldChallenger<F>,
157) -> Result<(Vec<F>, F), SumcheckError<F>> {
158    let mut assignment = Vec::new();
159
160    for (round, round_poly) in proof.round_polys.iter().enumerate() {
161        if round_poly.degree() > MAX_DEGREE {
162            return Err(SumcheckError::DegreeInvalid { round });
163        }
164
165        // TODO: optimize this by sending one less coefficient, and computing it from the
166        // claim, instead of checking the claim. (Can also be done by quotienting).
167        let sum = round_poly.evaluate(F::ZERO) + round_poly.evaluate(F::ONE);
168
169        if claim != sum {
170            return Err(SumcheckError::SumInvalid { claim, sum, round });
171        }
172
173        challenger.observe_slice(round_poly);
174        let challenge = challenger.sample_ext_element();
175
176        claim = round_poly.evaluate(challenge);
177        assignment.push(challenge);
178    }
179
180    Ok((assignment, claim))
181}
182
183#[derive(Debug, Clone)]
184pub struct SumcheckProof<F> {
185    pub round_polys: Vec<UnivariatePolynomial<F>>,
186}
187
188/// Max degree of polynomials the verifier accepts in each round of the protocol.
189pub const MAX_DEGREE: usize = 3;
190
191/// Sum-check protocol verification error.
192#[derive(Error, Debug)]
193pub enum SumcheckError<F> {
194    #[error("degree of the polynomial in round {round} is too high")]
195    DegreeInvalid { round: RoundIndex },
196    #[error("sum does not match the claim in round {round} (sum {sum}, claim {claim})")]
197    SumInvalid { claim: F, sum: F, round: RoundIndex },
198}
199
200/// Sum-check round index where 0 corresponds to the first round.
201pub type RoundIndex = usize;
202
203#[cfg(test)]
204mod tests {
205    use openvm_stark_sdk::{
206        config::baby_bear_blake3::default_engine, engine::StarkEngine, utils::create_seeded_rng,
207    };
208    use p3_baby_bear::BabyBear;
209    use p3_field::FieldAlgebra;
210    use rand::Rng;
211
212    use super::*;
213    use crate::poly::multi::Mle;
214
215    #[test]
216    fn sumcheck_works() {
217        type F = BabyBear;
218
219        let engine = default_engine();
220
221        let mut rng = create_seeded_rng();
222        let values: Vec<F> = (0..32).map(|_| rng.gen()).collect();
223        let claim = values.iter().copied().sum();
224
225        let mle = Mle::new(values);
226
227        let lambda = F::ONE;
228
229        let (proof, _) = prove_batch(
230            vec![claim],
231            vec![mle.clone()],
232            lambda,
233            &mut engine.new_challenger(),
234        );
235        let (assignment, eval) =
236            partially_verify(claim, &proof, &mut engine.new_challenger()).unwrap();
237
238        assert_eq!(eval, mle.eval(&assignment));
239    }
240
241    #[test]
242    fn batch_sumcheck_works() {
243        type F = BabyBear;
244
245        let engine = default_engine();
246        let mut rng = create_seeded_rng();
247
248        let values0: Vec<F> = (0..32).map(|_| rng.gen()).collect();
249        let values1: Vec<F> = (0..32).map(|_| rng.gen()).collect();
250        let claim0 = values0.iter().copied().sum();
251        let claim1 = values1.iter().copied().sum();
252
253        let mle0 = Mle::new(values0.clone());
254        let mle1 = Mle::new(values1.clone());
255
256        let lambda: F = rng.gen();
257
258        let claims = vec![claim0, claim1];
259        let mles = vec![mle0.clone(), mle1.clone()];
260        let (proof, _) = prove_batch(claims, mles, lambda, &mut engine.new_challenger());
261
262        let claim = claim0 + lambda * claim1;
263        let (assignment, eval) =
264            partially_verify(claim, &proof, &mut engine.new_challenger()).unwrap();
265
266        let eval0 = mle0.eval(&assignment);
267        let eval1 = mle1.eval(&assignment);
268        assert_eq!(eval, eval0 + lambda * eval1);
269    }
270
271    #[test]
272    fn batch_sumcheck_with_different_n_variables() {
273        type F = BabyBear;
274
275        let engine = default_engine();
276        let mut rng = create_seeded_rng();
277
278        let values0: Vec<F> = (0..64).map(|_| rng.gen()).collect();
279        let values1: Vec<F> = (0..32).map(|_| rng.gen()).collect();
280
281        let claim0 = values0.iter().copied().sum();
282        let claim1 = values1.iter().copied().sum();
283
284        let mle0 = Mle::new(values0.clone());
285        let mle1 = Mle::new(values1.clone());
286
287        let lambda: F = rng.gen();
288
289        let claims = vec![claim0, claim1];
290        let mles = vec![mle0.clone(), mle1.clone()];
291        let (proof, _) = prove_batch(claims, mles, lambda, &mut engine.new_challenger());
292
293        let claim = claim0 + lambda * claim1.double();
294        let (assignment, eval) =
295            partially_verify(claim, &proof, &mut engine.new_challenger()).unwrap();
296
297        let eval0 = mle0.eval(&assignment);
298        let eval1 = mle1.eval(&assignment[1..]);
299        assert_eq!(eval, eval0 + lambda * eval1);
300    }
301
302    #[test]
303    fn invalid_sumcheck_proof_fails() {
304        type F = BabyBear;
305
306        let engine = default_engine();
307        let mut rng = create_seeded_rng();
308
309        let values: Vec<F> = (0..8).map(|_| rng.gen()).collect();
310        let claim = values.iter().copied().sum();
311
312        let lambda = F::ONE;
313
314        // Compromise the first value.
315        let mut invalid_values = values;
316        invalid_values[0] += F::ONE;
317        let invalid_claim = claim + F::ONE;
318        let invalid_mle = Mle::new(invalid_values.clone());
319        let (invalid_proof, _) = prove_batch(
320            vec![invalid_claim],
321            vec![invalid_mle],
322            lambda,
323            &mut engine.new_challenger(),
324        );
325
326        assert!(partially_verify(claim, &invalid_proof, &mut engine.new_challenger()).is_err());
327    }
328}