openvm_native_recursion/halo2/
wrapper.rs
1use itertools::Itertools;
2use openvm_stark_backend::p3_util::log2_ceil_usize;
3use serde::{Deserialize, Serialize};
4use serde_with::serde_as;
5use snark_verifier_sdk::{
6 evm::{evm_verify, gen_evm_proof_shplonk, gen_evm_verifier_sol_code},
7 halo2::aggregation::{AggregationCircuit, AggregationConfigParams, VerifierUniversality},
8 snark_verifier::{
9 halo2_base::{
10 gates::circuit::{
11 CircuitBuilderStage,
12 CircuitBuilderStage::{Keygen, Prover},
13 },
14 halo2_proofs::{
15 halo2curves::bn256::G1Affine,
16 plonk::{keygen_pk2, VerifyingKey},
17 poly::commitment::Params,
18 },
19 },
20 loader::evm::compile_solidity,
21 },
22 CircuitExt, Snark, SHPLONK,
23};
24
25use crate::halo2::{
26 utils::{Halo2ParamsReader, KZG_PARAMS_FOR_SVK},
27 Halo2Params, Halo2ProvingMetadata, Halo2ProvingPinning, RawEvmProof,
28};
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct EvmVerifier {
32 pub sol_code: String,
33 pub artifact: EvmVerifierByteCode,
34}
35
36#[serde_as]
37#[derive(Clone, Debug, Deserialize, Serialize)]
38pub struct EvmVerifierByteCode {
39 pub sol_compiler_version: String,
40 pub sol_compiler_options: String,
41 #[serde_as(as = "serde_with::hex::Hex")]
42 pub bytecode: Vec<u8>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Halo2WrapperProvingKey {
47 pub pinning: Halo2ProvingPinning,
48}
49
50const MIN_ROWS: usize = 20;
51
52impl Halo2WrapperProvingKey {
53 pub fn keygen_auto_tune(reader: &impl Halo2ParamsReader, dummy_snark: Snark) -> Self {
55 let k = Self::select_k(dummy_snark.clone());
56 tracing::info!("Selected k: {}", k);
57 let params = reader.read_params(k);
58 Self::keygen(¶ms, dummy_snark)
59 }
60 pub fn keygen(params: &Halo2Params, dummy_snark: Snark) -> Self {
61 let k = params.k();
62 #[cfg(feature = "bench-metrics")]
63 let start = std::time::Instant::now();
64 let mut circuit = generate_wrapper_circuit_object(Keygen, k as usize, dummy_snark);
65 circuit.calculate_params(Some(MIN_ROWS));
66 let config_params = circuit.builder.config_params.clone();
67 tracing::info!(
68 "Wrapper circuit num advice: {:?}",
69 config_params.num_advice_per_phase
70 );
71 #[cfg(feature = "bench-metrics")]
72 emit_wrapper_circuit_metrics(&circuit);
73 let pk = keygen_pk2(params, &circuit, false).unwrap();
74 let num_pvs = circuit.instances().iter().map(|x| x.len()).collect_vec();
75 #[cfg(feature = "bench-metrics")]
76 metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64);
77 Self {
78 pinning: Halo2ProvingPinning {
79 pk,
80 metadata: Halo2ProvingMetadata {
81 config_params,
82 break_points: circuit.break_points(),
83 num_pvs,
84 },
85 },
86 }
87 }
88 pub fn evm_verify(evm_verifier: &EvmVerifier, evm_proof: &RawEvmProof) -> Result<u64, String> {
90 evm_verify(
91 evm_verifier.artifact.bytecode.clone(),
92 vec![evm_proof.instances.clone()],
93 evm_proof.proof.clone(),
94 )
95 }
96 pub fn generate_evm_verifier(&self, params: &Halo2Params) -> EvmVerifier {
98 assert_eq!(
99 self.pinning.metadata.config_params.k as u32,
100 params.k(),
101 "Provided params don't match circuit config"
102 );
103 gen_evm_verifier(
104 params,
105 self.pinning.pk.get_vk(),
106 self.pinning.metadata.num_pvs.clone(),
107 )
108 }
109 pub fn prove_for_evm(&self, params: &Halo2Params, snark_to_verify: Snark) -> RawEvmProof {
110 #[cfg(feature = "bench-metrics")]
111 let start = std::time::Instant::now();
112 let k = self.pinning.metadata.config_params.k;
113 let prover_circuit = self.generate_circuit_object_for_proving(k, snark_to_verify);
114 let mut pvs = prover_circuit.instances();
115 assert_eq!(pvs.len(), 1);
116 let proof = gen_evm_proof_shplonk(params, &self.pinning.pk, prover_circuit, pvs.clone());
117 #[cfg(feature = "bench-metrics")]
118 metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64);
119
120 RawEvmProof {
121 instances: pvs.pop().unwrap(),
122 proof,
123 }
124 }
125 fn generate_circuit_object_for_proving(
126 &self,
127 k: usize,
128 snark_to_verify: Snark,
129 ) -> AggregationCircuit {
130 assert_eq!(
131 snark_to_verify.instances.len(),
132 1,
133 "Snark should only have 1 instance column"
134 );
135 assert_eq!(
136 self.pinning.metadata.num_pvs[0],
137 snark_to_verify.instances[0].len() + 12,
138 );
139 generate_wrapper_circuit_object(Prover, k, snark_to_verify)
140 .use_params(
141 self.pinning
142 .metadata
143 .config_params
144 .clone()
145 .try_into()
146 .unwrap(),
147 )
148 .use_break_points(self.pinning.metadata.break_points.clone())
149 }
150
151 pub(crate) fn select_k(dummy_snark: Snark) -> usize {
152 let mut k = 20;
153 let mut first_run = true;
154 loop {
155 let mut circuit = generate_wrapper_circuit_object(Keygen, k, dummy_snark.clone());
156 circuit.calculate_params(Some(MIN_ROWS));
157 assert_eq!(
158 circuit.builder.config_params.num_advice_per_phase.len(),
159 1,
160 "Snark has multiple phases"
161 );
162 if circuit.builder.config_params.num_advice_per_phase[0] == 1 {
163 break;
164 }
165 if first_run {
166 k = log2_ceil_usize(
167 circuit.builder.statistics().gate.total_advice_per_phase[0] + MIN_ROWS,
168 );
169 } else {
170 k += 1;
171 }
172 first_run = false;
173 }
174 k
175 }
176}
177
178fn generate_wrapper_circuit_object(
179 stage: CircuitBuilderStage,
180 k: usize,
181 snark: Snark,
182) -> AggregationCircuit {
183 let config_params = AggregationConfigParams {
184 degree: k as u32,
185 lookup_bits: k - 1,
186 ..Default::default()
187 };
188 let mut circuit = AggregationCircuit::new::<SHPLONK>(
189 stage,
190 config_params,
191 &KZG_PARAMS_FOR_SVK,
192 [snark],
193 VerifierUniversality::None,
194 );
195 circuit.expose_previous_instances(false);
196 circuit
197}
198
199#[cfg(feature = "bench-metrics")]
200fn emit_wrapper_circuit_metrics(agg_circuit: &AggregationCircuit) {
201 let stats = agg_circuit.builder.statistics();
202 let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum();
203 let total_lookups: usize = stats.total_lookup_advice_per_phase.into_iter().sum();
204 let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
205 metrics::gauge!("halo2_total_cells").set(total_cell as f64);
206}
207
208fn gen_evm_verifier(
209 params: &Halo2Params,
210 vk: &VerifyingKey<G1Affine>,
211 num_instance: Vec<usize>,
212) -> EvmVerifier {
213 let sol_code =
214 gen_evm_verifier_sol_code::<AggregationCircuit, SHPLONK>(params, vk, num_instance);
215 let byte_code = compile_solidity(&sol_code);
216 EvmVerifier {
217 sol_code,
218 artifact: EvmVerifierByteCode {
219 sol_compiler_version: "0.8.19".to_string(),
220 sol_compiler_options: "".to_string(),
221 bytecode: byte_code,
222 },
223 }
224}