openvm_stark_backend/gkr/
verifier.rs
1use p3_challenger::FieldChallenger;
5use p3_field::Field;
6
7use crate::{
8 gkr::{
9 gate::Gate,
10 types::{GkrArtifact, GkrBatchProof, GkrError},
11 },
12 poly::{multi::hypercube_eq, uni::random_linear_combination},
13 sumcheck,
14};
15
16pub fn partially_verify_batch<F: Field>(
22 gate_by_instance: Vec<Gate>,
23 proof: &GkrBatchProof<F>,
24 challenger: &mut impl FieldChallenger<F>,
25) -> Result<GkrArtifact<F>, GkrError<F>> {
26 let GkrBatchProof {
27 sumcheck_proofs,
28 layer_masks_by_instance,
29 output_claims_by_instance,
30 } = proof;
31
32 if layer_masks_by_instance.len() != output_claims_by_instance.len() {
33 return Err(GkrError::MalformedProof);
34 }
35
36 let n_instances = layer_masks_by_instance.len();
37 let instance_n_layers = |instance: usize| layer_masks_by_instance[instance].len();
38 let n_layers = (0..n_instances).map(instance_n_layers).max().unwrap();
39
40 if n_layers != sumcheck_proofs.len() {
41 return Err(GkrError::MalformedProof);
42 }
43
44 if gate_by_instance.len() != n_instances {
45 return Err(GkrError::NumInstancesMismatch {
46 given: gate_by_instance.len(),
47 proof: n_instances,
48 });
49 }
50
51 let mut ood_point = vec![];
52 let mut claims_to_verify_by_instance = vec![None; n_instances];
53
54 for (layer, sumcheck_proof) in sumcheck_proofs.iter().enumerate() {
55 let n_remaining_layers = n_layers - layer;
56
57 for instance in 0..n_instances {
59 if instance_n_layers(instance) == n_remaining_layers {
60 let output_claims = output_claims_by_instance[instance].clone();
61 claims_to_verify_by_instance[instance] = Some(output_claims);
62 }
63 }
64
65 for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
67 challenger.observe_slice(claims_to_verify);
68 }
69
70 let sumcheck_alpha = challenger.sample();
71 let instance_lambda = challenger.sample();
72
73 let mut sumcheck_claims = Vec::new();
74 let mut sumcheck_instances = Vec::new();
75
76 for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() {
78 if let Some(claims_to_verify) = claims_to_verify {
79 let n_unused_variables = n_layers - instance_n_layers(instance);
80 let doubling_factor = F::from_canonical_u32(1 << n_unused_variables);
81 let claim =
82 random_linear_combination(claims_to_verify, instance_lambda) * doubling_factor;
83 sumcheck_claims.push(claim);
84 sumcheck_instances.push(instance);
85 }
86 }
87
88 let sumcheck_claim = random_linear_combination(&sumcheck_claims, sumcheck_alpha);
89 let (sumcheck_ood_point, sumcheck_eval) =
90 sumcheck::partially_verify(sumcheck_claim, sumcheck_proof, challenger)
91 .map_err(|source| GkrError::InvalidSumcheck { layer, source })?;
92
93 let mut layer_evals = Vec::new();
94
95 for &instance in &sumcheck_instances {
97 let n_unused = n_layers - instance_n_layers(instance);
98 let mask = &layer_masks_by_instance[instance][layer - n_unused];
99 let gate = &gate_by_instance[instance];
100 let gate_output = gate.eval(mask).map_err(|_| {
101 let instance_layer = instance_n_layers(layer) - n_remaining_layers;
102 GkrError::InvalidMask {
103 instance,
104 instance_layer,
105 }
106 })?;
107 let eq_eval = hypercube_eq(&ood_point[n_unused..], &sumcheck_ood_point[n_unused..]);
110 layer_evals.push(eq_eval * random_linear_combination(&gate_output, instance_lambda));
111 }
112
113 let layer_eval = random_linear_combination(&layer_evals, sumcheck_alpha);
114
115 if sumcheck_eval != layer_eval {
116 return Err(GkrError::CircuitCheckFailure {
117 claim: sumcheck_eval,
118 output: layer_eval,
119 layer,
120 });
121 }
122
123 for &instance in &sumcheck_instances {
125 let n_unused = n_layers - instance_n_layers(instance);
126 let mask = &layer_masks_by_instance[instance][layer - n_unused];
127 for column in mask.columns() {
128 challenger.observe_slice(column);
129 }
130 }
131
132 let challenge = challenger.sample();
134 ood_point = sumcheck_ood_point;
135 ood_point.push(challenge);
136
137 for instance in sumcheck_instances {
139 let n_unused = n_layers - instance_n_layers(instance);
140 let mask = &layer_masks_by_instance[instance][layer - n_unused];
141 claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge));
142 }
143 }
144
145 let claims_to_verify_by_instance = claims_to_verify_by_instance
146 .into_iter()
147 .map(Option::unwrap)
148 .collect();
149
150 Ok(GkrArtifact {
151 ood_point,
152 claims_to_verify_by_instance,
153 n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(),
154 })
155}