1use std::iter::zip;
10
11use itertools::Itertools;
12use p3_challenger::FieldChallenger;
13use p3_field::Field;
14use thiserror::Error;
15
16use crate::poly::{multi::MultivariatePolyOracle, uni::UnivariatePolynomial};
17
18pub struct SumcheckArtifacts<F, O> {
19 pub evaluation_point: Vec<F>,
20 pub constant_poly_oracles: Vec<O>,
21 pub claimed_evals: Vec<F>,
22}
23
24pub fn prove_batch<F: Field, O: MultivariatePolyOracle<F>>(
47 mut claims: Vec<F>,
48 mut polys: Vec<O>,
49 lambda: F,
50 challenger: &mut impl FieldChallenger<F>,
51) -> (SumcheckProof<F>, SumcheckArtifacts<F, O>) {
52 let n_variables = polys.iter().map(O::arity).max().unwrap();
53 assert_eq!(claims.len(), polys.len());
54
55 let mut round_polys = vec![];
56 let mut evaluation_point = vec![];
57
58 for (claim, multivariate_poly) in zip(&mut claims, &polys) {
60 let n_unused_variables = n_variables - multivariate_poly.arity();
61 *claim *= F::from_canonical_u32(1 << n_unused_variables);
62 }
63
64 for round in 0..n_variables {
66 let n_remaining_rounds = n_variables - round;
67
68 let this_round_polys = zip(&polys, &claims)
69 .enumerate()
70 .map(|(i, (multivariate_poly, &claim))| {
71 let round_poly = if n_remaining_rounds == multivariate_poly.arity() {
72 multivariate_poly.marginalize_first(claim)
73 } else {
74 claim.halve().into()
75 };
76
77 let eval_at_0 = round_poly.evaluate(F::ZERO);
78 let eval_at_1 = round_poly.evaluate(F::ONE);
79
80 assert_eq!(
81 eval_at_0 + eval_at_1,
82 claim,
83 "Round {round}, poly {i}: eval(0) + eval(1) != claim ({} != {claim})",
84 eval_at_0 + eval_at_1,
85 );
86 assert!(
87 round_poly.degree() <= MAX_DEGREE,
88 "Round {round}, poly {i}: degree {} > max {MAX_DEGREE}",
89 round_poly.degree(),
90 );
91
92 round_poly
93 })
94 .collect_vec();
95
96 let round_poly = random_linear_combination(&this_round_polys, lambda);
97
98 challenger.observe_slice(&round_poly);
99
100 let challenge = challenger.sample_ext_element();
101
102 claims = this_round_polys
103 .iter()
104 .map(|round_poly| round_poly.evaluate(challenge))
105 .collect();
106
107 polys = polys
108 .into_iter()
109 .map(|multivariate_poly| {
110 if n_remaining_rounds != multivariate_poly.arity() {
111 multivariate_poly
112 } else {
113 multivariate_poly.partial_evaluation(challenge)
114 }
115 })
116 .collect();
117
118 round_polys.push(round_poly);
119 evaluation_point.push(challenge);
120 }
121
122 let proof = SumcheckProof { round_polys };
123 let artifacts = SumcheckArtifacts {
124 evaluation_point,
125 constant_poly_oracles: polys,
126 claimed_evals: claims,
127 };
128
129 (proof, artifacts)
130}
131
132#[allow(dead_code)]
134fn random_linear_combination<F: Field>(
135 polys: &[UnivariatePolynomial<F>],
136 alpha: F,
137) -> UnivariatePolynomial<F> {
138 polys
139 .iter()
140 .rfold(UnivariatePolynomial::<F>::zero(), |acc, poly| {
141 acc * alpha + poly.clone()
142 })
143}
144
145pub fn partially_verify<F: Field>(
154 mut claim: F,
155 proof: &SumcheckProof<F>,
156 challenger: &mut impl FieldChallenger<F>,
157) -> Result<(Vec<F>, F), SumcheckError<F>> {
158 let mut assignment = Vec::new();
159
160 for (round, round_poly) in proof.round_polys.iter().enumerate() {
161 if round_poly.degree() > MAX_DEGREE {
162 return Err(SumcheckError::DegreeInvalid { round });
163 }
164
165 let sum = round_poly.evaluate(F::ZERO) + round_poly.evaluate(F::ONE);
168
169 if claim != sum {
170 return Err(SumcheckError::SumInvalid { claim, sum, round });
171 }
172
173 challenger.observe_slice(round_poly);
174 let challenge = challenger.sample_ext_element();
175
176 claim = round_poly.evaluate(challenge);
177 assignment.push(challenge);
178 }
179
180 Ok((assignment, claim))
181}
182
183#[derive(Debug, Clone)]
184pub struct SumcheckProof<F> {
185 pub round_polys: Vec<UnivariatePolynomial<F>>,
186}
187
188pub const MAX_DEGREE: usize = 3;
190
191#[derive(Error, Debug)]
193pub enum SumcheckError<F> {
194 #[error("degree of the polynomial in round {round} is too high")]
195 DegreeInvalid { round: RoundIndex },
196 #[error("sum does not match the claim in round {round} (sum {sum}, claim {claim})")]
197 SumInvalid { claim: F, sum: F, round: RoundIndex },
198}
199
200pub type RoundIndex = usize;
202
203#[cfg(test)]
204mod tests {
205 use openvm_stark_sdk::{
206 config::baby_bear_blake3::default_engine, engine::StarkEngine, utils::create_seeded_rng,
207 };
208 use p3_baby_bear::BabyBear;
209 use p3_field::FieldAlgebra;
210 use rand::Rng;
211
212 use super::*;
213 use crate::poly::multi::Mle;
214
215 #[test]
216 fn sumcheck_works() {
217 type F = BabyBear;
218
219 let engine = default_engine();
220
221 let mut rng = create_seeded_rng();
222 let values: Vec<F> = (0..32).map(|_| rng.gen()).collect();
223 let claim = values.iter().copied().sum();
224
225 let mle = Mle::new(values);
226
227 let lambda = F::ONE;
228
229 let (proof, _) = prove_batch(
230 vec![claim],
231 vec![mle.clone()],
232 lambda,
233 &mut engine.new_challenger(),
234 );
235 let (assignment, eval) =
236 partially_verify(claim, &proof, &mut engine.new_challenger()).unwrap();
237
238 assert_eq!(eval, mle.eval(&assignment));
239 }
240
241 #[test]
242 fn batch_sumcheck_works() {
243 type F = BabyBear;
244
245 let engine = default_engine();
246 let mut rng = create_seeded_rng();
247
248 let values0: Vec<F> = (0..32).map(|_| rng.gen()).collect();
249 let values1: Vec<F> = (0..32).map(|_| rng.gen()).collect();
250 let claim0 = values0.iter().copied().sum();
251 let claim1 = values1.iter().copied().sum();
252
253 let mle0 = Mle::new(values0.clone());
254 let mle1 = Mle::new(values1.clone());
255
256 let lambda: F = rng.gen();
257
258 let claims = vec![claim0, claim1];
259 let mles = vec![mle0.clone(), mle1.clone()];
260 let (proof, _) = prove_batch(claims, mles, lambda, &mut engine.new_challenger());
261
262 let claim = claim0 + lambda * claim1;
263 let (assignment, eval) =
264 partially_verify(claim, &proof, &mut engine.new_challenger()).unwrap();
265
266 let eval0 = mle0.eval(&assignment);
267 let eval1 = mle1.eval(&assignment);
268 assert_eq!(eval, eval0 + lambda * eval1);
269 }
270
271 #[test]
272 fn batch_sumcheck_with_different_n_variables() {
273 type F = BabyBear;
274
275 let engine = default_engine();
276 let mut rng = create_seeded_rng();
277
278 let values0: Vec<F> = (0..64).map(|_| rng.gen()).collect();
279 let values1: Vec<F> = (0..32).map(|_| rng.gen()).collect();
280
281 let claim0 = values0.iter().copied().sum();
282 let claim1 = values1.iter().copied().sum();
283
284 let mle0 = Mle::new(values0.clone());
285 let mle1 = Mle::new(values1.clone());
286
287 let lambda: F = rng.gen();
288
289 let claims = vec![claim0, claim1];
290 let mles = vec![mle0.clone(), mle1.clone()];
291 let (proof, _) = prove_batch(claims, mles, lambda, &mut engine.new_challenger());
292
293 let claim = claim0 + lambda * claim1.double();
294 let (assignment, eval) =
295 partially_verify(claim, &proof, &mut engine.new_challenger()).unwrap();
296
297 let eval0 = mle0.eval(&assignment);
298 let eval1 = mle1.eval(&assignment[1..]);
299 assert_eq!(eval, eval0 + lambda * eval1);
300 }
301
302 #[test]
303 fn invalid_sumcheck_proof_fails() {
304 type F = BabyBear;
305
306 let engine = default_engine();
307 let mut rng = create_seeded_rng();
308
309 let values: Vec<F> = (0..8).map(|_| rng.gen()).collect();
310 let claim = values.iter().copied().sum();
311
312 let lambda = F::ONE;
313
314 let mut invalid_values = values;
316 invalid_values[0] += F::ONE;
317 let invalid_claim = claim + F::ONE;
318 let invalid_mle = Mle::new(invalid_values.clone());
319 let (invalid_proof, _) = prove_batch(
320 vec![invalid_claim],
321 vec![invalid_mle],
322 lambda,
323 &mut engine.new_challenger(),
324 );
325
326 assert!(partially_verify(claim, &invalid_proof, &mut engine.new_challenger()).is_err());
327 }
328}