1use std::{
4 iter::{successors, zip},
5 ops::Deref,
6};
7
8use itertools::Itertools;
9use p3_challenger::FieldChallenger;
10use p3_field::Field;
11use thiserror::Error;
12
13use crate::{
14 gkr::types::{GkrArtifact, GkrBatchProof, GkrMask, Layer},
15 poly::{
16 multi::{hypercube_eq, Mle, MultivariatePolyOracle},
17 uni::{random_linear_combination, UnivariatePolynomial},
18 },
19 sumcheck,
20 sumcheck::SumcheckArtifacts,
21};
22
23#[derive(Debug, Clone)]
29struct HypercubeEqEvals<F> {
30 y: Vec<F>,
31 evals: Vec<F>,
32}
33
34impl<F: Field> HypercubeEqEvals<F> {
35 pub fn eval(y: &[F]) -> Self {
36 let y = y.to_vec();
37
38 if y.is_empty() {
39 let evals = vec![F::ONE];
40 return Self { evals, y };
41 }
42
43 let evals = Self::gen(&y[1..], F::ONE - y[0]);
45 assert_eq!(evals.len(), 1 << (y.len() - 1));
46 Self { evals, y }
47 }
48
49 fn gen(y: &[F], v: F) -> Vec<F> {
51 let mut evals = Vec::with_capacity(1 << y.len());
52 evals.push(v);
53
54 for &y_i in y.iter().rev() {
55 for j in 0..evals.len() {
56 let tmp = evals[j] * y_i;
59 evals.push(tmp);
60 evals[j] -= tmp;
61 }
62 }
63
64 evals
65 }
66}
67
68impl<F> Deref for HypercubeEqEvals<F> {
69 type Target = [F];
70
71 fn deref(&self) -> &Self::Target {
72 self.evals.deref()
73 }
74}
75
76struct GkrMultivariatePolyOracle<'a, F: Clone> {
95 pub eq_evals: &'a HypercubeEqEvals<F>,
96 pub input_layer: Layer<F>,
97 pub eq_fixed_var_correction: F,
98 pub lambda: F,
100}
101
102impl<F: Field> MultivariatePolyOracle<F> for GkrMultivariatePolyOracle<'_, F> {
103 fn arity(&self) -> usize {
104 self.input_layer.n_variables() - 1
105 }
106
107 fn marginalize_first(&self, claim: F) -> UnivariatePolynomial<F> {
108 let n_variables = self.arity();
109 assert_ne!(n_variables, 0);
110 let n_terms = 1 << (n_variables - 1);
111 let y = &self.eq_evals.y;
113 let lambda = self.lambda;
114
115 let (mut eval_at_0, mut eval_at_2) = match &self.input_layer {
116 Layer::GrandProduct(col) => eval_grand_product_sum(self.eq_evals, col, n_terms),
117 Layer::LogUpGeneric {
118 numerators,
119 denominators,
120 }
121 | Layer::LogUpMultiplicities {
122 numerators,
123 denominators,
124 } => eval_logup_sum(self.eq_evals, numerators, denominators, n_terms, lambda),
125 Layer::LogUpSingles { denominators } => {
126 eval_logup_singles_sum(self.eq_evals, denominators, n_terms, lambda)
127 }
128 };
129
130 eval_at_0 *= self.eq_fixed_var_correction;
131 eval_at_2 *= self.eq_fixed_var_correction;
132 correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
133 }
134
135 fn partial_evaluation(self, alpha: F) -> Self {
136 if self.is_constant() {
137 return self;
138 }
139
140 let z0 = self.eq_evals.y[self.eq_evals.y.len() - self.arity()];
141 let eq_fixed_var_correction = self.eq_fixed_var_correction * hypercube_eq(&[alpha], &[z0]);
142
143 Self {
144 eq_evals: self.eq_evals,
145 eq_fixed_var_correction,
146 input_layer: self.input_layer.fix_first_variable(alpha),
147 lambda: self.lambda,
148 }
149 }
150}
151
152fn eval_grand_product_sum<F: Field>(
156 eq_evals: &HypercubeEqEvals<F>,
157 input_layer: &Mle<F>,
158 n_terms: usize,
159) -> (F, F) {
160 let mut eval_at_0 = F::ZERO;
161 let mut eval_at_2 = F::ZERO;
162
163 for i in 0..n_terms {
164 let (inp_r0_0, inp_r0_1) = (input_layer[i * 2], input_layer[i * 2 + 1]);
166 let (inp_r1_0, inp_r1_1) = (
167 input_layer[(n_terms + i) * 2],
168 input_layer[(n_terms + i) * 2 + 1],
169 );
170
171 let inp_r2_0 = inp_r1_0.double() - inp_r0_0;
173 let inp_r2_1 = inp_r1_1.double() - inp_r0_1;
174
175 let prod_at_r0i = inp_r0_0 * inp_r0_1;
177 let prod_at_r2i = inp_r2_0 * inp_r2_1;
178
179 let eq_eval_at_0i = eq_evals[i];
181 eval_at_0 += eq_eval_at_0i * prod_at_r0i;
182 eval_at_2 += eq_eval_at_0i * prod_at_r2i;
183 }
184
185 (eval_at_0, eval_at_2)
186}
187
188fn eval_logup_sum<F: Field>(
194 eq_evals: &HypercubeEqEvals<F>,
195 input_numerators: &Mle<F>,
196 input_denominators: &Mle<F>,
197 n_terms: usize,
198 lambda: F,
199) -> (F, F) {
200 let mut eval_at_0 = F::ZERO;
201 let mut eval_at_2 = F::ZERO;
202
203 for i in 0..n_terms {
204 let (numer_r0_0, denom_r0_0) = (input_numerators[i * 2], input_denominators[i * 2]);
206 let (numer_r0_1, denom_r0_1) = (input_numerators[i * 2 + 1], input_denominators[i * 2 + 1]);
207 let (numer_r1_0, denom_r1_0) = (
208 input_numerators[(n_terms + i) * 2],
209 input_denominators[(n_terms + i) * 2],
210 );
211 let (numer_r1_1, denom_r1_1) = (
212 input_numerators[(n_terms + i) * 2 + 1],
213 input_denominators[(n_terms + i) * 2 + 1],
214 );
215
216 let numer_r2_0 = numer_r1_0.double() - numer_r0_0;
218 let denom_r2_0 = denom_r1_0.double() - denom_r0_0;
219 let numer_r2_1 = numer_r1_1.double() - numer_r0_1;
220 let denom_r2_1 = denom_r1_1.double() - denom_r0_1;
221
222 let numer_at_r0i = numer_r0_0 * denom_r0_1 + numer_r0_1 * denom_r0_0;
224 let denom_at_r0i = denom_r0_1 * denom_r0_0;
225 let numer_at_r2i = numer_r2_0 * denom_r2_1 + numer_r2_1 * denom_r2_0;
226 let denom_at_r2i = denom_r2_1 * denom_r2_0;
227
228 let eq_eval_at_0i = eq_evals[i];
230 eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
231 eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i);
232 }
233
234 (eval_at_0, eval_at_2)
235}
236
237fn eval_logup_singles_sum<F: Field>(
242 eq_evals: &HypercubeEqEvals<F>,
243 input_denominators: &Mle<F>,
244 n_terms: usize,
245 lambda: F,
246) -> (F, F) {
247 let mut eval_at_0 = F::ZERO;
248 let mut eval_at_2 = F::ZERO;
249
250 for i in 0..n_terms {
251 let (inp_denom_r0_0, inp_denom_r0_1) =
253 (input_denominators[i * 2], input_denominators[i * 2 + 1]);
254 let (inp_denom_r1_0, inp_denom_r1_1) = (
255 input_denominators[(n_terms + i) * 2],
256 input_denominators[(n_terms + i) * 2 + 1],
257 );
258
259 let inp_denom_r2_0 = inp_denom_r1_0.double() - inp_denom_r0_0;
261 let inp_denom_r2_1 = inp_denom_r1_1.double() - inp_denom_r0_1;
262
263 let numer_at_r0i = inp_denom_r0_0 + inp_denom_r0_1;
265 let denom_at_r0i = inp_denom_r0_0 * inp_denom_r0_1;
266 let numer_at_r2i = inp_denom_r2_0 + inp_denom_r2_1;
267 let denom_at_r2i = inp_denom_r2_0 * inp_denom_r2_1;
268
269 let eq_eval_at_0i = eq_evals[i];
271 eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
272 eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i);
273 }
274
275 (eval_at_0, eval_at_2)
276}
277
278impl<F: Field> GkrMultivariatePolyOracle<'_, F> {
279 fn is_constant(&self) -> bool {
280 self.arity() == 0
281 }
282
283 fn try_into_mask(self) -> Result<GkrMask<F>, NotConstantPolyError> {
294 if !self.is_constant() {
295 return Err(NotConstantPolyError);
296 }
297
298 let columns = match self.input_layer {
299 Layer::GrandProduct(mle) => vec![mle.as_ref().try_into().unwrap()],
300 Layer::LogUpGeneric {
301 numerators,
302 denominators,
303 } => {
304 let numerators = numerators.as_ref().try_into().unwrap();
305 let denominators = denominators.as_ref().try_into().unwrap();
306 vec![numerators, denominators]
307 }
308 Layer::LogUpMultiplicities { .. } => unimplemented!(),
310 Layer::LogUpSingles { denominators } => {
311 let numerators = [F::ONE; 2];
312 let denominators = denominators.as_ref().try_into().unwrap();
313 vec![numerators, denominators]
314 }
315 };
316
317 Ok(GkrMask::new(columns))
318 }
319}
320
321#[derive(Debug, Error)]
323#[error("polynomial is not constant")]
324pub struct NotConstantPolyError;
325
326pub fn prove_batch<F: Field>(
331 challenger: &mut impl FieldChallenger<F>,
332 input_layer_by_instance: Vec<Layer<F>>,
333) -> (GkrBatchProof<F>, GkrArtifact<F>) {
334 let n_instances = input_layer_by_instance.len();
335 let n_layers_by_instance = input_layer_by_instance
336 .iter()
337 .map(|l| l.n_variables())
338 .collect_vec();
339 let n_layers = *n_layers_by_instance.iter().max().unwrap();
340
341 let mut layers_by_instance = input_layer_by_instance
343 .into_iter()
344 .map(|input_layer| gen_layers(input_layer).into_iter().rev())
345 .collect_vec();
346
347 let mut output_claims_by_instance = vec![None; n_instances];
348 let mut layer_masks_by_instance = (0..n_instances).map(|_| Vec::new()).collect_vec();
349 let mut sumcheck_proofs = Vec::new();
350
351 let mut ood_point = Vec::new();
352 let mut claims_to_verify_by_instance = vec![None; n_instances];
353
354 for layer in 0..n_layers {
355 let n_remaining_layers = n_layers - layer;
356
357 for (instance, layers) in layers_by_instance.iter_mut().enumerate() {
359 if n_layers_by_instance[instance] == n_remaining_layers {
360 let output_layer = layers.next().unwrap();
361 let output_layer_values = output_layer.try_into_output_layer_values().unwrap();
362 claims_to_verify_by_instance[instance] = Some(output_layer_values.clone());
363 output_claims_by_instance[instance] = Some(output_layer_values);
364 }
365 }
366
367 for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
369 challenger.observe_slice(claims_to_verify);
370 }
371
372 let eq_evals = HypercubeEqEvals::eval(&ood_point);
373 let sumcheck_alpha = challenger.sample();
374 let instance_lambda = challenger.sample();
375
376 let mut sumcheck_oracles = Vec::new();
377 let mut sumcheck_claims = Vec::new();
378 let mut sumcheck_instances = Vec::new();
379
380 for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() {
382 if let Some(claims_to_verify) = claims_to_verify {
383 let layer = layers_by_instance[instance].next().unwrap();
384
385 sumcheck_oracles.push(GkrMultivariatePolyOracle {
386 eq_evals: &eq_evals,
387 input_layer: layer,
388 eq_fixed_var_correction: F::ONE,
389 lambda: instance_lambda,
390 });
391 sumcheck_claims.push(random_linear_combination(claims_to_verify, instance_lambda));
392 sumcheck_instances.push(instance);
393 }
394 }
395
396 let (
397 sumcheck_proof,
398 SumcheckArtifacts {
399 evaluation_point: sumcheck_ood_point,
400 constant_poly_oracles,
401 ..
402 },
403 ) = sumcheck::prove_batch(
404 sumcheck_claims,
405 sumcheck_oracles,
406 sumcheck_alpha,
407 challenger,
408 );
409
410 sumcheck_proofs.push(sumcheck_proof);
411
412 let masks = constant_poly_oracles
413 .into_iter()
414 .map(|oracle| oracle.try_into_mask().unwrap())
415 .collect_vec();
416
417 for (&instance, mask) in zip(&sumcheck_instances, &masks) {
419 for column in mask.columns() {
420 challenger.observe_slice(column);
421 }
422 layer_masks_by_instance[instance].push(mask.clone());
423 }
424
425 let challenge = challenger.sample();
426 ood_point = sumcheck_ood_point;
427 ood_point.push(challenge);
428
429 for (instance, mask) in zip(sumcheck_instances, masks) {
431 claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge));
432 }
433 }
434
435 let output_claims_by_instance = output_claims_by_instance
436 .into_iter()
437 .map(Option::unwrap)
438 .collect();
439
440 let claims_to_verify_by_instance = claims_to_verify_by_instance
441 .into_iter()
442 .map(Option::unwrap)
443 .collect();
444
445 let proof = GkrBatchProof {
446 sumcheck_proofs,
447 layer_masks_by_instance,
448 output_claims_by_instance,
449 };
450
451 let artifact = GkrArtifact {
452 ood_point,
453 claims_to_verify_by_instance,
454 n_variables_by_instance: n_layers_by_instance,
455 };
456
457 (proof, artifact)
458}
459
460fn gen_layers<F: Field>(input_layer: Layer<F>) -> Vec<Layer<F>> {
462 let n_variables = input_layer.n_variables();
463 let layers = successors(Some(input_layer), |layer| layer.next_layer()).collect_vec();
464 assert_eq!(layers.len(), n_variables + 1);
465 layers
466}
467
468pub fn correct_sum_as_poly_in_first_variable<F: Field>(
476 f_at_0: F,
477 f_at_2: F,
478 claim: F,
479 y: &[F],
480 k: usize,
481) -> UnivariatePolynomial<F> {
482 assert_ne!(k, 0);
483 let n = y.len();
484 assert!(k <= n);
485
486 let a_const = hypercube_eq(&vec![F::ZERO; n - k + 1], &y[..n - k + 1]).inverse();
489
490 let b_const = (F::ONE - y[n - k]) / (F::ONE - y[n - k].double());
497
498 let r_at_0 = f_at_0 * hypercube_eq(&[F::ZERO], &[y[n - k]]) * a_const;
500 let r_at_1 = claim - r_at_0;
501 let r_at_2 = f_at_2 * hypercube_eq(&[F::TWO], &[y[n - k]]) * a_const;
502
503 UnivariatePolynomial::from_interpolation(&[
505 (F::ZERO, r_at_0),
506 (F::ONE, r_at_1),
507 (F::TWO, r_at_2),
508 (b_const, F::ZERO),
509 ])
510}
511
512#[cfg(test)]
513mod tests {
514 use p3_baby_bear::BabyBear;
515 use p3_field::FieldAlgebra;
516 use rand::Rng;
517
518 use crate::{gkr::prover::HypercubeEqEvals, poly::multi::hypercube_eq};
519
520 #[test]
521 fn test_gen_eq_evals() {
522 type F = BabyBear;
523
524 let mut rng = rand::thread_rng();
525
526 let v: F = rng.gen();
527 let y: Vec<F> = vec![rng.gen(), rng.gen(), rng.gen()];
528
529 let eq_evals = HypercubeEqEvals::gen(&y, v);
530
531 assert_eq!(
532 *eq_evals,
533 [
534 hypercube_eq(&[F::ZERO, F::ZERO, F::ZERO], &y) * v,
535 hypercube_eq(&[F::ZERO, F::ZERO, F::ONE], &y) * v,
536 hypercube_eq(&[F::ZERO, F::ONE, F::ZERO], &y) * v,
537 hypercube_eq(&[F::ZERO, F::ONE, F::ONE], &y) * v,
538 hypercube_eq(&[F::ONE, F::ZERO, F::ZERO], &y) * v,
539 hypercube_eq(&[F::ONE, F::ZERO, F::ONE], &y) * v,
540 hypercube_eq(&[F::ONE, F::ONE, F::ZERO], &y) * v,
541 hypercube_eq(&[F::ONE, F::ONE, F::ONE], &y) * v,
542 ]
543 );
544 }
545}