1use super::PlonkSuccinctVerifier;
2use crate::{BITS, LIMBS};
3use getset::Getters;
4use halo2_base::{
5 gates::{
6 circuit::{
7 builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig, CircuitBuilderStage,
8 },
9 flex_gate::{threads::SinglePhaseCoreManager, MultiPhaseThreadBreakPoints},
10 RangeChip,
11 },
12 halo2_proofs::{
13 circuit::{Layouter, SimpleFloorPlanner},
14 halo2curves::bn256::{Bn256, Fr, G1Affine},
15 plonk::{self, Circuit, ConstraintSystem, Selector},
16 poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG},
17 },
18 utils::ScalarField,
19 AssignedValue,
20};
21use itertools::Itertools;
22use rand::{rngs::StdRng, SeedableRng};
23use serde::{Deserialize, Serialize};
24use snark_verifier::{
27 loader::{
28 self,
29 halo2::halo2_ecc::{self, bigint::ProperCrtUint, bn254::FpChip},
30 native::NativeLoader,
31 },
32 pcs::{
33 kzg::{KzgAccumulator, KzgAsProvingKey, KzgAsVerifyingKey, KzgSuccinctVerifyingKey},
34 AccumulationScheme, AccumulationSchemeProver, PolynomialCommitmentScheme,
35 },
36 system::halo2::transcript::halo2::TranscriptObject,
37 verifier::SnarkVerifier,
38};
39use std::{fs::File, mem, path::Path, rc::Rc};
40
41use super::{CircuitExt, PoseidonTranscript, Snark, POSEIDON_SPEC};
42
43pub type Svk = KzgSuccinctVerifyingKey<G1Affine>;
44pub type BaseFieldEccChip<'chip> = halo2_ecc::ecc::BaseFieldEccChip<'chip, G1Affine>;
45pub type Halo2Loader<'chip> = loader::halo2::Halo2Loader<G1Affine, BaseFieldEccChip<'chip>>;
46
47#[derive(Clone, Debug)]
48pub struct PreprocessedAndDomainAsWitness {
49 pub preprocessed: Vec<AssignedValue<Fr>>,
51 pub k: AssignedValue<Fr>,
52}
53
54#[derive(Clone, Debug)]
55pub struct SnarkAggregationWitness<'a> {
56 pub previous_instances: Vec<Vec<AssignedValue<Fr>>>,
61 pub accumulator: KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>,
62 pub preprocessed: Vec<PreprocessedAndDomainAsWitness>,
65 pub proof_transcripts: Vec<Vec<TranscriptObject<G1Affine, Rc<Halo2Loader<'a>>>>>,
67}
68
69#[derive(PartialEq, Eq, Clone, Copy, Debug, Default)]
71pub enum VerifierUniversality {
72 #[default]
74 None,
75 PreprocessedAsWitness,
77 Full,
79}
80
81impl VerifierUniversality {
82 pub fn preprocessed_as_witness(&self) -> bool {
83 self != &VerifierUniversality::None
84 }
85
86 pub fn k_as_witness(&self) -> bool {
87 self == &VerifierUniversality::Full
88 }
89}
90
91#[allow(clippy::type_complexity)]
92pub fn aggregate<'a, AS>(
104 svk: &Svk,
105 loader: &Rc<Halo2Loader<'a>>,
106 snarks: &[Snark],
107 as_proof: &[u8],
108 universality: VerifierUniversality,
109) -> SnarkAggregationWitness<'a>
110where
111 AS: PolynomialCommitmentScheme<
112 G1Affine,
113 Rc<Halo2Loader<'a>>,
114 VerifyingKey = Svk,
115 Output = KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>,
116 > + AccumulationScheme<
117 G1Affine,
118 Rc<Halo2Loader<'a>>,
119 Accumulator = KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>,
120 VerifyingKey = KzgAsVerifyingKey,
121 >,
122{
123 assert!(!snarks.is_empty(), "trying to aggregate 0 snarks");
124 let assign_instances = |instances: &[Vec<Fr>]| {
125 instances
126 .iter()
127 .map(|instances| {
128 instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec()
129 })
130 .collect_vec()
131 };
132
133 let mut previous_instances = Vec::with_capacity(snarks.len());
134 let mut preprocessed_witnesses = Vec::with_capacity(snarks.len());
135 let mut transcript = PoseidonTranscript::<Rc<Halo2Loader<'a>>, &[u8]>::from_spec(
137 loader,
138 &[],
139 POSEIDON_SPEC.clone(),
140 );
141
142 let preprocessed_as_witness = universality.preprocessed_as_witness();
143 let (proof_transcripts, accumulators): (Vec<_>, Vec<_>) = snarks
144 .iter()
145 .map(|snark: &Snark| {
146 let protocol = if preprocessed_as_witness {
147 snark.protocol.loaded_preprocessed_as_witness(loader, universality.k_as_witness())
149 } else {
150 snark.protocol.loaded(loader)
151 };
152 let preprocessed = protocol
153 .preprocessed
154 .iter()
155 .flat_map(|preprocessed| {
156 let assigned = preprocessed.assigned();
157 [assigned.x(), assigned.y()]
158 .into_iter()
159 .flat_map(|coordinate| coordinate.limbs().to_vec())
160 .collect_vec()
161 })
162 .chain(
163 protocol.transcript_initial_state.clone().map(|scalar| scalar.into_assigned()),
164 )
165 .collect_vec();
166 let k = protocol
168 .domain_as_witness
169 .as_ref()
170 .map(|domain| domain.k.clone().into_assigned())
171 .unwrap_or_else(|| {
172 loader.ctx_mut().main().load_constant(Fr::from(protocol.domain.k as u64))
173 });
174 let preprocessed_and_k = PreprocessedAndDomainAsWitness { preprocessed, k };
175 preprocessed_witnesses.push(preprocessed_and_k);
176
177 let instances = assign_instances(&snark.instances);
178
179 transcript.new_stream(snark.proof());
182 let proof = PlonkSuccinctVerifier::<AS>::read_proof(
183 svk,
184 &protocol,
185 &instances,
186 &mut transcript,
187 )
188 .unwrap();
189 let accumulator =
190 PlonkSuccinctVerifier::<AS>::verify(svk, &protocol, &instances, &proof).unwrap();
191
192 previous_instances.push(
193 instances.into_iter().flatten().map(|scalar| scalar.into_assigned()).collect(),
194 );
195 let proof_transcript = transcript.loaded_stream.clone();
196 debug_assert_eq!(
197 snark.proof().len(),
198 proof_transcript
199 .iter()
200 .map(|t| match t {
201 TranscriptObject::Scalar(_) => 32,
202 TranscriptObject::EcPoint(_) => 32,
203 })
204 .sum::<usize>()
205 );
206 (proof_transcript, accumulator)
207 })
208 .unzip();
209 let mut accumulators = accumulators.into_iter().flatten().collect_vec();
210
211 let accumulator = if accumulators.len() > 1 {
212 transcript.new_stream(as_proof);
213 let proof = <AS as AccumulationScheme<_, _>>::read_proof(
214 &Default::default(),
215 &accumulators,
216 &mut transcript,
217 )
218 .unwrap();
219 <AS as AccumulationScheme<_, _>>::verify(&Default::default(), &accumulators, &proof)
220 .unwrap()
221 } else {
222 accumulators.pop().unwrap()
223 };
224
225 SnarkAggregationWitness {
226 previous_instances,
227 accumulator,
228 preprocessed: preprocessed_witnesses,
229 proof_transcripts,
230 }
231}
232
233#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize)]
236pub struct AggregationConfigParams {
237 pub degree: u32,
238 pub num_advice: usize,
239 pub num_lookup_advice: usize,
240 pub num_fixed: usize,
241 pub lookup_bits: usize,
242}
243
244impl AggregationConfigParams {
245 pub fn from_path(path: impl AsRef<Path>) -> Self {
246 serde_json::from_reader(File::open(path).expect("Aggregation config path does not exist"))
247 .unwrap()
248 }
249}
250
251impl From<AggregationConfigParams> for BaseCircuitParams {
252 fn from(params: AggregationConfigParams) -> Self {
253 BaseCircuitParams {
254 k: params.degree as usize,
255 num_advice_per_phase: vec![params.num_advice],
256 num_lookup_advice_per_phase: vec![params.num_lookup_advice],
257 num_fixed: params.num_fixed,
258 lookup_bits: Some(params.lookup_bits),
259 num_instance_columns: 1,
260 }
261 }
262}
263
264impl TryFrom<&BaseCircuitParams> for AggregationConfigParams {
265 type Error = &'static str;
266
267 fn try_from(params: &BaseCircuitParams) -> Result<Self, Self::Error> {
268 if params.num_advice_per_phase.iter().skip(1).any(|&n| n != 0) {
269 return Err("AggregationConfigParams only supports 1 phase");
270 }
271 if params.num_lookup_advice_per_phase.iter().skip(1).any(|&n| n != 0) {
272 return Err("AggregationConfigParams only supports 1 phase");
273 }
274 if params.lookup_bits.is_none() {
275 return Err("AggregationConfigParams requires lookup_bits");
276 }
277 if params.num_instance_columns != 1 {
278 return Err("AggregationConfigParams only supports 1 instance column");
279 }
280 Ok(Self {
281 degree: params.k as u32,
282 num_advice: params.num_advice_per_phase[0],
283 num_lookup_advice: params.num_lookup_advice_per_phase[0],
284 num_fixed: params.num_fixed,
285 lookup_bits: params.lookup_bits.unwrap(),
286 })
287 }
288}
289
290impl TryFrom<BaseCircuitParams> for AggregationConfigParams {
291 type Error = &'static str;
292
293 fn try_from(value: BaseCircuitParams) -> Result<Self, Self::Error> {
294 Self::try_from(&value)
295 }
296}
297
298#[derive(Clone, Debug, Getters)]
299pub struct AggregationCircuit {
300 pub builder: BaseCircuitBuilder<Fr>,
302 #[getset(get = "pub")]
307 previous_instances: Vec<Vec<AssignedValue<Fr>>>,
308 #[getset(get = "pub")]
311 preprocessed: Vec<PreprocessedAndDomainAsWitness>,
312 }
315
316pub trait AccumulationSchemeSDK:
318 for<'a> PolynomialCommitmentScheme<
319 G1Affine,
320 Rc<Halo2Loader<'a>>,
321 VerifyingKey = Svk,
322 Output = KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>,
323 > + for<'a> AccumulationScheme<
324 G1Affine,
325 Rc<Halo2Loader<'a>>,
326 Accumulator = KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>,
327 VerifyingKey = KzgAsVerifyingKey,
328 > + PolynomialCommitmentScheme<
329 G1Affine,
330 NativeLoader,
331 VerifyingKey = Svk,
332 Output = KzgAccumulator<G1Affine, NativeLoader>,
333 > + AccumulationScheme<
334 G1Affine,
335 NativeLoader,
336 Accumulator = KzgAccumulator<G1Affine, NativeLoader>,
337 VerifyingKey = KzgAsVerifyingKey,
338 > + AccumulationSchemeProver<G1Affine, ProvingKey = KzgAsProvingKey<G1Affine>>
339{
340}
341
342impl AccumulationSchemeSDK for crate::GWC {}
343impl AccumulationSchemeSDK for crate::SHPLONK {}
344
345#[derive(Clone, Debug)]
348pub struct SnarkAggregationOutput {
349 pub previous_instances: Vec<Vec<AssignedValue<Fr>>>,
350 pub accumulator: Vec<AssignedValue<Fr>>,
351 pub preprocessed: Vec<PreprocessedAndDomainAsWitness>,
354 pub proof_transcripts: Vec<Vec<AssignedTranscriptObject>>,
356}
357
358#[allow(clippy::large_enum_variant)]
359#[derive(Clone, Debug)]
360pub enum AssignedTranscriptObject {
361 Scalar(AssignedValue<Fr>),
362 EcPoint(halo2_ecc::ecc::EcPoint<Fr, ProperCrtUint<Fr>>),
363}
364
365pub fn aggregate_snarks<AS>(
384 pool: &mut SinglePhaseCoreManager<Fr>,
385 range: &RangeChip<Fr>,
386 svk: Svk, snarks: impl IntoIterator<Item = Snark>,
388 universality: VerifierUniversality,
389) -> SnarkAggregationOutput
390where
391 AS: AccumulationSchemeSDK,
392{
393 let snarks = snarks.into_iter().collect_vec();
394
395 let mut transcript_read =
396 PoseidonTranscript::<NativeLoader, &[u8]>::from_spec(&[], POSEIDON_SPEC.clone());
397 let accumulators = snarks
399 .iter()
400 .flat_map(|snark| {
401 transcript_read.new_stream(snark.proof());
402 let proof = PlonkSuccinctVerifier::<AS>::read_proof(
403 &svk,
404 &snark.protocol,
405 &snark.instances,
406 &mut transcript_read,
407 )
408 .unwrap();
409 PlonkSuccinctVerifier::<AS>::verify(&svk, &snark.protocol, &snark.instances, &proof)
410 .unwrap()
411 })
412 .collect_vec();
413
414 let (_accumulator, as_proof) = {
415 let mut transcript_write =
416 PoseidonTranscript::<NativeLoader, Vec<u8>>::from_spec(vec![], POSEIDON_SPEC.clone());
417 let rng = StdRng::from_entropy();
418 let accumulator =
419 AS::create_proof(&Default::default(), &accumulators, &mut transcript_write, rng)
420 .unwrap();
421 (accumulator, transcript_write.finalize())
422 };
423
424 let fp_chip = FpChip::<Fr>::new(range, BITS, LIMBS);
426 let ecc_chip = BaseFieldEccChip::new(&fp_chip);
427 let tmp_pool = mem::take(pool);
430 let loader = Halo2Loader::new(ecc_chip, tmp_pool);
433
434 let SnarkAggregationWitness {
436 previous_instances,
437 accumulator,
438 preprocessed,
439 proof_transcripts,
440 } = aggregate::<AS>(&svk, &loader, &snarks, as_proof.as_slice(), universality);
441 let lhs = accumulator.lhs.assigned();
442 let rhs = accumulator.rhs.assigned();
443 let accumulator = lhs
444 .x()
445 .limbs()
446 .iter()
447 .chain(lhs.y().limbs().iter())
448 .chain(rhs.x().limbs().iter())
449 .chain(rhs.y().limbs().iter())
450 .copied()
451 .collect_vec();
452 let proof_transcripts = proof_transcripts
453 .into_iter()
454 .map(|transcript| {
455 transcript
456 .into_iter()
457 .map(|obj| match obj {
458 TranscriptObject::Scalar(scalar) => {
459 AssignedTranscriptObject::Scalar(scalar.into_assigned())
460 }
461 TranscriptObject::EcPoint(point) => {
462 AssignedTranscriptObject::EcPoint(point.into_assigned())
463 }
464 })
465 .collect()
466 })
467 .collect();
468
469 *pool = loader.take_ctx();
480 SnarkAggregationOutput { previous_instances, accumulator, preprocessed, proof_transcripts }
481}
482
483impl AggregationCircuit {
484 pub fn new<AS>(
498 stage: CircuitBuilderStage,
499 config_params: AggregationConfigParams,
500 params: &ParamsKZG<Bn256>,
501 snarks: impl IntoIterator<Item = Snark>,
502 universality: VerifierUniversality,
503 ) -> Self
504 where
505 AS: AccumulationSchemeSDK,
506 {
507 let svk: Svk = params.get_g()[0].into();
508 let mut builder = BaseCircuitBuilder::from_stage(stage).use_params(config_params.into());
509 let range = builder.range_chip();
510 let SnarkAggregationOutput { previous_instances, accumulator, preprocessed, .. } =
511 aggregate_snarks::<AS>(builder.pool(0), &range, svk, snarks, universality);
512 assert_eq!(
513 builder.assigned_instances.len(),
514 1,
515 "AggregationCircuit must have exactly 1 instance column"
516 );
517 builder.assigned_instances[0] = accumulator;
519 Self { builder, previous_instances, preprocessed }
520 }
521
522 pub fn expose_previous_instances(&mut self, has_prev_accumulator: bool) {
526 let start = (has_prev_accumulator as usize) * 4 * LIMBS;
527 for prev in self.previous_instances.iter() {
528 self.builder.assigned_instances[0].extend_from_slice(&prev[start..]);
529 }
530 }
531
532 pub fn lookup_bits(&self) -> usize {
534 self.builder.config_params.lookup_bits.unwrap()
535 }
536
537 pub fn set_params(&mut self, params: AggregationConfigParams) {
539 self.builder.set_params(params.into());
540 }
541
542 pub fn use_params(mut self, params: AggregationConfigParams) -> Self {
544 self.set_params(params);
545 self
546 }
547
548 pub fn break_points(&self) -> MultiPhaseThreadBreakPoints {
550 self.builder.break_points()
551 }
552
553 pub fn set_break_points(&mut self, break_points: MultiPhaseThreadBreakPoints) {
555 self.builder.set_break_points(break_points);
556 }
557
558 pub fn use_break_points(mut self, break_points: MultiPhaseThreadBreakPoints) -> Self {
560 self.set_break_points(break_points);
561 self
562 }
563
564 pub fn calculate_params(&mut self, minimum_rows: Option<usize>) -> AggregationConfigParams {
566 self.builder.calculate_params(minimum_rows).try_into().unwrap()
567 }
568}
569
570impl<F: ScalarField> CircuitExt<F> for BaseCircuitBuilder<F> {
571 fn num_instance(&self) -> Vec<usize> {
572 self.assigned_instances.iter().map(|instances| instances.len()).collect()
573 }
574
575 fn instances(&self) -> Vec<Vec<F>> {
576 self.assigned_instances
577 .iter()
578 .map(|instances| instances.iter().map(|v| *v.value()).collect())
579 .collect()
580 }
581
582 fn selectors(config: &Self::Config) -> Vec<Selector> {
583 config.gate().basic_gates[0].iter().map(|gate| gate.q_enable).collect()
584 }
585}
586
587impl Circuit<Fr> for AggregationCircuit {
588 type Config = BaseConfig<Fr>;
589 type FloorPlanner = SimpleFloorPlanner;
590 type Params = AggregationConfigParams;
591
592 fn params(&self) -> Self::Params {
593 (&self.builder.config_params).try_into().unwrap()
594 }
595
596 fn without_witnesses(&self) -> Self {
597 unimplemented!()
598 }
599
600 fn configure_with_params(
601 meta: &mut ConstraintSystem<Fr>,
602 params: Self::Params,
603 ) -> Self::Config {
604 BaseCircuitBuilder::configure_with_params(meta, params.into())
605 }
606
607 fn configure(_: &mut ConstraintSystem<Fr>) -> Self::Config {
608 unreachable!()
609 }
610
611 fn synthesize(
612 &self,
613 config: Self::Config,
614 layouter: impl Layouter<Fr>,
615 ) -> Result<(), plonk::Error> {
616 self.builder.synthesize(config, layouter)
617 }
618}
619
620impl CircuitExt<Fr> for AggregationCircuit {
621 fn num_instance(&self) -> Vec<usize> {
622 self.builder.num_instance()
623 }
624
625 fn instances(&self) -> Vec<Vec<Fr>> {
626 self.builder.instances()
627 }
628
629 fn accumulator_indices() -> Option<Vec<(usize, usize)>> {
630 Some((0..4 * LIMBS).map(|idx| (0, idx)).collect())
631 }
632
633 fn selectors(config: &Self::Config) -> Vec<Selector> {
634 BaseCircuitBuilder::selectors(config)
635 }
636}
637
638pub fn load_verify_circuit_degree() -> u32 {
639 let path = std::env::var("VERIFY_CONFIG")
640 .unwrap_or_else(|_| "./configs/verify_circuit.config".to_string());
641 let params: AggregationConfigParams = serde_json::from_reader(
642 File::open(path.as_str()).unwrap_or_else(|_| panic!("{path} does not exist")),
643 )
644 .unwrap();
645 params.degree
646}