1use crate::{
2 loader::{native::NativeLoader, LoadedScalar, Loader},
3 util::{
4 arithmetic::{CurveAffine, Domain, Field, Fraction, PrimeField, Rotation},
5 Itertools,
6 },
7};
8use num_integer::Integer;
9use num_traits::One;
10use serde::{Deserialize, Serialize};
11use std::{
12 cmp::{max, Ordering},
13 collections::{BTreeMap, BTreeSet},
14 fmt::Debug,
15 iter::{self, Sum},
16 ops::{Add, Mul, Neg, Sub},
17};
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
21pub struct DomainAsWitness<C, L>
22where
23 C: CurveAffine,
24 L: Loader<C>,
25{
26 pub k: L::LoadedScalar,
28 pub n: L::LoadedScalar,
30 pub gen: L::LoadedScalar,
32 pub gen_inv: L::LoadedScalar,
34}
35
36impl<C, L> DomainAsWitness<C, L>
37where
38 C: CurveAffine,
39 L: Loader<C>,
40{
41 pub fn rotate_one(&self, rotation: Rotation) -> L::LoadedScalar {
43 let loader = self.gen.loader();
44 match rotation.0.cmp(&0) {
45 Ordering::Equal => loader.load_one(),
46 Ordering::Greater => self.gen.pow_const(rotation.0 as u64),
47 Ordering::Less => self.gen_inv.pow_const(-rotation.0 as u64),
48 }
49 }
50}
51
52#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct PlonkProtocol<C, L = NativeLoader>
55where
56 C: CurveAffine,
57 L: Loader<C>,
58{
59 #[serde(bound(
60 serialize = "C::Scalar: Serialize",
61 deserialize = "C::Scalar: Deserialize<'de>"
62 ))]
63 pub domain: Domain<C::Scalar>,
65
66 #[serde(bound(
67 serialize = "L::LoadedScalar: Serialize",
68 deserialize = "L::LoadedScalar: Deserialize<'de>"
69 ))]
70 pub domain_as_witness: Option<DomainAsWitness<C, L>>,
72
73 #[serde(bound(
74 serialize = "L::LoadedEcPoint: Serialize",
75 deserialize = "L::LoadedEcPoint: Deserialize<'de>"
76 ))]
77 pub preprocessed: Vec<L::LoadedEcPoint>,
79 pub num_instance: Vec<usize>,
81 pub num_witness: Vec<usize>,
83 pub num_challenge: Vec<usize>,
85 pub evaluations: Vec<Query>,
87 pub queries: Vec<Query>,
89 pub quotient: QuotientPolynomial<C::Scalar>,
91 #[serde(bound(
92 serialize = "L::LoadedScalar: Serialize",
93 deserialize = "L::LoadedScalar: Deserialize<'de>"
94 ))]
95 pub transcript_initial_state: Option<L::LoadedScalar>,
97 pub instance_committing_key: Option<InstanceCommittingKey<C>>,
99 pub linearization: Option<LinearizationStrategy>,
101 pub accumulator_indices: Vec<Vec<(usize, usize)>>,
104}
105
106impl<C, L> PlonkProtocol<C, L>
107where
108 C: CurveAffine,
109 L: Loader<C>,
110{
111 pub(super) fn langranges(&self) -> impl IntoIterator<Item = i32> {
112 let instance_eval_lagrange = self.instance_committing_key.is_none().then(|| {
113 let queries = {
114 let offset = self.preprocessed.len();
115 let range = offset..offset + self.num_instance.len();
116 self.quotient
117 .numerator
118 .used_query()
119 .into_iter()
120 .filter(move |query| range.contains(&query.poly))
121 };
122 let (min_rotation, max_rotation) = queries.fold((0, 0), |(min, max), query| {
123 if query.rotation.0 < min {
124 (query.rotation.0, max)
125 } else if query.rotation.0 > max {
126 (min, query.rotation.0)
127 } else {
128 (min, max)
129 }
130 });
131 let max_instance_len = self.num_instance.iter().max().copied().unwrap_or_default();
132 -max_rotation..max_instance_len as i32 + min_rotation.abs()
133 });
134 self.quotient
135 .numerator
136 .used_langrange()
137 .into_iter()
138 .chain(instance_eval_lagrange.into_iter().flatten())
139 }
140}
141impl<C> PlonkProtocol<C>
142where
143 C: CurveAffine,
144{
145 pub fn loaded<L: Loader<C>>(&self, loader: &L) -> PlonkProtocol<C, L> {
148 let preprocessed = self
149 .preprocessed
150 .iter()
151 .map(|preprocessed| loader.ec_point_load_const(preprocessed))
152 .collect();
153 let transcript_initial_state = self
154 .transcript_initial_state
155 .as_ref()
156 .map(|transcript_initial_state| loader.load_const(transcript_initial_state));
157 PlonkProtocol {
158 domain: self.domain.clone(),
159 domain_as_witness: None,
160 preprocessed,
161 num_instance: self.num_instance.clone(),
162 num_witness: self.num_witness.clone(),
163 num_challenge: self.num_challenge.clone(),
164 evaluations: self.evaluations.clone(),
165 queries: self.queries.clone(),
166 quotient: self.quotient.clone(),
167 transcript_initial_state,
168 instance_committing_key: self.instance_committing_key.clone(),
169 linearization: self.linearization,
170 accumulator_indices: self.accumulator_indices.clone(),
171 }
172 }
173}
174
175#[cfg(feature = "loader_halo2")]
176mod halo2 {
177 use crate::{
178 loader::{
179 halo2::{EccInstructions, Halo2Loader},
180 LoadedScalar, ScalarLoader,
181 },
182 util::arithmetic::CurveAffine,
183 verifier::plonk::PlonkProtocol,
184 };
185 use halo2_base::utils::bit_length;
186 use std::rc::Rc;
187
188 use super::{DomainAsWitness, PrimeField};
189
190 impl<C> PlonkProtocol<C>
191 where
192 C: CurveAffine,
193 {
194 pub fn loaded_preprocessed_as_witness<EccChip: EccInstructions<C>>(
198 &self,
199 loader: &Rc<Halo2Loader<C, EccChip>>,
200 load_k_as_witness: bool,
201 ) -> PlonkProtocol<C, Rc<Halo2Loader<C, EccChip>>> {
202 let domain_as_witness = load_k_as_witness.then(|| {
203 let k = loader.assign_scalar(C::Scalar::from(self.domain.k as u64));
204 let two = loader.load_const(&C::Scalar::from(2));
206 let n = two.pow_var(&k, bit_length(C::Scalar::S as u64) + 1);
207 let root_of_unity = loader.load_const(&C::Scalar::ROOT_OF_UNITY);
210 let s = loader.load_const(&C::Scalar::from(C::Scalar::S as u64));
211 let exp = two.pow_var(&(s - &k), bit_length(C::Scalar::S as u64)); let gen = root_of_unity.pow_var(&exp, C::Scalar::S as usize); let gen_inv = gen.invert().expect("subgroup generation is invertible");
214 DomainAsWitness { k, n, gen, gen_inv }
215 });
216
217 let preprocessed = self
218 .preprocessed
219 .iter()
220 .map(|preprocessed| loader.assign_ec_point(*preprocessed))
221 .collect();
222 let transcript_initial_state = self
223 .transcript_initial_state
224 .as_ref()
225 .map(|transcript_initial_state| loader.assign_scalar(*transcript_initial_state));
226 PlonkProtocol {
227 domain: self.domain.clone(),
228 domain_as_witness,
229 preprocessed,
230 num_instance: self.num_instance.clone(),
231 num_witness: self.num_witness.clone(),
232 num_challenge: self.num_challenge.clone(),
233 evaluations: self.evaluations.clone(),
234 queries: self.queries.clone(),
235 quotient: self.quotient.clone(),
236 transcript_initial_state,
237 instance_committing_key: self.instance_committing_key.clone(),
238 linearization: self.linearization,
239 accumulator_indices: self.accumulator_indices.clone(),
240 }
241 }
242 }
243}
244
245#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
246pub enum CommonPolynomial {
247 Identity,
248 Lagrange(i32),
249}
250
251#[derive(Clone, Debug)]
252pub struct CommonPolynomialEvaluation<C, L>
253where
254 C: CurveAffine,
255 L: Loader<C>,
256{
257 zn: L::LoadedScalar,
258 zn_minus_one: L::LoadedScalar,
259 zn_minus_one_inv: Fraction<L::LoadedScalar>,
260 identity: L::LoadedScalar,
261 lagrange: BTreeMap<i32, Fraction<L::LoadedScalar>>,
262}
263
264impl<C, L> CommonPolynomialEvaluation<C, L>
265where
266 C: CurveAffine,
267 L: Loader<C>,
268{
269 pub fn new(
273 domain: &Domain<C::Scalar>,
274 lagranges: impl IntoIterator<Item = i32>,
275 z: &L::LoadedScalar,
276 domain_as_witness: &Option<DomainAsWitness<C, L>>,
277 ) -> Self {
278 let loader = z.loader();
279
280 let lagranges = lagranges.into_iter().sorted().dedup().collect_vec();
281 let one = loader.load_one();
282
283 let (zn, n_inv, omegas) = if let Some(domain) = domain_as_witness.as_ref() {
284 let zn = z.pow_var(&domain.n, C::Scalar::S as usize + 1);
285 let n_inv = domain.n.invert().expect("n is not zero");
286 let omegas = lagranges.iter().map(|&i| domain.rotate_one(Rotation(i))).collect_vec();
287 (zn, n_inv, omegas)
288 } else {
289 let zn = z.pow_const(domain.n as u64);
290 let n_inv = loader.load_const(&domain.n_inv);
291 let omegas = lagranges
292 .iter()
293 .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::ONE, Rotation(i))))
294 .collect_vec();
295 (zn, n_inv, omegas)
296 };
297
298 let zn_minus_one = zn.clone() - &one;
299 let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone());
300
301 let numer = zn_minus_one.clone() * &n_inv;
302 let lagrange_evals = omegas
303 .iter()
304 .map(|omega| Fraction::new(numer.clone() * omega, z.clone() - omega))
305 .collect_vec();
306
307 Self {
308 zn,
309 zn_minus_one,
310 zn_minus_one_inv,
311 identity: z.clone(),
312 lagrange: lagranges.into_iter().zip(lagrange_evals).collect(),
313 }
314 }
315
316 pub fn zn(&self) -> &L::LoadedScalar {
317 &self.zn
318 }
319
320 pub fn zn_minus_one(&self) -> &L::LoadedScalar {
321 &self.zn_minus_one
322 }
323
324 pub fn zn_minus_one_inv(&self) -> &L::LoadedScalar {
325 self.zn_minus_one_inv.evaluated()
326 }
327
328 pub fn get(&self, poly: CommonPolynomial) -> &L::LoadedScalar {
329 match poly {
330 CommonPolynomial::Identity => &self.identity,
331 CommonPolynomial::Lagrange(i) => self.lagrange.get(&i).unwrap().evaluated(),
332 }
333 }
334
335 pub fn denoms(&mut self) -> impl IntoIterator<Item = &'_ mut L::LoadedScalar> {
336 self.lagrange
337 .iter_mut()
338 .map(|(_, value)| value.denom_mut())
339 .chain(iter::once(self.zn_minus_one_inv.denom_mut()))
340 .flatten()
341 }
342
343 pub fn evaluate(&mut self) {
344 self.lagrange
345 .iter_mut()
346 .map(|(_, value)| value)
347 .chain(iter::once(&mut self.zn_minus_one_inv))
348 .for_each(Fraction::evaluate)
349 }
350}
351
352#[derive(Clone, Debug, Serialize, Deserialize)]
353pub struct QuotientPolynomial<F: Clone> {
354 pub chunk_degree: usize,
355 pub numerator: Expression<F>,
356}
357
358impl<F: Clone> QuotientPolynomial<F> {
359 pub fn num_chunk(&self) -> usize {
360 Integer::div_ceil(
361 &(self.numerator.degree().checked_sub(1).unwrap_or_default()),
362 &self.chunk_degree,
363 )
364 }
365}
366
367#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
368pub struct Query {
369 pub poly: usize,
370 pub rotation: Rotation,
371}
372
373impl Query {
374 pub fn new<R: Into<Rotation>>(poly: usize, rotation: R) -> Self {
375 Self { poly, rotation: rotation.into() }
376 }
377}
378
379#[derive(Clone, Debug, Serialize, Deserialize)]
380pub enum Expression<F> {
381 Constant(F),
382 CommonPolynomial(CommonPolynomial),
383 Polynomial(Query),
384 Challenge(usize),
385 Negated(Box<Expression<F>>),
386 Sum(Box<Expression<F>>, Box<Expression<F>>),
387 Product(Box<Expression<F>>, Box<Expression<F>>),
388 Scaled(Box<Expression<F>>, F),
389 DistributePowers(Vec<Expression<F>>, Box<Expression<F>>),
390}
391
392impl<F: Clone> Expression<F> {
393 pub fn evaluate<T: Clone>(
394 &self,
395 constant: &impl Fn(F) -> T,
396 common_poly: &impl Fn(CommonPolynomial) -> T,
397 poly: &impl Fn(Query) -> T,
398 challenge: &impl Fn(usize) -> T,
399 negated: &impl Fn(T) -> T,
400 sum: &impl Fn(T, T) -> T,
401 product: &impl Fn(T, T) -> T,
402 scaled: &impl Fn(T, F) -> T,
403 ) -> T {
404 let evaluate = |expr: &Expression<F>| {
405 expr.evaluate(constant, common_poly, poly, challenge, negated, sum, product, scaled)
406 };
407 match self {
408 Expression::Constant(scalar) => constant(scalar.clone()),
409 Expression::CommonPolynomial(poly) => common_poly(*poly),
410 Expression::Polynomial(query) => poly(*query),
411 Expression::Challenge(index) => challenge(*index),
412 Expression::Negated(a) => {
413 let a = evaluate(a);
414 negated(a)
415 }
416 Expression::Sum(a, b) => {
417 let a = evaluate(a);
418 let b = evaluate(b);
419 sum(a, b)
420 }
421 Expression::Product(a, b) => {
422 let a = evaluate(a);
423 let b = evaluate(b);
424 product(a, b)
425 }
426 Expression::Scaled(a, scalar) => {
427 let a = evaluate(a);
428 scaled(a, scalar.clone())
429 }
430 Expression::DistributePowers(exprs, scalar) => {
431 assert!(!exprs.is_empty());
432 if exprs.len() == 1 {
433 return evaluate(exprs.first().unwrap());
434 }
435 let mut exprs = exprs.iter();
436 let first = evaluate(exprs.next().unwrap());
437 let scalar = evaluate(scalar);
438 exprs.fold(first, |acc, expr| sum(product(acc, scalar.clone()), evaluate(expr)))
439 }
440 }
441 }
442
443 pub fn degree(&self) -> usize {
444 match self {
445 Expression::Constant(_) => 0,
446 Expression::CommonPolynomial(_) => 1,
447 Expression::Polynomial { .. } => 1,
448 Expression::Challenge { .. } => 0,
449 Expression::Negated(a) => a.degree(),
450 Expression::Sum(a, b) => max(a.degree(), b.degree()),
451 Expression::Product(a, b) => a.degree() + b.degree(),
452 Expression::Scaled(a, _) => a.degree(),
453 Expression::DistributePowers(a, b) => {
454 a.iter().chain(Some(b.as_ref())).map(Self::degree).max().unwrap_or_default()
455 }
456 }
457 }
458
459 pub fn used_langrange(&self) -> BTreeSet<i32> {
460 self.evaluate(
461 &|_| None,
462 &|poly| match poly {
463 CommonPolynomial::Lagrange(i) => Some(BTreeSet::from_iter([i])),
464 _ => None,
465 },
466 &|_| None,
467 &|_| None,
468 &|a| a,
469 &merge_left_right,
470 &merge_left_right,
471 &|a, _| a,
472 )
473 .unwrap_or_default()
474 }
475
476 pub fn used_query(&self) -> BTreeSet<Query> {
477 self.evaluate(
478 &|_| None,
479 &|_| None,
480 &|query| Some(BTreeSet::from_iter([query])),
481 &|_| None,
482 &|a| a,
483 &merge_left_right,
484 &merge_left_right,
485 &|a, _| a,
486 )
487 .unwrap_or_default()
488 }
489}
490
491impl<F: Clone> From<Query> for Expression<F> {
492 fn from(query: Query) -> Self {
493 Self::Polynomial(query)
494 }
495}
496
497impl<F: Clone> From<CommonPolynomial> for Expression<F> {
498 fn from(common_poly: CommonPolynomial) -> Self {
499 Self::CommonPolynomial(common_poly)
500 }
501}
502
503macro_rules! impl_expression_ops {
504 ($trait:ident, $op:ident, $variant:ident, $rhs:ty, $rhs_expr:expr) => {
505 impl<F: Clone> $trait<$rhs> for Expression<F> {
506 type Output = Expression<F>;
507 fn $op(self, rhs: $rhs) -> Self::Output {
508 Expression::$variant((self).into(), $rhs_expr(rhs).into())
509 }
510 }
511 impl<F: Clone> $trait<$rhs> for &Expression<F> {
512 type Output = Expression<F>;
513 fn $op(self, rhs: $rhs) -> Self::Output {
514 Expression::$variant((self.clone()).into(), $rhs_expr(rhs).into())
515 }
516 }
517 impl<F: Clone> $trait<&$rhs> for Expression<F> {
518 type Output = Expression<F>;
519 fn $op(self, rhs: &$rhs) -> Self::Output {
520 Expression::$variant((self).into(), $rhs_expr(rhs.clone()).into())
521 }
522 }
523 impl<F: Clone> $trait<&$rhs> for &Expression<F> {
524 type Output = Expression<F>;
525 fn $op(self, rhs: &$rhs) -> Self::Output {
526 Expression::$variant((self.clone()).into(), $rhs_expr(rhs.clone()).into())
527 }
528 }
529 };
530}
531
532impl_expression_ops!(Mul, mul, Product, Expression<F>, std::convert::identity);
533impl_expression_ops!(Mul, mul, Scaled, F, std::convert::identity);
534impl_expression_ops!(Add, add, Sum, Expression<F>, std::convert::identity);
535impl_expression_ops!(Sub, sub, Sum, Expression<F>, Neg::neg);
536
537impl<F: Clone> Neg for Expression<F> {
538 type Output = Expression<F>;
539 fn neg(self) -> Self::Output {
540 Expression::Negated(Box::new(self))
541 }
542}
543
544impl<F: Clone> Neg for &Expression<F> {
545 type Output = Expression<F>;
546 fn neg(self) -> Self::Output {
547 Expression::Negated(Box::new(self.clone()))
548 }
549}
550
551impl<F: Clone + Default> Sum for Expression<F> {
552 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
553 iter.reduce(|acc, item| acc + item).unwrap_or_else(|| Expression::Constant(F::default()))
554 }
555}
556
557impl<F: Field> One for Expression<F> {
558 fn one() -> Self {
559 Expression::Constant(F::ONE)
560 }
561}
562
563fn merge_left_right<T: Ord>(a: Option<BTreeSet<T>>, b: Option<BTreeSet<T>>) -> Option<BTreeSet<T>> {
564 match (a, b) {
565 (Some(a), None) | (None, Some(a)) => Some(a),
566 (Some(mut a), Some(b)) => {
567 a.extend(b);
568 Some(a)
569 }
570 _ => None,
571 }
572}
573
574#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
575pub enum LinearizationStrategy {
576 WithoutConstant,
580 MinusVanishingTimesQuotient,
584}
585
586#[derive(Clone, Debug, Default, Serialize, Deserialize)]
587pub struct InstanceCommittingKey<C> {
588 pub bases: Vec<C>,
589 pub constant: Option<C>,
590}