openvm_stark_backend/gkr/
types.rs
1use std::ops::Index;
2
3use p3_field::Field;
4use thiserror::Error;
5
6use crate::{
7 poly::{
8 multi::{fold_mle_evals, Mle, MultivariatePolyOracle},
9 uni::Fraction,
10 },
11 sumcheck::{SumcheckError, SumcheckProof},
12};
13
14pub struct GkrBatchProof<F> {
16 pub sumcheck_proofs: Vec<SumcheckProof<F>>,
18 pub layer_masks_by_instance: Vec<Vec<GkrMask<F>>>,
20 pub output_claims_by_instance: Vec<Vec<F>>,
22}
23
24pub struct GkrArtifact<F> {
26 pub ood_point: Vec<F>,
28 pub claims_to_verify_by_instance: Vec<Vec<F>>,
30 pub n_variables_by_instance: Vec<usize>,
32}
33
34#[derive(Debug, Clone)]
36pub struct GkrMask<F> {
37 columns: Vec<[F; 2]>,
38}
39
40impl<F> GkrMask<F> {
41 pub fn new(columns: Vec<[F; 2]>) -> Self {
42 Self { columns }
43 }
44
45 pub fn columns(&self) -> &[[F; 2]] {
46 &self.columns
47 }
48}
49
50impl<F: Field> GkrMask<F> {
51 pub fn to_rows(&self) -> [Vec<F>; 2] {
52 self.columns.iter().map(|[a, b]| (a, b)).unzip().into()
53 }
54
55 pub fn reduce_at_point(&self, x: F) -> Vec<F> {
57 self.columns
58 .iter()
59 .map(|&[v0, v1]| fold_mle_evals(x, v0, v1))
60 .collect()
61 }
62}
63
64#[derive(Error, Debug)]
66pub enum GkrError<F> {
67 #[error("proof data is invalid")]
69 MalformedProof,
70 #[error("mask in layer {instance_layer} of instance {instance} is invalid")]
72 InvalidMask {
73 instance: usize,
74 instance_layer: LayerIndex,
76 },
77 #[error("provided an invalid number of instances (given {given}, proof expects {proof})")]
80 NumInstancesMismatch { given: usize, proof: usize },
81 #[error("sum-check invalid in layer {layer}: {source}")]
83 InvalidSumcheck {
84 layer: LayerIndex,
85 source: SumcheckError<F>,
86 },
87 #[error("circuit check failed in layer {layer} (calculated {output}, claim {claim})")]
89 CircuitCheckFailure {
90 claim: F,
91 output: F,
92 layer: LayerIndex,
93 },
94}
95
96pub type LayerIndex = usize;
98
99#[derive(Debug, Clone)]
106pub enum Layer<F> {
107 GrandProduct(Mle<F>),
108 LogUpGeneric {
109 numerators: Mle<F>,
110 denominators: Mle<F>,
111 },
112 LogUpMultiplicities {
113 numerators: Mle<F>,
114 denominators: Mle<F>,
115 },
116 LogUpSingles {
118 denominators: Mle<F>,
119 },
120}
121
122impl<F: Field> Layer<F> {
123 pub fn n_variables(&self) -> usize {
125 match self {
126 Self::GrandProduct(mle)
127 | Self::LogUpSingles { denominators: mle }
128 | Self::LogUpMultiplicities {
129 denominators: mle, ..
130 }
131 | Self::LogUpGeneric {
132 denominators: mle, ..
133 } => mle.arity(),
134 }
135 }
136
137 fn is_output_layer(&self) -> bool {
138 self.n_variables() == 0
139 }
140
141 pub fn next_layer(&self) -> Option<Self> {
146 if self.is_output_layer() {
147 return None;
148 }
149
150 let next_layer = match self {
151 Layer::GrandProduct(layer) => Self::next_grand_product_layer(layer),
152 Layer::LogUpGeneric {
153 numerators,
154 denominators,
155 }
156 | Layer::LogUpMultiplicities {
157 numerators,
158 denominators,
159 } => Self::next_logup_layer(MleExpr::Mle(numerators), denominators),
160 Layer::LogUpSingles { denominators } => {
161 Self::next_logup_layer(MleExpr::Constant(F::ONE), denominators)
162 }
163 };
164 Some(next_layer)
165 }
166
167 fn next_grand_product_layer(layer: &Mle<F>) -> Layer<F> {
168 let res = layer
169 .chunks_exact(2) .map(|chunk| chunk[0] * chunk[1]) .collect();
172 Layer::GrandProduct(Mle::new(res))
173 }
174
175 fn next_logup_layer(numerators: MleExpr<'_, F>, denominators: &Mle<F>) -> Layer<F> {
176 let half_n = 1 << (denominators.arity() - 1);
177 let mut next_numerators = Vec::with_capacity(half_n);
178 let mut next_denominators = Vec::with_capacity(half_n);
179
180 for i in 0..half_n {
181 let a = Fraction::new(numerators[i * 2], denominators[i * 2]);
182 let b = Fraction::new(numerators[i * 2 + 1], denominators[i * 2 + 1]);
183 let res = a + b;
184 next_numerators.push(res.numerator);
185 next_denominators.push(res.denominator);
186 }
187
188 Layer::LogUpGeneric {
189 numerators: Mle::new(next_numerators),
190 denominators: Mle::new(next_denominators),
191 }
192 }
193
194 pub fn try_into_output_layer_values(self) -> Result<Vec<F>, NotOutputLayerError> {
196 if !self.is_output_layer() {
197 return Err(NotOutputLayerError);
198 }
199
200 Ok(match self {
201 Layer::LogUpSingles { denominators } => {
202 let numerator = F::ONE;
203 let denominator = denominators[0];
204 vec![numerator, denominator]
205 }
206 Layer::LogUpGeneric {
207 numerators,
208 denominators,
209 }
210 | Layer::LogUpMultiplicities {
211 numerators,
212 denominators,
213 } => {
214 let numerator = numerators[0];
215 let denominator = denominators[0];
216 vec![numerator, denominator]
217 }
218 Layer::GrandProduct(col) => {
219 vec![col[0]]
220 }
221 })
222 }
223
224 pub fn fix_first_variable(self, x0: F) -> Self {
226 if self.n_variables() == 0 {
227 return self;
228 }
229
230 match self {
231 Self::GrandProduct(mle) => Self::GrandProduct(mle.partial_evaluation(x0)),
232 Self::LogUpGeneric {
233 numerators,
234 denominators,
235 }
236 | Self::LogUpMultiplicities {
237 numerators,
238 denominators,
239 } => Self::LogUpGeneric {
240 numerators: numerators.partial_evaluation(x0),
241 denominators: denominators.partial_evaluation(x0),
242 },
243 Self::LogUpSingles { denominators } => Self::LogUpSingles {
244 denominators: denominators.partial_evaluation(x0),
245 },
246 }
247 }
248}
249
250#[derive(Debug)]
251pub struct NotOutputLayerError;
252
253enum MleExpr<'a, F: Field> {
254 Constant(F),
255 Mle(&'a Mle<F>),
256}
257
258impl<F: Field> Index<usize> for MleExpr<'_, F> {
259 type Output = F;
260
261 fn index(&self, index: usize) -> &F {
262 match self {
263 Self::Constant(v) => v,
264 Self::Mle(mle) => &mle[index],
265 }
266 }
267}