openvm_native_recursion/halo2/
wrapper.rs1use itertools::Itertools;
2use openvm_stark_backend::p3_util::log2_ceil_usize;
3use serde::{Deserialize, Serialize};
4use serde_with::serde_as;
5#[cfg(feature = "evm-prove")]
6use snark_verifier_sdk::snark_verifier::{
7 halo2_base::halo2_proofs::{halo2curves::bn256::G1Affine, plonk::VerifyingKey},
8 loader::evm::compile_solidity,
9};
10use snark_verifier_sdk::{
11 halo2::aggregation::{AggregationCircuit, AggregationConfigParams, VerifierUniversality},
12 snark_verifier::halo2_base::{
13 gates::circuit::CircuitBuilderStage,
14 halo2_proofs::{plonk::keygen_pk2, poly::commitment::Params},
15 },
16 CircuitExt, Snark, SHPLONK,
17};
18
19#[cfg(feature = "evm-prove")]
20use crate::halo2::RawEvmProof;
21use crate::halo2::{
22 utils::{Halo2ParamsReader, KZG_PARAMS_FOR_SVK},
23 Halo2Params, Halo2ProvingMetadata, Halo2ProvingPinning,
24};
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct FallbackEvmVerifier {
30 pub sol_code: String,
31 pub artifact: EvmVerifierByteCode,
32}
33
34#[serde_as]
35#[derive(Clone, Debug, Deserialize, Serialize)]
36pub struct EvmVerifierByteCode {
37 pub sol_compiler_version: String,
38 pub sol_compiler_options: String,
39 #[serde_as(as = "serde_with::hex::Hex")]
40 pub bytecode: Vec<u8>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Halo2WrapperProvingKey {
45 pub pinning: Halo2ProvingPinning,
46}
47
48const MIN_ROWS: usize = 20;
49
50impl Halo2WrapperProvingKey {
51 pub fn keygen_auto_tune(reader: &impl Halo2ParamsReader, dummy_snark: Snark) -> Self {
53 let k = Self::select_k(dummy_snark.clone());
54 tracing::info!("Selected k: {}", k);
55 let params = reader.read_params(k);
56 Self::keygen(¶ms, dummy_snark)
57 }
58 pub fn keygen(params: &Halo2Params, dummy_snark: Snark) -> Self {
59 let k = params.k();
60 #[cfg(feature = "metrics")]
61 let start = std::time::Instant::now();
62 let mut circuit =
63 generate_wrapper_circuit_object(CircuitBuilderStage::Keygen, k as usize, dummy_snark);
64 circuit.calculate_params(Some(MIN_ROWS));
65 let config_params = circuit.builder.config_params.clone();
66 tracing::info!(
67 "Wrapper circuit num advice: {:?}",
68 config_params.num_advice_per_phase
69 );
70 #[cfg(feature = "metrics")]
71 emit_wrapper_circuit_metrics(&circuit);
72 let pk = keygen_pk2(params, &circuit, false).unwrap();
73 let num_pvs = circuit.instances().iter().map(|x| x.len()).collect_vec();
74 #[cfg(feature = "metrics")]
75 metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64);
76 Self {
77 pinning: Halo2ProvingPinning {
78 pk,
79 metadata: Halo2ProvingMetadata {
80 config_params,
81 break_points: circuit.break_points(),
82 num_pvs,
83 },
84 },
85 }
86 }
87 #[cfg(feature = "evm-verify")]
88 pub fn evm_verify(
90 evm_verifier: &FallbackEvmVerifier,
91 evm_proof: &RawEvmProof,
92 ) -> Result<u64, String> {
93 snark_verifier_sdk::evm::evm_verify(
94 evm_verifier.artifact.bytecode.clone(),
95 vec![evm_proof.instances.clone()],
96 evm_proof.proof.clone(),
97 )
98 }
99 #[cfg(feature = "evm-prove")]
100 pub fn generate_fallback_evm_verifier(&self, params: &Halo2Params) -> FallbackEvmVerifier {
102 assert_eq!(
103 self.pinning.metadata.config_params.k as u32,
104 params.k(),
105 "Provided params don't match circuit config"
106 );
107 gen_evm_verifier(
108 params,
109 self.pinning.pk.get_vk(),
110 self.pinning.metadata.num_pvs.clone(),
111 )
112 }
113 #[cfg(feature = "evm-prove")]
114 pub fn prove_for_evm(&self, params: &Halo2Params, snark_to_verify: Snark) -> RawEvmProof {
115 #[cfg(feature = "metrics")]
116 let start = std::time::Instant::now();
117 let k = self.pinning.metadata.config_params.k;
118 let prover_circuit = self.generate_circuit_object_for_proving(k, snark_to_verify);
119 let mut pvs = prover_circuit.instances();
120 assert_eq!(pvs.len(), 1);
121 let proof = snark_verifier_sdk::evm::gen_evm_proof_shplonk(
122 params,
123 &self.pinning.pk,
124 prover_circuit,
125 pvs.clone(),
126 );
127 #[cfg(feature = "metrics")]
128 metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64);
129
130 RawEvmProof {
131 instances: pvs.pop().unwrap(),
132 proof,
133 }
134 }
135
136 #[cfg(feature = "evm-prove")]
137 fn generate_circuit_object_for_proving(
138 &self,
139 k: usize,
140 snark_to_verify: Snark,
141 ) -> AggregationCircuit {
142 assert_eq!(
143 snark_to_verify.instances.len(),
144 1,
145 "Snark should only have 1 instance column"
146 );
147 assert_eq!(
148 self.pinning.metadata.num_pvs[0],
149 snark_to_verify.instances[0].len() + 12,
150 );
151 generate_wrapper_circuit_object(CircuitBuilderStage::Prover, k, snark_to_verify)
152 .use_params(
153 self.pinning
154 .metadata
155 .config_params
156 .clone()
157 .try_into()
158 .unwrap(),
159 )
160 .use_break_points(self.pinning.metadata.break_points.clone())
161 }
162
163 pub(crate) fn select_k(dummy_snark: Snark) -> usize {
164 let mut k = 20;
165 let mut first_run = true;
166 loop {
167 let mut circuit = generate_wrapper_circuit_object(
168 CircuitBuilderStage::Keygen,
169 k,
170 dummy_snark.clone(),
171 );
172 circuit.calculate_params(Some(MIN_ROWS));
173 assert_eq!(
174 circuit.builder.config_params.num_advice_per_phase.len(),
175 1,
176 "Snark has multiple phases"
177 );
178 if circuit.builder.config_params.num_advice_per_phase[0] == 1 {
179 circuit.builder.clear();
180 break;
181 }
182 if first_run {
183 k = log2_ceil_usize(
184 circuit.builder.statistics().gate.total_advice_per_phase[0] + MIN_ROWS,
185 );
186 } else {
187 k += 1;
188 }
189 first_run = false;
190 circuit.builder.clear();
192 }
193 k
194 }
195}
196
197fn generate_wrapper_circuit_object(
198 stage: CircuitBuilderStage,
199 k: usize,
200 snark: Snark,
201) -> AggregationCircuit {
202 let config_params = AggregationConfigParams {
203 degree: k as u32,
204 lookup_bits: k - 1,
205 ..Default::default()
206 };
207 let mut circuit = AggregationCircuit::new::<SHPLONK>(
208 stage,
209 config_params,
210 &KZG_PARAMS_FOR_SVK,
211 [snark],
212 VerifierUniversality::None,
213 );
214 circuit.expose_previous_instances(false);
215 circuit
216}
217
218#[cfg(feature = "metrics")]
219fn emit_wrapper_circuit_metrics(agg_circuit: &AggregationCircuit) {
220 let stats = agg_circuit.builder.statistics();
221 let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum();
222 let total_lookups: usize = stats.total_lookup_advice_per_phase.into_iter().sum();
223 let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
224 metrics::gauge!("halo2_total_cells").set(total_cell as f64);
225}
226
227#[cfg(feature = "evm-prove")]
228fn gen_evm_verifier(
229 params: &Halo2Params,
230 vk: &VerifyingKey<G1Affine>,
231 num_instance: Vec<usize>,
232) -> FallbackEvmVerifier {
233 let sol_code = snark_verifier_sdk::evm::gen_evm_verifier_sol_code::<AggregationCircuit, SHPLONK>(
234 params,
235 vk,
236 num_instance,
237 );
238 let byte_code = compile_solidity(&sol_code);
239 FallbackEvmVerifier {
240 sol_code,
241 artifact: EvmVerifierByteCode {
242 sol_compiler_version: "0.8.19".to_string(),
243 sol_compiler_options: "".to_string(),
244 bytecode: byte_code,
245 },
246 }
247}