openvm_native_recursion/halo2/
wrapper.rs

1use itertools::Itertools;
2use openvm_stark_backend::p3_util::log2_ceil_usize;
3use serde::{Deserialize, Serialize};
4use serde_with::serde_as;
5use snark_verifier_sdk::{
6    evm::{evm_verify, gen_evm_proof_shplonk, gen_evm_verifier_sol_code},
7    halo2::aggregation::{AggregationCircuit, AggregationConfigParams, VerifierUniversality},
8    snark_verifier::{
9        halo2_base::{
10            gates::circuit::{
11                CircuitBuilderStage,
12                CircuitBuilderStage::{Keygen, Prover},
13            },
14            halo2_proofs::{
15                halo2curves::bn256::G1Affine,
16                plonk::{keygen_pk2, VerifyingKey},
17                poly::commitment::Params,
18            },
19        },
20        loader::evm::compile_solidity,
21    },
22    CircuitExt, Snark, SHPLONK,
23};
24
25use crate::halo2::{
26    utils::{Halo2ParamsReader, KZG_PARAMS_FOR_SVK},
27    Halo2Params, Halo2ProvingMetadata, Halo2ProvingPinning, RawEvmProof,
28};
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct EvmVerifier {
32    pub sol_code: String,
33    pub artifact: EvmVerifierByteCode,
34}
35
36#[serde_as]
37#[derive(Clone, Debug, Deserialize, Serialize)]
38pub struct EvmVerifierByteCode {
39    pub sol_compiler_version: String,
40    pub sol_compiler_options: String,
41    #[serde_as(as = "serde_with::hex::Hex")]
42    pub bytecode: Vec<u8>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Halo2WrapperProvingKey {
47    pub pinning: Halo2ProvingPinning,
48}
49
50const MIN_ROWS: usize = 20;
51
52impl Halo2WrapperProvingKey {
53    /// Auto select k to let Wrapper circuit only have 1 advice column.
54    pub fn keygen_auto_tune(reader: &impl Halo2ParamsReader, dummy_snark: Snark) -> Self {
55        let k = Self::select_k(dummy_snark.clone());
56        tracing::info!("Selected k: {}", k);
57        let params = reader.read_params(k);
58        Self::keygen(&params, dummy_snark)
59    }
60    pub fn keygen(params: &Halo2Params, dummy_snark: Snark) -> Self {
61        let k = params.k();
62        #[cfg(feature = "bench-metrics")]
63        let start = std::time::Instant::now();
64        let mut circuit = generate_wrapper_circuit_object(Keygen, k as usize, dummy_snark);
65        circuit.calculate_params(Some(MIN_ROWS));
66        let config_params = circuit.builder.config_params.clone();
67        tracing::info!(
68            "Wrapper circuit num advice: {:?}",
69            config_params.num_advice_per_phase
70        );
71        #[cfg(feature = "bench-metrics")]
72        emit_wrapper_circuit_metrics(&circuit);
73        let pk = keygen_pk2(params, &circuit, false).unwrap();
74        let num_pvs = circuit.instances().iter().map(|x| x.len()).collect_vec();
75        #[cfg(feature = "bench-metrics")]
76        metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64);
77        Self {
78            pinning: Halo2ProvingPinning {
79                pk,
80                metadata: Halo2ProvingMetadata {
81                    config_params,
82                    break_points: circuit.break_points(),
83                    num_pvs,
84                },
85            },
86        }
87    }
88    /// A helper function for testing to verify the proof of this circuit with evm verifier.
89    pub fn evm_verify(evm_verifier: &EvmVerifier, evm_proof: &RawEvmProof) -> Result<u64, String> {
90        evm_verify(
91            evm_verifier.artifact.bytecode.clone(),
92            vec![evm_proof.instances.clone()],
93            evm_proof.proof.clone(),
94        )
95    }
96    /// Return deployment code for EVM verifier which can verify the snark of this circuit.
97    pub fn generate_evm_verifier(&self, params: &Halo2Params) -> EvmVerifier {
98        assert_eq!(
99            self.pinning.metadata.config_params.k as u32,
100            params.k(),
101            "Provided params don't match circuit config"
102        );
103        gen_evm_verifier(
104            params,
105            self.pinning.pk.get_vk(),
106            self.pinning.metadata.num_pvs.clone(),
107        )
108    }
109    pub fn prove_for_evm(&self, params: &Halo2Params, snark_to_verify: Snark) -> RawEvmProof {
110        #[cfg(feature = "bench-metrics")]
111        let start = std::time::Instant::now();
112        let k = self.pinning.metadata.config_params.k;
113        let prover_circuit = self.generate_circuit_object_for_proving(k, snark_to_verify);
114        let mut pvs = prover_circuit.instances();
115        assert_eq!(pvs.len(), 1);
116        let proof = gen_evm_proof_shplonk(params, &self.pinning.pk, prover_circuit, pvs.clone());
117        #[cfg(feature = "bench-metrics")]
118        metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64);
119
120        RawEvmProof {
121            instances: pvs.pop().unwrap(),
122            proof,
123        }
124    }
125    fn generate_circuit_object_for_proving(
126        &self,
127        k: usize,
128        snark_to_verify: Snark,
129    ) -> AggregationCircuit {
130        assert_eq!(
131            snark_to_verify.instances.len(),
132            1,
133            "Snark should only have 1 instance column"
134        );
135        assert_eq!(
136            self.pinning.metadata.num_pvs[0],
137            snark_to_verify.instances[0].len() + 12,
138        );
139        generate_wrapper_circuit_object(Prover, k, snark_to_verify)
140            .use_params(
141                self.pinning
142                    .metadata
143                    .config_params
144                    .clone()
145                    .try_into()
146                    .unwrap(),
147            )
148            .use_break_points(self.pinning.metadata.break_points.clone())
149    }
150
151    pub(crate) fn select_k(dummy_snark: Snark) -> usize {
152        let mut k = 20;
153        let mut first_run = true;
154        loop {
155            let mut circuit = generate_wrapper_circuit_object(Keygen, k, dummy_snark.clone());
156            circuit.calculate_params(Some(MIN_ROWS));
157            assert_eq!(
158                circuit.builder.config_params.num_advice_per_phase.len(),
159                1,
160                "Snark has multiple phases"
161            );
162            if circuit.builder.config_params.num_advice_per_phase[0] == 1 {
163                break;
164            }
165            if first_run {
166                k = log2_ceil_usize(
167                    circuit.builder.statistics().gate.total_advice_per_phase[0] + MIN_ROWS,
168                );
169            } else {
170                k += 1;
171            }
172            first_run = false;
173        }
174        k
175    }
176}
177
178fn generate_wrapper_circuit_object(
179    stage: CircuitBuilderStage,
180    k: usize,
181    snark: Snark,
182) -> AggregationCircuit {
183    let config_params = AggregationConfigParams {
184        degree: k as u32,
185        lookup_bits: k - 1,
186        ..Default::default()
187    };
188    let mut circuit = AggregationCircuit::new::<SHPLONK>(
189        stage,
190        config_params,
191        &KZG_PARAMS_FOR_SVK,
192        [snark],
193        VerifierUniversality::None,
194    );
195    circuit.expose_previous_instances(false);
196    circuit
197}
198
199#[cfg(feature = "bench-metrics")]
200fn emit_wrapper_circuit_metrics(agg_circuit: &AggregationCircuit) {
201    let stats = agg_circuit.builder.statistics();
202    let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum();
203    let total_lookups: usize = stats.total_lookup_advice_per_phase.into_iter().sum();
204    let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
205    metrics::gauge!("halo2_total_cells").set(total_cell as f64);
206}
207
208fn gen_evm_verifier(
209    params: &Halo2Params,
210    vk: &VerifyingKey<G1Affine>,
211    num_instance: Vec<usize>,
212) -> EvmVerifier {
213    let sol_code =
214        gen_evm_verifier_sol_code::<AggregationCircuit, SHPLONK>(params, vk, num_instance);
215    let byte_code = compile_solidity(&sol_code);
216    EvmVerifier {
217        sol_code,
218        artifact: EvmVerifierByteCode {
219            sol_compiler_version: "0.8.19".to_string(),
220            sol_compiler_options: "".to_string(),
221            bytecode: byte_code,
222        },
223    }
224}