openvm_stark_backend/gkr/
types.rs

1use std::ops::Index;
2
3use p3_field::Field;
4use thiserror::Error;
5
6use crate::{
7    poly::{
8        multi::{fold_mle_evals, Mle, MultivariatePolyOracle},
9        uni::Fraction,
10    },
11    sumcheck::{SumcheckError, SumcheckProof},
12};
13
14/// Batch GKR proof.
15pub struct GkrBatchProof<F> {
16    /// Sum-check proof for each layer.
17    pub sumcheck_proofs: Vec<SumcheckProof<F>>,
18    /// Mask for each layer for each instance.
19    pub layer_masks_by_instance: Vec<Vec<GkrMask<F>>>,
20    /// Column circuit outputs for each instance.
21    pub output_claims_by_instance: Vec<Vec<F>>,
22}
23
24/// Values of interest obtained from the execution of the GKR protocol.
25pub struct GkrArtifact<F> {
26    /// Out-of-domain (OOD) point for evaluating columns in the input layer.
27    pub ood_point: Vec<F>,
28    /// The claimed evaluation at `ood_point` for each column in the input layer of each instance.
29    pub claims_to_verify_by_instance: Vec<Vec<F>>,
30    /// The number of variables that interpolate the input layer of each instance.
31    pub n_variables_by_instance: Vec<usize>,
32}
33
34/// Stores two evaluations of each column in a GKR layer.
35#[derive(Debug, Clone)]
36pub struct GkrMask<F> {
37    columns: Vec<[F; 2]>,
38}
39
40impl<F> GkrMask<F> {
41    pub fn new(columns: Vec<[F; 2]>) -> Self {
42        Self { columns }
43    }
44
45    pub fn columns(&self) -> &[[F; 2]] {
46        &self.columns
47    }
48}
49
50impl<F: Field> GkrMask<F> {
51    pub fn to_rows(&self) -> [Vec<F>; 2] {
52        self.columns.iter().map(|[a, b]| (a, b)).unzip().into()
53    }
54
55    /// Returns all `p_i(x)` where `p_i` interpolates column `i` of the mask on `{0, 1}`.
56    pub fn reduce_at_point(&self, x: F) -> Vec<F> {
57        self.columns
58            .iter()
59            .map(|&[v0, v1]| fold_mle_evals(x, v0, v1))
60            .collect()
61    }
62}
63
64/// Error encountered during GKR protocol verification.
65#[derive(Error, Debug)]
66pub enum GkrError<F> {
67    /// The proof is malformed.
68    #[error("proof data is invalid")]
69    MalformedProof,
70    /// Mask has an invalid number of columns.
71    #[error("mask in layer {instance_layer} of instance {instance} is invalid")]
72    InvalidMask {
73        instance: usize,
74        /// Layer of the instance (but not necessarily the batch).
75        instance_layer: LayerIndex,
76    },
77    /// There is a mismatch between the number of instances in the proof and the number of
78    /// instances passed for verification.
79    #[error("provided an invalid number of instances (given {given}, proof expects {proof})")]
80    NumInstancesMismatch { given: usize, proof: usize },
81    /// There was an error with one of the sumcheck proofs.
82    #[error("sum-check invalid in layer {layer}: {source}")]
83    InvalidSumcheck {
84        layer: LayerIndex,
85        source: SumcheckError<F>,
86    },
87    /// The circuit polynomial the verifier evaluated doesn't match claim from sumcheck.
88    #[error("circuit check failed in layer {layer} (calculated {output}, claim {claim})")]
89    CircuitCheckFailure {
90        claim: F,
91        output: F,
92        layer: LayerIndex,
93    },
94}
95
96/// GKR layer index where 0 corresponds to the output layer.
97pub type LayerIndex = usize;
98
99/// Represents a layer in a binary tree structured GKR circuit.
100///
101/// Layers can contain multiple columns, for example [LogUp] which has separate columns for
102/// numerators and denominators.
103///
104/// [LogUp]: https://eprint.iacr.org/2023/1284.pdf
105#[derive(Debug, Clone)]
106pub enum Layer<F> {
107    GrandProduct(Mle<F>),
108    LogUpGeneric {
109        numerators: Mle<F>,
110        denominators: Mle<F>,
111    },
112    LogUpMultiplicities {
113        numerators: Mle<F>,
114        denominators: Mle<F>,
115    },
116    /// All numerators implicitly equal "1".
117    LogUpSingles {
118        denominators: Mle<F>,
119    },
120}
121
122impl<F: Field> Layer<F> {
123    /// Returns the number of variables used to interpolate the layer's gate values.
124    pub fn n_variables(&self) -> usize {
125        match self {
126            Self::GrandProduct(mle)
127            | Self::LogUpSingles { denominators: mle }
128            | Self::LogUpMultiplicities {
129                denominators: mle, ..
130            }
131            | Self::LogUpGeneric {
132                denominators: mle, ..
133            } => mle.arity(),
134        }
135    }
136
137    fn is_output_layer(&self) -> bool {
138        self.n_variables() == 0
139    }
140
141    /// Produces the next layer from the current layer.
142    ///
143    /// The next layer is strictly half the size of the current layer.
144    /// Returns [`None`] if called on an output layer.
145    pub fn next_layer(&self) -> Option<Self> {
146        if self.is_output_layer() {
147            return None;
148        }
149
150        let next_layer = match self {
151            Layer::GrandProduct(layer) => Self::next_grand_product_layer(layer),
152            Layer::LogUpGeneric {
153                numerators,
154                denominators,
155            }
156            | Layer::LogUpMultiplicities {
157                numerators,
158                denominators,
159            } => Self::next_logup_layer(MleExpr::Mle(numerators), denominators),
160            Layer::LogUpSingles { denominators } => {
161                Self::next_logup_layer(MleExpr::Constant(F::ONE), denominators)
162            }
163        };
164        Some(next_layer)
165    }
166
167    fn next_grand_product_layer(layer: &Mle<F>) -> Layer<F> {
168        let res = layer
169            .chunks_exact(2) // Process in chunks of 2 elements
170            .map(|chunk| chunk[0] * chunk[1]) // Multiply each pair
171            .collect();
172        Layer::GrandProduct(Mle::new(res))
173    }
174
175    fn next_logup_layer(numerators: MleExpr<'_, F>, denominators: &Mle<F>) -> Layer<F> {
176        let half_n = 1 << (denominators.arity() - 1);
177        let mut next_numerators = Vec::with_capacity(half_n);
178        let mut next_denominators = Vec::with_capacity(half_n);
179
180        for i in 0..half_n {
181            let a = Fraction::new(numerators[i * 2], denominators[i * 2]);
182            let b = Fraction::new(numerators[i * 2 + 1], denominators[i * 2 + 1]);
183            let res = a + b;
184            next_numerators.push(res.numerator);
185            next_denominators.push(res.denominator);
186        }
187
188        Layer::LogUpGeneric {
189            numerators: Mle::new(next_numerators),
190            denominators: Mle::new(next_denominators),
191        }
192    }
193
194    /// Returns each column output if the layer is an output layer, otherwise returns an `Err`.
195    pub fn try_into_output_layer_values(self) -> Result<Vec<F>, NotOutputLayerError> {
196        if !self.is_output_layer() {
197            return Err(NotOutputLayerError);
198        }
199
200        Ok(match self {
201            Layer::LogUpSingles { denominators } => {
202                let numerator = F::ONE;
203                let denominator = denominators[0];
204                vec![numerator, denominator]
205            }
206            Layer::LogUpGeneric {
207                numerators,
208                denominators,
209            }
210            | Layer::LogUpMultiplicities {
211                numerators,
212                denominators,
213            } => {
214                let numerator = numerators[0];
215                let denominator = denominators[0];
216                vec![numerator, denominator]
217            }
218            Layer::GrandProduct(col) => {
219                vec![col[0]]
220            }
221        })
222    }
223
224    /// Returns a transformed layer with the first variable of each column fixed to `assignment`.
225    pub fn fix_first_variable(self, x0: F) -> Self {
226        if self.n_variables() == 0 {
227            return self;
228        }
229
230        match self {
231            Self::GrandProduct(mle) => Self::GrandProduct(mle.partial_evaluation(x0)),
232            Self::LogUpGeneric {
233                numerators,
234                denominators,
235            }
236            | Self::LogUpMultiplicities {
237                numerators,
238                denominators,
239            } => Self::LogUpGeneric {
240                numerators: numerators.partial_evaluation(x0),
241                denominators: denominators.partial_evaluation(x0),
242            },
243            Self::LogUpSingles { denominators } => Self::LogUpSingles {
244                denominators: denominators.partial_evaluation(x0),
245            },
246        }
247    }
248}
249
250#[derive(Debug)]
251pub struct NotOutputLayerError;
252
253enum MleExpr<'a, F: Field> {
254    Constant(F),
255    Mle(&'a Mle<F>),
256}
257
258impl<F: Field> Index<usize> for MleExpr<'_, F> {
259    type Output = F;
260
261    fn index(&self, index: usize) -> &F {
262        match self {
263            Self::Constant(v) => v,
264            Self::Mle(mle) => &mle[index],
265        }
266    }
267}