snark_verifier/verifier/plonk/
protocol.rs

1use crate::{
2    loader::{native::NativeLoader, LoadedScalar, Loader},
3    util::{
4        arithmetic::{CurveAffine, Domain, Field, Fraction, PrimeField, Rotation},
5        Itertools,
6    },
7};
8use num_integer::Integer;
9use num_traits::One;
10use serde::{Deserialize, Serialize};
11use std::{
12    cmp::{max, Ordering},
13    collections::{BTreeMap, BTreeSet},
14    fmt::Debug,
15    iter::{self, Sum},
16    ops::{Add, Mul, Neg, Sub},
17};
18
19/// Domain parameters to be optionally loaded as witnesses
20#[derive(Clone, Debug, Serialize, Deserialize)]
21pub struct DomainAsWitness<C, L>
22where
23    C: CurveAffine,
24    L: Loader<C>,
25{
26    /// 2<sup>k</sup> is the number of rows in the domain
27    pub k: L::LoadedScalar,
28    /// n = 2<sup>k</sup> is the number of rows in the domain
29    pub n: L::LoadedScalar,
30    /// Generator of the domain
31    pub gen: L::LoadedScalar,
32    /// Inverse generator of the domain
33    pub gen_inv: L::LoadedScalar,
34}
35
36impl<C, L> DomainAsWitness<C, L>
37where
38    C: CurveAffine,
39    L: Loader<C>,
40{
41    /// Rotate `F::one()` to given `rotation`.
42    pub fn rotate_one(&self, rotation: Rotation) -> L::LoadedScalar {
43        let loader = self.gen.loader();
44        match rotation.0.cmp(&0) {
45            Ordering::Equal => loader.load_one(),
46            Ordering::Greater => self.gen.pow_const(rotation.0 as u64),
47            Ordering::Less => self.gen_inv.pow_const(-rotation.0 as u64),
48        }
49    }
50}
51
52/// Protocol specifying configuration of a PLONK.
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct PlonkProtocol<C, L = NativeLoader>
55where
56    C: CurveAffine,
57    L: Loader<C>,
58{
59    #[serde(bound(
60        serialize = "C::Scalar: Serialize",
61        deserialize = "C::Scalar: Deserialize<'de>"
62    ))]
63    /// Working domain.
64    pub domain: Domain<C::Scalar>,
65
66    #[serde(bound(
67        serialize = "L::LoadedScalar: Serialize",
68        deserialize = "L::LoadedScalar: Deserialize<'de>"
69    ))]
70    /// Optional: load `domain.n` and `domain.gen` as a witness
71    pub domain_as_witness: Option<DomainAsWitness<C, L>>,
72
73    #[serde(bound(
74        serialize = "L::LoadedEcPoint: Serialize",
75        deserialize = "L::LoadedEcPoint: Deserialize<'de>"
76    ))]
77    /// Commitments of preprocessed polynomials.
78    pub preprocessed: Vec<L::LoadedEcPoint>,
79    /// Number of instances in each instance polynomial.
80    pub num_instance: Vec<usize>,
81    /// Number of witness polynomials in each phase.
82    pub num_witness: Vec<usize>,
83    /// Number of challenges to squeeze from transcript after each phase.
84    pub num_challenge: Vec<usize>,
85    /// Evaluations to read from transcript.
86    pub evaluations: Vec<Query>,
87    /// [`crate::pcs::PolynomialCommitmentScheme`] queries to verify.
88    pub queries: Vec<Query>,
89    /// Structure of quotient polynomial.
90    pub quotient: QuotientPolynomial<C::Scalar>,
91    #[serde(bound(
92        serialize = "L::LoadedScalar: Serialize",
93        deserialize = "L::LoadedScalar: Deserialize<'de>"
94    ))]
95    /// Prover and verifier common initial state to write to transcript if any.
96    pub transcript_initial_state: Option<L::LoadedScalar>,
97    /// Instance polynomials commiting key if any.
98    pub instance_committing_key: Option<InstanceCommittingKey<C>>,
99    /// Linearization strategy.
100    pub linearization: Option<LinearizationStrategy>,
101    /// Indices (instance polynomial index, row) of encoded
102    /// [`crate::pcs::AccumulationScheme::Accumulator`]s.
103    pub accumulator_indices: Vec<Vec<(usize, usize)>>,
104}
105
106impl<C, L> PlonkProtocol<C, L>
107where
108    C: CurveAffine,
109    L: Loader<C>,
110{
111    pub(super) fn langranges(&self) -> impl IntoIterator<Item = i32> {
112        let instance_eval_lagrange = self.instance_committing_key.is_none().then(|| {
113            let queries = {
114                let offset = self.preprocessed.len();
115                let range = offset..offset + self.num_instance.len();
116                self.quotient
117                    .numerator
118                    .used_query()
119                    .into_iter()
120                    .filter(move |query| range.contains(&query.poly))
121            };
122            let (min_rotation, max_rotation) = queries.fold((0, 0), |(min, max), query| {
123                if query.rotation.0 < min {
124                    (query.rotation.0, max)
125                } else if query.rotation.0 > max {
126                    (min, query.rotation.0)
127                } else {
128                    (min, max)
129                }
130            });
131            let max_instance_len = self.num_instance.iter().max().copied().unwrap_or_default();
132            -max_rotation..max_instance_len as i32 + min_rotation.abs()
133        });
134        self.quotient
135            .numerator
136            .used_langrange()
137            .into_iter()
138            .chain(instance_eval_lagrange.into_iter().flatten())
139    }
140}
141impl<C> PlonkProtocol<C>
142where
143    C: CurveAffine,
144{
145    /// Loaded `PlonkProtocol` with `preprocessed` and
146    /// `transcript_initial_state` loaded as constant.
147    pub fn loaded<L: Loader<C>>(&self, loader: &L) -> PlonkProtocol<C, L> {
148        let preprocessed = self
149            .preprocessed
150            .iter()
151            .map(|preprocessed| loader.ec_point_load_const(preprocessed))
152            .collect();
153        let transcript_initial_state = self
154            .transcript_initial_state
155            .as_ref()
156            .map(|transcript_initial_state| loader.load_const(transcript_initial_state));
157        PlonkProtocol {
158            domain: self.domain.clone(),
159            domain_as_witness: None,
160            preprocessed,
161            num_instance: self.num_instance.clone(),
162            num_witness: self.num_witness.clone(),
163            num_challenge: self.num_challenge.clone(),
164            evaluations: self.evaluations.clone(),
165            queries: self.queries.clone(),
166            quotient: self.quotient.clone(),
167            transcript_initial_state,
168            instance_committing_key: self.instance_committing_key.clone(),
169            linearization: self.linearization,
170            accumulator_indices: self.accumulator_indices.clone(),
171        }
172    }
173}
174
175#[cfg(feature = "loader_halo2")]
176mod halo2 {
177    use crate::{
178        loader::{
179            halo2::{EccInstructions, Halo2Loader},
180            LoadedScalar, ScalarLoader,
181        },
182        util::arithmetic::CurveAffine,
183        verifier::plonk::PlonkProtocol,
184    };
185    use halo2_base::utils::bit_length;
186    use std::rc::Rc;
187
188    use super::{DomainAsWitness, PrimeField};
189
190    impl<C> PlonkProtocol<C>
191    where
192        C: CurveAffine,
193    {
194        /// Loaded `PlonkProtocol` with `preprocessed` and
195        /// `transcript_initial_state` loaded as witness, which is useful when
196        /// doing recursion.
197        pub fn loaded_preprocessed_as_witness<EccChip: EccInstructions<C>>(
198            &self,
199            loader: &Rc<Halo2Loader<C, EccChip>>,
200            load_k_as_witness: bool,
201        ) -> PlonkProtocol<C, Rc<Halo2Loader<C, EccChip>>> {
202            let domain_as_witness = load_k_as_witness.then(|| {
203                let k = loader.assign_scalar(C::Scalar::from(self.domain.k as u64));
204                // n = 2^k
205                let two = loader.load_const(&C::Scalar::from(2));
206                let n = two.pow_var(&k, bit_length(C::Scalar::S as u64) + 1);
207                // gen = omega = ROOT_OF_UNITY ^ {2^{S - k}}, where ROOT_OF_UNITY is primitive 2^S root of unity
208                // this makes omega a 2^k root of unity
209                let root_of_unity = loader.load_const(&C::Scalar::ROOT_OF_UNITY);
210                let s = loader.load_const(&C::Scalar::from(C::Scalar::S as u64));
211                let exp = two.pow_var(&(s - &k), bit_length(C::Scalar::S as u64)); // if S - k < 0, constraint on max bits will fail
212                let gen = root_of_unity.pow_var(&exp, C::Scalar::S as usize); // 2^{S - k} < 2^S for k > 0
213                let gen_inv = gen.invert().expect("subgroup generation is invertible");
214                DomainAsWitness { k, n, gen, gen_inv }
215            });
216
217            let preprocessed = self
218                .preprocessed
219                .iter()
220                .map(|preprocessed| loader.assign_ec_point(*preprocessed))
221                .collect();
222            let transcript_initial_state = self
223                .transcript_initial_state
224                .as_ref()
225                .map(|transcript_initial_state| loader.assign_scalar(*transcript_initial_state));
226            PlonkProtocol {
227                domain: self.domain.clone(),
228                domain_as_witness,
229                preprocessed,
230                num_instance: self.num_instance.clone(),
231                num_witness: self.num_witness.clone(),
232                num_challenge: self.num_challenge.clone(),
233                evaluations: self.evaluations.clone(),
234                queries: self.queries.clone(),
235                quotient: self.quotient.clone(),
236                transcript_initial_state,
237                instance_committing_key: self.instance_committing_key.clone(),
238                linearization: self.linearization,
239                accumulator_indices: self.accumulator_indices.clone(),
240            }
241        }
242    }
243}
244
245#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
246pub enum CommonPolynomial {
247    Identity,
248    Lagrange(i32),
249}
250
251#[derive(Clone, Debug)]
252pub struct CommonPolynomialEvaluation<C, L>
253where
254    C: CurveAffine,
255    L: Loader<C>,
256{
257    zn: L::LoadedScalar,
258    zn_minus_one: L::LoadedScalar,
259    zn_minus_one_inv: Fraction<L::LoadedScalar>,
260    identity: L::LoadedScalar,
261    lagrange: BTreeMap<i32, Fraction<L::LoadedScalar>>,
262}
263
264impl<C, L> CommonPolynomialEvaluation<C, L>
265where
266    C: CurveAffine,
267    L: Loader<C>,
268{
269    // if `n_as_witness` is Some, then we assume `n_as_witness` has value equal to `domain.n` (i.e., number of rows in the circuit)
270    // and is loaded as a witness instead of a constant.
271    // The generator of `domain` also depends on `n`.
272    pub fn new(
273        domain: &Domain<C::Scalar>,
274        lagranges: impl IntoIterator<Item = i32>,
275        z: &L::LoadedScalar,
276        domain_as_witness: &Option<DomainAsWitness<C, L>>,
277    ) -> Self {
278        let loader = z.loader();
279
280        let lagranges = lagranges.into_iter().sorted().dedup().collect_vec();
281        let one = loader.load_one();
282
283        let (zn, n_inv, omegas) = if let Some(domain) = domain_as_witness.as_ref() {
284            let zn = z.pow_var(&domain.n, C::Scalar::S as usize + 1);
285            let n_inv = domain.n.invert().expect("n is not zero");
286            let omegas = lagranges.iter().map(|&i| domain.rotate_one(Rotation(i))).collect_vec();
287            (zn, n_inv, omegas)
288        } else {
289            let zn = z.pow_const(domain.n as u64);
290            let n_inv = loader.load_const(&domain.n_inv);
291            let omegas = lagranges
292                .iter()
293                .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::ONE, Rotation(i))))
294                .collect_vec();
295            (zn, n_inv, omegas)
296        };
297
298        let zn_minus_one = zn.clone() - &one;
299        let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone());
300
301        let numer = zn_minus_one.clone() * &n_inv;
302        let lagrange_evals = omegas
303            .iter()
304            .map(|omega| Fraction::new(numer.clone() * omega, z.clone() - omega))
305            .collect_vec();
306
307        Self {
308            zn,
309            zn_minus_one,
310            zn_minus_one_inv,
311            identity: z.clone(),
312            lagrange: lagranges.into_iter().zip(lagrange_evals).collect(),
313        }
314    }
315
316    pub fn zn(&self) -> &L::LoadedScalar {
317        &self.zn
318    }
319
320    pub fn zn_minus_one(&self) -> &L::LoadedScalar {
321        &self.zn_minus_one
322    }
323
324    pub fn zn_minus_one_inv(&self) -> &L::LoadedScalar {
325        self.zn_minus_one_inv.evaluated()
326    }
327
328    pub fn get(&self, poly: CommonPolynomial) -> &L::LoadedScalar {
329        match poly {
330            CommonPolynomial::Identity => &self.identity,
331            CommonPolynomial::Lagrange(i) => self.lagrange.get(&i).unwrap().evaluated(),
332        }
333    }
334
335    pub fn denoms(&mut self) -> impl IntoIterator<Item = &'_ mut L::LoadedScalar> {
336        self.lagrange
337            .iter_mut()
338            .map(|(_, value)| value.denom_mut())
339            .chain(iter::once(self.zn_minus_one_inv.denom_mut()))
340            .flatten()
341    }
342
343    pub fn evaluate(&mut self) {
344        self.lagrange
345            .iter_mut()
346            .map(|(_, value)| value)
347            .chain(iter::once(&mut self.zn_minus_one_inv))
348            .for_each(Fraction::evaluate)
349    }
350}
351
352#[derive(Clone, Debug, Serialize, Deserialize)]
353pub struct QuotientPolynomial<F: Clone> {
354    pub chunk_degree: usize,
355    pub numerator: Expression<F>,
356}
357
358impl<F: Clone> QuotientPolynomial<F> {
359    pub fn num_chunk(&self) -> usize {
360        Integer::div_ceil(
361            &(self.numerator.degree().checked_sub(1).unwrap_or_default()),
362            &self.chunk_degree,
363        )
364    }
365}
366
367#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
368pub struct Query {
369    pub poly: usize,
370    pub rotation: Rotation,
371}
372
373impl Query {
374    pub fn new<R: Into<Rotation>>(poly: usize, rotation: R) -> Self {
375        Self { poly, rotation: rotation.into() }
376    }
377}
378
379#[derive(Clone, Debug, Serialize, Deserialize)]
380pub enum Expression<F> {
381    Constant(F),
382    CommonPolynomial(CommonPolynomial),
383    Polynomial(Query),
384    Challenge(usize),
385    Negated(Box<Expression<F>>),
386    Sum(Box<Expression<F>>, Box<Expression<F>>),
387    Product(Box<Expression<F>>, Box<Expression<F>>),
388    Scaled(Box<Expression<F>>, F),
389    DistributePowers(Vec<Expression<F>>, Box<Expression<F>>),
390}
391
392impl<F: Clone> Expression<F> {
393    pub fn evaluate<T: Clone>(
394        &self,
395        constant: &impl Fn(F) -> T,
396        common_poly: &impl Fn(CommonPolynomial) -> T,
397        poly: &impl Fn(Query) -> T,
398        challenge: &impl Fn(usize) -> T,
399        negated: &impl Fn(T) -> T,
400        sum: &impl Fn(T, T) -> T,
401        product: &impl Fn(T, T) -> T,
402        scaled: &impl Fn(T, F) -> T,
403    ) -> T {
404        let evaluate = |expr: &Expression<F>| {
405            expr.evaluate(constant, common_poly, poly, challenge, negated, sum, product, scaled)
406        };
407        match self {
408            Expression::Constant(scalar) => constant(scalar.clone()),
409            Expression::CommonPolynomial(poly) => common_poly(*poly),
410            Expression::Polynomial(query) => poly(*query),
411            Expression::Challenge(index) => challenge(*index),
412            Expression::Negated(a) => {
413                let a = evaluate(a);
414                negated(a)
415            }
416            Expression::Sum(a, b) => {
417                let a = evaluate(a);
418                let b = evaluate(b);
419                sum(a, b)
420            }
421            Expression::Product(a, b) => {
422                let a = evaluate(a);
423                let b = evaluate(b);
424                product(a, b)
425            }
426            Expression::Scaled(a, scalar) => {
427                let a = evaluate(a);
428                scaled(a, scalar.clone())
429            }
430            Expression::DistributePowers(exprs, scalar) => {
431                assert!(!exprs.is_empty());
432                if exprs.len() == 1 {
433                    return evaluate(exprs.first().unwrap());
434                }
435                let mut exprs = exprs.iter();
436                let first = evaluate(exprs.next().unwrap());
437                let scalar = evaluate(scalar);
438                exprs.fold(first, |acc, expr| sum(product(acc, scalar.clone()), evaluate(expr)))
439            }
440        }
441    }
442
443    pub fn degree(&self) -> usize {
444        match self {
445            Expression::Constant(_) => 0,
446            Expression::CommonPolynomial(_) => 1,
447            Expression::Polynomial { .. } => 1,
448            Expression::Challenge { .. } => 0,
449            Expression::Negated(a) => a.degree(),
450            Expression::Sum(a, b) => max(a.degree(), b.degree()),
451            Expression::Product(a, b) => a.degree() + b.degree(),
452            Expression::Scaled(a, _) => a.degree(),
453            Expression::DistributePowers(a, b) => {
454                a.iter().chain(Some(b.as_ref())).map(Self::degree).max().unwrap_or_default()
455            }
456        }
457    }
458
459    pub fn used_langrange(&self) -> BTreeSet<i32> {
460        self.evaluate(
461            &|_| None,
462            &|poly| match poly {
463                CommonPolynomial::Lagrange(i) => Some(BTreeSet::from_iter([i])),
464                _ => None,
465            },
466            &|_| None,
467            &|_| None,
468            &|a| a,
469            &merge_left_right,
470            &merge_left_right,
471            &|a, _| a,
472        )
473        .unwrap_or_default()
474    }
475
476    pub fn used_query(&self) -> BTreeSet<Query> {
477        self.evaluate(
478            &|_| None,
479            &|_| None,
480            &|query| Some(BTreeSet::from_iter([query])),
481            &|_| None,
482            &|a| a,
483            &merge_left_right,
484            &merge_left_right,
485            &|a, _| a,
486        )
487        .unwrap_or_default()
488    }
489}
490
491impl<F: Clone> From<Query> for Expression<F> {
492    fn from(query: Query) -> Self {
493        Self::Polynomial(query)
494    }
495}
496
497impl<F: Clone> From<CommonPolynomial> for Expression<F> {
498    fn from(common_poly: CommonPolynomial) -> Self {
499        Self::CommonPolynomial(common_poly)
500    }
501}
502
503macro_rules! impl_expression_ops {
504    ($trait:ident, $op:ident, $variant:ident, $rhs:ty, $rhs_expr:expr) => {
505        impl<F: Clone> $trait<$rhs> for Expression<F> {
506            type Output = Expression<F>;
507            fn $op(self, rhs: $rhs) -> Self::Output {
508                Expression::$variant((self).into(), $rhs_expr(rhs).into())
509            }
510        }
511        impl<F: Clone> $trait<$rhs> for &Expression<F> {
512            type Output = Expression<F>;
513            fn $op(self, rhs: $rhs) -> Self::Output {
514                Expression::$variant((self.clone()).into(), $rhs_expr(rhs).into())
515            }
516        }
517        impl<F: Clone> $trait<&$rhs> for Expression<F> {
518            type Output = Expression<F>;
519            fn $op(self, rhs: &$rhs) -> Self::Output {
520                Expression::$variant((self).into(), $rhs_expr(rhs.clone()).into())
521            }
522        }
523        impl<F: Clone> $trait<&$rhs> for &Expression<F> {
524            type Output = Expression<F>;
525            fn $op(self, rhs: &$rhs) -> Self::Output {
526                Expression::$variant((self.clone()).into(), $rhs_expr(rhs.clone()).into())
527            }
528        }
529    };
530}
531
532impl_expression_ops!(Mul, mul, Product, Expression<F>, std::convert::identity);
533impl_expression_ops!(Mul, mul, Scaled, F, std::convert::identity);
534impl_expression_ops!(Add, add, Sum, Expression<F>, std::convert::identity);
535impl_expression_ops!(Sub, sub, Sum, Expression<F>, Neg::neg);
536
537impl<F: Clone> Neg for Expression<F> {
538    type Output = Expression<F>;
539    fn neg(self) -> Self::Output {
540        Expression::Negated(Box::new(self))
541    }
542}
543
544impl<F: Clone> Neg for &Expression<F> {
545    type Output = Expression<F>;
546    fn neg(self) -> Self::Output {
547        Expression::Negated(Box::new(self.clone()))
548    }
549}
550
551impl<F: Clone + Default> Sum for Expression<F> {
552    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
553        iter.reduce(|acc, item| acc + item).unwrap_or_else(|| Expression::Constant(F::default()))
554    }
555}
556
557impl<F: Field> One for Expression<F> {
558    fn one() -> Self {
559        Expression::Constant(F::ONE)
560    }
561}
562
563fn merge_left_right<T: Ord>(a: Option<BTreeSet<T>>, b: Option<BTreeSet<T>>) -> Option<BTreeSet<T>> {
564    match (a, b) {
565        (Some(a), None) | (None, Some(a)) => Some(a),
566        (Some(mut a), Some(b)) => {
567            a.extend(b);
568            Some(a)
569        }
570        _ => None,
571    }
572}
573
574#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
575pub enum LinearizationStrategy {
576    /// Older linearization strategy of GWC19, which has linearization
577    /// polynomial that doesn't evaluate to 0, and requires prover to send extra
578    /// evaluation of it to verifier.
579    WithoutConstant,
580    /// Current linearization strategy of GWC19, which has linearization
581    /// polynomial that evaluate to 0 by subtracting product of vanishing and
582    /// quotient polynomials.
583    MinusVanishingTimesQuotient,
584}
585
586#[derive(Clone, Debug, Default, Serialize, Deserialize)]
587pub struct InstanceCommittingKey<C> {
588    pub bases: Vec<C>,
589    pub constant: Option<C>,
590}