openvm_stark_backend/gkr/
prover.rs

1//! Copied from starkware-libs/stwo under Apache-2.0 license.
2//! GKR batch prover for Grand Product and LogUp lookup arguments.
3use std::{
4    iter::{successors, zip},
5    ops::Deref,
6};
7
8use itertools::Itertools;
9use p3_challenger::FieldChallenger;
10use p3_field::Field;
11use thiserror::Error;
12
13use crate::{
14    gkr::types::{GkrArtifact, GkrBatchProof, GkrMask, Layer},
15    poly::{
16        multi::{hypercube_eq, Mle, MultivariatePolyOracle},
17        uni::{random_linear_combination, UnivariatePolynomial},
18    },
19    sumcheck,
20    sumcheck::SumcheckArtifacts,
21};
22
23/// For a given `y`, stores evaluations of [hypercube_eq](x, y) on all 2^{n-1} boolean hypercube
24/// points of the form `x = (0, x_2, ..., x_n)`.
25///
26/// Evaluations are stored in lexicographic order i.e. `evals[0] = eq((0, ..., 0, 0), y)`,
27/// `evals[1] = eq((0, ..., 0, 1), y)`, etc.
28#[derive(Debug, Clone)]
29struct HypercubeEqEvals<F> {
30    y: Vec<F>,
31    evals: Vec<F>,
32}
33
34impl<F: Field> HypercubeEqEvals<F> {
35    pub fn eval(y: &[F]) -> Self {
36        let y = y.to_vec();
37
38        if y.is_empty() {
39            let evals = vec![F::ONE];
40            return Self { evals, y };
41        }
42
43        // Compute evaluations for when x_0 = 0.
44        let evals = Self::gen(&y[1..], F::ONE - y[0]);
45        assert_eq!(evals.len(), 1 << (y.len() - 1));
46        Self { evals, y }
47    }
48
49    /// Returns evaluations of the function `x -> eq(x, y) * v` for each `x` in `{0, 1}^n`.
50    fn gen(y: &[F], v: F) -> Vec<F> {
51        let mut evals = Vec::with_capacity(1 << y.len());
52        evals.push(v);
53
54        for &y_i in y.iter().rev() {
55            for j in 0..evals.len() {
56                // `lhs[j] = eq(0, y_i) * c[i]`
57                // `rhs[j] = eq(1, y_i) * c[i]`
58                let tmp = evals[j] * y_i;
59                evals.push(tmp);
60                evals[j] -= tmp;
61            }
62        }
63
64        evals
65    }
66}
67
68impl<F> Deref for HypercubeEqEvals<F> {
69    type Target = [F];
70
71    fn deref(&self) -> &Self::Target {
72        self.evals.deref()
73    }
74}
75
76/// Multivariate polynomial `P` that expresses the relation between two consecutive GKR layers.
77///
78/// When the input layer is [`Layer::GrandProduct`] (represented by multilinear column `inp`)
79/// the polynomial represents:
80///
81/// ```text
82/// P(x) = eq(x, y) * inp(x, 0) * inp(x, 1)
83/// ```
84///
85/// When the input layer is LogUp (represented by multilinear columns `inp_numer` and
86/// `inp_denom`) the polynomial represents:
87///
88/// ```text
89/// numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)
90/// denom(x) = inp_denom(x, 0) * inp_denom(x, 1)
91///
92/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x))
93/// ```
94struct GkrMultivariatePolyOracle<'a, F: Clone> {
95    pub eq_evals: &'a HypercubeEqEvals<F>,
96    pub input_layer: Layer<F>,
97    pub eq_fixed_var_correction: F,
98    /// Used by LogUp to perform a random linear combination of the numerators and denominators.
99    pub lambda: F,
100}
101
102impl<F: Field> MultivariatePolyOracle<F> for GkrMultivariatePolyOracle<'_, F> {
103    fn arity(&self) -> usize {
104        self.input_layer.n_variables() - 1
105    }
106
107    fn marginalize_first(&self, claim: F) -> UnivariatePolynomial<F> {
108        let n_variables = self.arity();
109        assert_ne!(n_variables, 0);
110        let n_terms = 1 << (n_variables - 1);
111        // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
112        let y = &self.eq_evals.y;
113        let lambda = self.lambda;
114
115        let (mut eval_at_0, mut eval_at_2) = match &self.input_layer {
116            Layer::GrandProduct(col) => eval_grand_product_sum(self.eq_evals, col, n_terms),
117            Layer::LogUpGeneric {
118                numerators,
119                denominators,
120            }
121            | Layer::LogUpMultiplicities {
122                numerators,
123                denominators,
124            } => eval_logup_sum(self.eq_evals, numerators, denominators, n_terms, lambda),
125            Layer::LogUpSingles { denominators } => {
126                eval_logup_singles_sum(self.eq_evals, denominators, n_terms, lambda)
127            }
128        };
129
130        eval_at_0 *= self.eq_fixed_var_correction;
131        eval_at_2 *= self.eq_fixed_var_correction;
132        correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
133    }
134
135    fn partial_evaluation(self, alpha: F) -> Self {
136        if self.is_constant() {
137            return self;
138        }
139
140        let z0 = self.eq_evals.y[self.eq_evals.y.len() - self.arity()];
141        let eq_fixed_var_correction = self.eq_fixed_var_correction * hypercube_eq(&[alpha], &[z0]);
142
143        Self {
144            eq_evals: self.eq_evals,
145            eq_fixed_var_correction,
146            input_layer: self.input_layer.fix_first_variable(alpha),
147            lambda: self.lambda,
148        }
149    }
150}
151
152/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
153///
154/// Output of the form: `(eval_at_0, eval_at_2)`.
155fn eval_grand_product_sum<F: Field>(
156    eq_evals: &HypercubeEqEvals<F>,
157    input_layer: &Mle<F>,
158    n_terms: usize,
159) -> (F, F) {
160    let mut eval_at_0 = F::ZERO;
161    let mut eval_at_2 = F::ZERO;
162
163    for i in 0..n_terms {
164        // Input polynomial values at (r, {0, 1, 2}, bits(i), {0, 1})
165        let (inp_r0_0, inp_r0_1) = (input_layer[i * 2], input_layer[i * 2 + 1]);
166        let (inp_r1_0, inp_r1_1) = (
167            input_layer[(n_terms + i) * 2],
168            input_layer[(n_terms + i) * 2 + 1],
169        );
170
171        // Calculate values at t = 2
172        let inp_r2_0 = inp_r1_0.double() - inp_r0_0;
173        let inp_r2_1 = inp_r1_1.double() - inp_r0_1;
174
175        // Product polynomials at t = 0 and t = 2
176        let prod_at_r0i = inp_r0_0 * inp_r0_1;
177        let prod_at_r2i = inp_r2_0 * inp_r2_1;
178
179        // Accumulate evaluated terms
180        let eq_eval_at_0i = eq_evals[i];
181        eval_at_0 += eq_eval_at_0i * prod_at_r0i;
182        eval_at_2 += eq_eval_at_0i * prod_at_r2i;
183    }
184
185    (eval_at_0, eval_at_2)
186}
187
188/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_numer(r, t, x, 0) * inp_denom(r, t, x, 1) +
189/// inp_numer(r, t, x, 1) * inp_denom(r, t, x, 0) + lambda * inp_denom(r, t, x, 0) * inp_denom(r, t,
190/// x, 1))` at `t=0` and `t=2`.
191///
192/// Output of the form: `(eval_at_0, eval_at_2)`.
193fn eval_logup_sum<F: Field>(
194    eq_evals: &HypercubeEqEvals<F>,
195    input_numerators: &Mle<F>,
196    input_denominators: &Mle<F>,
197    n_terms: usize,
198    lambda: F,
199) -> (F, F) {
200    let mut eval_at_0 = F::ZERO;
201    let mut eval_at_2 = F::ZERO;
202
203    for i in 0..n_terms {
204        // Gather input values at (r, {0, 1, 2}, bits(i), {0, 1})
205        let (numer_r0_0, denom_r0_0) = (input_numerators[i * 2], input_denominators[i * 2]);
206        let (numer_r0_1, denom_r0_1) = (input_numerators[i * 2 + 1], input_denominators[i * 2 + 1]);
207        let (numer_r1_0, denom_r1_0) = (
208            input_numerators[(n_terms + i) * 2],
209            input_denominators[(n_terms + i) * 2],
210        );
211        let (numer_r1_1, denom_r1_1) = (
212            input_numerators[(n_terms + i) * 2 + 1],
213            input_denominators[(n_terms + i) * 2 + 1],
214        );
215
216        // Calculate values at r, t = 2
217        let numer_r2_0 = numer_r1_0.double() - numer_r0_0;
218        let denom_r2_0 = denom_r1_0.double() - denom_r0_0;
219        let numer_r2_1 = numer_r1_1.double() - numer_r0_1;
220        let denom_r2_1 = denom_r1_1.double() - denom_r0_1;
221
222        // Compute fractions at t = 0 and t = 2
223        let numer_at_r0i = numer_r0_0 * denom_r0_1 + numer_r0_1 * denom_r0_0;
224        let denom_at_r0i = denom_r0_1 * denom_r0_0;
225        let numer_at_r2i = numer_r2_0 * denom_r2_1 + numer_r2_1 * denom_r2_0;
226        let denom_at_r2i = denom_r2_1 * denom_r2_0;
227
228        // Accumulate the evaluated terms
229        let eq_eval_at_0i = eq_evals[i];
230        eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
231        eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i);
232    }
233
234    (eval_at_0, eval_at_2)
235}
236
237/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) +
238/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`.
239///
240/// Output of the form: `(eval_at_0, eval_at_2)`.
241fn eval_logup_singles_sum<F: Field>(
242    eq_evals: &HypercubeEqEvals<F>,
243    input_denominators: &Mle<F>,
244    n_terms: usize,
245    lambda: F,
246) -> (F, F) {
247    let mut eval_at_0 = F::ZERO;
248    let mut eval_at_2 = F::ZERO;
249
250    for i in 0..n_terms {
251        // Input denominator values at (r, {0, 1, 2}, bits(i), {0, 1})
252        let (inp_denom_r0_0, inp_denom_r0_1) =
253            (input_denominators[i * 2], input_denominators[i * 2 + 1]);
254        let (inp_denom_r1_0, inp_denom_r1_1) = (
255            input_denominators[(n_terms + i) * 2],
256            input_denominators[(n_terms + i) * 2 + 1],
257        );
258
259        // Calculate values at t = 2
260        let inp_denom_r2_0 = inp_denom_r1_0.double() - inp_denom_r0_0;
261        let inp_denom_r2_1 = inp_denom_r1_1.double() - inp_denom_r0_1;
262
263        // Fraction addition polynomials at t = 0 and t = 2
264        let numer_at_r0i = inp_denom_r0_0 + inp_denom_r0_1;
265        let denom_at_r0i = inp_denom_r0_0 * inp_denom_r0_1;
266        let numer_at_r2i = inp_denom_r2_0 + inp_denom_r2_1;
267        let denom_at_r2i = inp_denom_r2_0 * inp_denom_r2_1;
268
269        // Accumulate evaluated terms
270        let eq_eval_at_0i = eq_evals[i];
271        eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
272        eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i);
273    }
274
275    (eval_at_0, eval_at_2)
276}
277
278impl<F: Field> GkrMultivariatePolyOracle<'_, F> {
279    fn is_constant(&self) -> bool {
280        self.arity() == 0
281    }
282
283    /// Returns all input layer columns restricted to a line.
284    ///
285    /// Let `l` be the line satisfying `l(0) = b*` and `l(1) = c*`. Oracles that represent constants
286    /// are expressed by values `c_i(b*)` and `c_i(c*)` where `c_i` represents the input GKR layer's
287    /// `i`th column (for binary tree GKR `b* = (r, 0)`, `c* = (r, 1)`).
288    ///
289    /// If this oracle represents a constant, then each `c_i` restricted to `l` is returned.
290    /// Otherwise, an [`Err`] is returned.
291    ///
292    /// For more context see <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> page 64.
293    fn try_into_mask(self) -> Result<GkrMask<F>, NotConstantPolyError> {
294        if !self.is_constant() {
295            return Err(NotConstantPolyError);
296        }
297
298        let columns = match self.input_layer {
299            Layer::GrandProduct(mle) => vec![mle.as_ref().try_into().unwrap()],
300            Layer::LogUpGeneric {
301                numerators,
302                denominators,
303            } => {
304                let numerators = numerators.as_ref().try_into().unwrap();
305                let denominators = denominators.as_ref().try_into().unwrap();
306                vec![numerators, denominators]
307            }
308            // Should never get called.
309            Layer::LogUpMultiplicities { .. } => unimplemented!(),
310            Layer::LogUpSingles { denominators } => {
311                let numerators = [F::ONE; 2];
312                let denominators = denominators.as_ref().try_into().unwrap();
313                vec![numerators, denominators]
314            }
315        };
316
317        Ok(GkrMask::new(columns))
318    }
319}
320
321/// Error returned when a polynomial is expected to be constant but it is not.
322#[derive(Debug, Error)]
323#[error("polynomial is not constant")]
324pub struct NotConstantPolyError;
325
326/// Batch proves lookup circuits with GKR.
327///
328/// The input layers should be committed to the channel before calling this function.
329// GKR algorithm: <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> (page 64)
330pub fn prove_batch<F: Field>(
331    challenger: &mut impl FieldChallenger<F>,
332    input_layer_by_instance: Vec<Layer<F>>,
333) -> (GkrBatchProof<F>, GkrArtifact<F>) {
334    let n_instances = input_layer_by_instance.len();
335    let n_layers_by_instance = input_layer_by_instance
336        .iter()
337        .map(|l| l.n_variables())
338        .collect_vec();
339    let n_layers = *n_layers_by_instance.iter().max().unwrap();
340
341    // Evaluate all instance circuits and collect the layer values.
342    let mut layers_by_instance = input_layer_by_instance
343        .into_iter()
344        .map(|input_layer| gen_layers(input_layer).into_iter().rev())
345        .collect_vec();
346
347    let mut output_claims_by_instance = vec![None; n_instances];
348    let mut layer_masks_by_instance = (0..n_instances).map(|_| Vec::new()).collect_vec();
349    let mut sumcheck_proofs = Vec::new();
350
351    let mut ood_point = Vec::new();
352    let mut claims_to_verify_by_instance = vec![None; n_instances];
353
354    for layer in 0..n_layers {
355        let n_remaining_layers = n_layers - layer;
356
357        // Check all the instances for output layers.
358        for (instance, layers) in layers_by_instance.iter_mut().enumerate() {
359            if n_layers_by_instance[instance] == n_remaining_layers {
360                let output_layer = layers.next().unwrap();
361                let output_layer_values = output_layer.try_into_output_layer_values().unwrap();
362                claims_to_verify_by_instance[instance] = Some(output_layer_values.clone());
363                output_claims_by_instance[instance] = Some(output_layer_values);
364            }
365        }
366
367        // Seed the channel with layer claims.
368        for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
369            challenger.observe_slice(claims_to_verify);
370        }
371
372        let eq_evals = HypercubeEqEvals::eval(&ood_point);
373        let sumcheck_alpha = challenger.sample();
374        let instance_lambda = challenger.sample();
375
376        let mut sumcheck_oracles = Vec::new();
377        let mut sumcheck_claims = Vec::new();
378        let mut sumcheck_instances = Vec::new();
379
380        // Create the multivariate polynomial oracles used with sumcheck.
381        for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() {
382            if let Some(claims_to_verify) = claims_to_verify {
383                let layer = layers_by_instance[instance].next().unwrap();
384
385                sumcheck_oracles.push(GkrMultivariatePolyOracle {
386                    eq_evals: &eq_evals,
387                    input_layer: layer,
388                    eq_fixed_var_correction: F::ONE,
389                    lambda: instance_lambda,
390                });
391                sumcheck_claims.push(random_linear_combination(claims_to_verify, instance_lambda));
392                sumcheck_instances.push(instance);
393            }
394        }
395
396        let (
397            sumcheck_proof,
398            SumcheckArtifacts {
399                evaluation_point: sumcheck_ood_point,
400                constant_poly_oracles,
401                ..
402            },
403        ) = sumcheck::prove_batch(
404            sumcheck_claims,
405            sumcheck_oracles,
406            sumcheck_alpha,
407            challenger,
408        );
409
410        sumcheck_proofs.push(sumcheck_proof);
411
412        let masks = constant_poly_oracles
413            .into_iter()
414            .map(|oracle| oracle.try_into_mask().unwrap())
415            .collect_vec();
416
417        // Seed the channel with the layer masks.
418        for (&instance, mask) in zip(&sumcheck_instances, &masks) {
419            for column in mask.columns() {
420                challenger.observe_slice(column);
421            }
422            layer_masks_by_instance[instance].push(mask.clone());
423        }
424
425        let challenge = challenger.sample();
426        ood_point = sumcheck_ood_point;
427        ood_point.push(challenge);
428
429        // Set the claims to prove in the layer above.
430        for (instance, mask) in zip(sumcheck_instances, masks) {
431            claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge));
432        }
433    }
434
435    let output_claims_by_instance = output_claims_by_instance
436        .into_iter()
437        .map(Option::unwrap)
438        .collect();
439
440    let claims_to_verify_by_instance = claims_to_verify_by_instance
441        .into_iter()
442        .map(Option::unwrap)
443        .collect();
444
445    let proof = GkrBatchProof {
446        sumcheck_proofs,
447        layer_masks_by_instance,
448        output_claims_by_instance,
449    };
450
451    let artifact = GkrArtifact {
452        ood_point,
453        claims_to_verify_by_instance,
454        n_variables_by_instance: n_layers_by_instance,
455    };
456
457    (proof, artifact)
458}
459
460/// Executes the GKR circuit on the input layer and returns all the circuit's layers.
461fn gen_layers<F: Field>(input_layer: Layer<F>) -> Vec<Layer<F>> {
462    let n_variables = input_layer.n_variables();
463    let layers = successors(Some(input_layer), |layer| layer.next_layer()).collect_vec();
464    assert_eq!(layers.len(), n_variables + 1);
465    layers
466}
467
468/// Computes `r(t) = sum_x eq((t, x), y[-k:]) * p(t, x)` from evaluations of
469/// `f(t) = sum_x eq(({0}^(n - k), 0, x), y) * p(t, x)`.
470///
471/// Note `claim` must equal `r(0) + r(1)` and `r` must have degree <= 3.
472///
473/// For more context see `Layer::into_multivariate_poly()` docs.
474/// See also <https://ia.cr/2024/108> (section 3.2).
475pub fn correct_sum_as_poly_in_first_variable<F: Field>(
476    f_at_0: F,
477    f_at_2: F,
478    claim: F,
479    y: &[F],
480    k: usize,
481) -> UnivariatePolynomial<F> {
482    assert_ne!(k, 0);
483    let n = y.len();
484    assert!(k <= n);
485
486    // We evaluated `f(0)` and `f(2)` - the inputs.
487    // We want to compute `r(t) = f(t) * eq(t, y[n - k]) / eq(0, y[:n - k + 1])`.
488    let a_const = hypercube_eq(&vec![F::ZERO; n - k + 1], &y[..n - k + 1]).inverse();
489
490    // Find the additional root of `r(t)`, by finding the root of `eq(t, y[n - k])`:
491    //    0 = eq(t, y[n - k])
492    //      = t * y[n - k] + (1 - t)(1 - y[n - k])
493    //      = 1 - y[n - k] - t(1 - 2 * y[n - k])
494    // => t = (1 - y[n - k]) / (1 - 2 * y[n - k])
495    //      = b
496    let b_const = (F::ONE - y[n - k]) / (F::ONE - y[n - k].double());
497
498    // We get that `r(t) = f(t) * eq(t, y[n - k]) * a`.
499    let r_at_0 = f_at_0 * hypercube_eq(&[F::ZERO], &[y[n - k]]) * a_const;
500    let r_at_1 = claim - r_at_0;
501    let r_at_2 = f_at_2 * hypercube_eq(&[F::TWO], &[y[n - k]]) * a_const;
502
503    // Interpolate.
504    UnivariatePolynomial::from_interpolation(&[
505        (F::ZERO, r_at_0),
506        (F::ONE, r_at_1),
507        (F::TWO, r_at_2),
508        (b_const, F::ZERO),
509    ])
510}
511
512#[cfg(test)]
513mod tests {
514    use p3_baby_bear::BabyBear;
515    use p3_field::FieldAlgebra;
516    use rand::Rng;
517
518    use crate::{gkr::prover::HypercubeEqEvals, poly::multi::hypercube_eq};
519
520    #[test]
521    fn test_gen_eq_evals() {
522        type F = BabyBear;
523
524        let mut rng = rand::thread_rng();
525
526        let v: F = rng.gen();
527        let y: Vec<F> = vec![rng.gen(), rng.gen(), rng.gen()];
528
529        let eq_evals = HypercubeEqEvals::gen(&y, v);
530
531        assert_eq!(
532            *eq_evals,
533            [
534                hypercube_eq(&[F::ZERO, F::ZERO, F::ZERO], &y) * v,
535                hypercube_eq(&[F::ZERO, F::ZERO, F::ONE], &y) * v,
536                hypercube_eq(&[F::ZERO, F::ONE, F::ZERO], &y) * v,
537                hypercube_eq(&[F::ZERO, F::ONE, F::ONE], &y) * v,
538                hypercube_eq(&[F::ONE, F::ZERO, F::ZERO], &y) * v,
539                hypercube_eq(&[F::ONE, F::ZERO, F::ONE], &y) * v,
540                hypercube_eq(&[F::ONE, F::ONE, F::ZERO], &y) * v,
541                hypercube_eq(&[F::ONE, F::ONE, F::ONE], &y) * v,
542            ]
543        );
544    }
545}