openvm_stark_backend/gkr/
verifier.rs

1//! Copied from starkware-libs/stwo under Apache-2.0 license.
2//! GKR batch verifier for Grand Product and LogUp lookup arguments.
3
4use p3_challenger::FieldChallenger;
5use p3_field::Field;
6
7use crate::{
8    gkr::{
9        gate::Gate,
10        types::{GkrArtifact, GkrBatchProof, GkrError},
11    },
12    poly::{multi::hypercube_eq, uni::random_linear_combination},
13    sumcheck,
14};
15
16/// Partially verifies a batch GKR proof.
17///
18/// On successful verification the function returns a [`GkrArtifact`] which stores the out-of-domain
19/// point and claimed evaluations in the input layer columns for each instance at the OOD point.
20/// These claimed evaluations are not checked in this function - hence partial verification.
21pub fn partially_verify_batch<F: Field>(
22    gate_by_instance: Vec<Gate>,
23    proof: &GkrBatchProof<F>,
24    challenger: &mut impl FieldChallenger<F>,
25) -> Result<GkrArtifact<F>, GkrError<F>> {
26    let GkrBatchProof {
27        sumcheck_proofs,
28        layer_masks_by_instance,
29        output_claims_by_instance,
30    } = proof;
31
32    if layer_masks_by_instance.len() != output_claims_by_instance.len() {
33        return Err(GkrError::MalformedProof);
34    }
35
36    let n_instances = layer_masks_by_instance.len();
37    let instance_n_layers = |instance: usize| layer_masks_by_instance[instance].len();
38    let n_layers = (0..n_instances).map(instance_n_layers).max().unwrap();
39
40    if n_layers != sumcheck_proofs.len() {
41        return Err(GkrError::MalformedProof);
42    }
43
44    if gate_by_instance.len() != n_instances {
45        return Err(GkrError::NumInstancesMismatch {
46            given: gate_by_instance.len(),
47            proof: n_instances,
48        });
49    }
50
51    let mut ood_point = vec![];
52    let mut claims_to_verify_by_instance = vec![None; n_instances];
53
54    for (layer, sumcheck_proof) in sumcheck_proofs.iter().enumerate() {
55        let n_remaining_layers = n_layers - layer;
56
57        // Check for output layers.
58        for instance in 0..n_instances {
59            if instance_n_layers(instance) == n_remaining_layers {
60                let output_claims = output_claims_by_instance[instance].clone();
61                claims_to_verify_by_instance[instance] = Some(output_claims);
62            }
63        }
64
65        // Seed the channel with layer claims.
66        for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
67            challenger.observe_slice(claims_to_verify);
68        }
69
70        let sumcheck_alpha = challenger.sample();
71        let instance_lambda = challenger.sample();
72
73        let mut sumcheck_claims = Vec::new();
74        let mut sumcheck_instances = Vec::new();
75
76        // Prepare the sumcheck claim.
77        for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() {
78            if let Some(claims_to_verify) = claims_to_verify {
79                let n_unused_variables = n_layers - instance_n_layers(instance);
80                let doubling_factor = F::from_canonical_u32(1 << n_unused_variables);
81                let claim =
82                    random_linear_combination(claims_to_verify, instance_lambda) * doubling_factor;
83                sumcheck_claims.push(claim);
84                sumcheck_instances.push(instance);
85            }
86        }
87
88        let sumcheck_claim = random_linear_combination(&sumcheck_claims, sumcheck_alpha);
89        let (sumcheck_ood_point, sumcheck_eval) =
90            sumcheck::partially_verify(sumcheck_claim, sumcheck_proof, challenger)
91                .map_err(|source| GkrError::InvalidSumcheck { layer, source })?;
92
93        let mut layer_evals = Vec::new();
94
95        // Evaluate the circuit locally at sumcheck OOD point.
96        for &instance in &sumcheck_instances {
97            let n_unused = n_layers - instance_n_layers(instance);
98            let mask = &layer_masks_by_instance[instance][layer - n_unused];
99            let gate = &gate_by_instance[instance];
100            let gate_output = gate.eval(mask).map_err(|_| {
101                let instance_layer = instance_n_layers(layer) - n_remaining_layers;
102                GkrError::InvalidMask {
103                    instance,
104                    instance_layer,
105                }
106            })?;
107            // TODO: Consider simplifying the code by just using the same eq eval for all instances
108            // regardless of size.
109            let eq_eval = hypercube_eq(&ood_point[n_unused..], &sumcheck_ood_point[n_unused..]);
110            layer_evals.push(eq_eval * random_linear_combination(&gate_output, instance_lambda));
111        }
112
113        let layer_eval = random_linear_combination(&layer_evals, sumcheck_alpha);
114
115        if sumcheck_eval != layer_eval {
116            return Err(GkrError::CircuitCheckFailure {
117                claim: sumcheck_eval,
118                output: layer_eval,
119                layer,
120            });
121        }
122
123        // Seed the channel with the layer masks.
124        for &instance in &sumcheck_instances {
125            let n_unused = n_layers - instance_n_layers(instance);
126            let mask = &layer_masks_by_instance[instance][layer - n_unused];
127            for column in mask.columns() {
128                challenger.observe_slice(column);
129            }
130        }
131
132        // Set the OOD evaluation point for layer above.
133        let challenge = challenger.sample();
134        ood_point = sumcheck_ood_point;
135        ood_point.push(challenge);
136
137        // Set the claims to verify in the layer above.
138        for instance in sumcheck_instances {
139            let n_unused = n_layers - instance_n_layers(instance);
140            let mask = &layer_masks_by_instance[instance][layer - n_unused];
141            claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge));
142        }
143    }
144
145    let claims_to_verify_by_instance = claims_to_verify_by_instance
146        .into_iter()
147        .map(Option::unwrap)
148        .collect();
149
150    Ok(GkrArtifact {
151        ood_point,
152        claims_to_verify_by_instance,
153        n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(),
154    })
155}