openvm_native_recursion/halo2/
wrapper.rs

1use itertools::Itertools;
2use openvm_stark_backend::p3_util::log2_ceil_usize;
3use serde::{Deserialize, Serialize};
4use serde_with::serde_as;
5#[cfg(feature = "evm-prove")]
6use snark_verifier_sdk::snark_verifier::{
7    halo2_base::halo2_proofs::{halo2curves::bn256::G1Affine, plonk::VerifyingKey},
8    loader::evm::compile_solidity,
9};
10use snark_verifier_sdk::{
11    halo2::aggregation::{AggregationCircuit, AggregationConfigParams, VerifierUniversality},
12    snark_verifier::halo2_base::{
13        gates::circuit::CircuitBuilderStage,
14        halo2_proofs::{plonk::keygen_pk2, poly::commitment::Params},
15    },
16    CircuitExt, Snark, SHPLONK,
17};
18
19#[cfg(feature = "evm-prove")]
20use crate::halo2::RawEvmProof;
21use crate::halo2::{
22    utils::{Halo2ParamsReader, KZG_PARAMS_FOR_SVK},
23    Halo2Params, Halo2ProvingMetadata, Halo2ProvingPinning,
24};
25
26/// `FallbackEvmVerifier` is for the raw verifier contract outputted by
27/// `snark-verifier`.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct FallbackEvmVerifier {
30    pub sol_code: String,
31    pub artifact: EvmVerifierByteCode,
32}
33
34#[serde_as]
35#[derive(Clone, Debug, Deserialize, Serialize)]
36pub struct EvmVerifierByteCode {
37    pub sol_compiler_version: String,
38    pub sol_compiler_options: String,
39    #[serde_as(as = "serde_with::hex::Hex")]
40    pub bytecode: Vec<u8>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Halo2WrapperProvingKey {
45    pub pinning: Halo2ProvingPinning,
46}
47
48const MIN_ROWS: usize = 20;
49
50impl Halo2WrapperProvingKey {
51    /// Auto select k to let Wrapper circuit only have 1 advice column.
52    pub fn keygen_auto_tune(reader: &impl Halo2ParamsReader, dummy_snark: Snark) -> Self {
53        let k = Self::select_k(dummy_snark.clone());
54        tracing::info!("Selected k: {}", k);
55        let params = reader.read_params(k);
56        Self::keygen(&params, dummy_snark)
57    }
58    pub fn keygen(params: &Halo2Params, dummy_snark: Snark) -> Self {
59        let k = params.k();
60        #[cfg(feature = "metrics")]
61        let start = std::time::Instant::now();
62        let mut circuit =
63            generate_wrapper_circuit_object(CircuitBuilderStage::Keygen, k as usize, dummy_snark);
64        circuit.calculate_params(Some(MIN_ROWS));
65        let config_params = circuit.builder.config_params.clone();
66        tracing::info!(
67            "Wrapper circuit num advice: {:?}",
68            config_params.num_advice_per_phase
69        );
70        #[cfg(feature = "metrics")]
71        emit_wrapper_circuit_metrics(&circuit);
72        let pk = keygen_pk2(params, &circuit, false).unwrap();
73        let num_pvs = circuit.instances().iter().map(|x| x.len()).collect_vec();
74        #[cfg(feature = "metrics")]
75        metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64);
76        Self {
77            pinning: Halo2ProvingPinning {
78                pk,
79                metadata: Halo2ProvingMetadata {
80                    config_params,
81                    break_points: circuit.break_points(),
82                    num_pvs,
83                },
84            },
85        }
86    }
87    #[cfg(feature = "evm-verify")]
88    /// A helper function for testing to verify the proof of this circuit with evm verifier.
89    pub fn evm_verify(
90        evm_verifier: &FallbackEvmVerifier,
91        evm_proof: &RawEvmProof,
92    ) -> Result<u64, String> {
93        snark_verifier_sdk::evm::evm_verify(
94            evm_verifier.artifact.bytecode.clone(),
95            vec![evm_proof.instances.clone()],
96            evm_proof.proof.clone(),
97        )
98    }
99    #[cfg(feature = "evm-prove")]
100    /// Return deployment code for EVM verifier which can verify the snark of this circuit.
101    pub fn generate_fallback_evm_verifier(&self, params: &Halo2Params) -> FallbackEvmVerifier {
102        assert_eq!(
103            self.pinning.metadata.config_params.k as u32,
104            params.k(),
105            "Provided params don't match circuit config"
106        );
107        gen_evm_verifier(
108            params,
109            self.pinning.pk.get_vk(),
110            self.pinning.metadata.num_pvs.clone(),
111        )
112    }
113    #[cfg(feature = "evm-prove")]
114    pub fn prove_for_evm(&self, params: &Halo2Params, snark_to_verify: Snark) -> RawEvmProof {
115        #[cfg(feature = "metrics")]
116        let start = std::time::Instant::now();
117        let k = self.pinning.metadata.config_params.k;
118        let prover_circuit = self.generate_circuit_object_for_proving(k, snark_to_verify);
119        let mut pvs = prover_circuit.instances();
120        assert_eq!(pvs.len(), 1);
121        let proof = snark_verifier_sdk::evm::gen_evm_proof_shplonk(
122            params,
123            &self.pinning.pk,
124            prover_circuit,
125            pvs.clone(),
126        );
127        #[cfg(feature = "metrics")]
128        metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64);
129
130        RawEvmProof {
131            instances: pvs.pop().unwrap(),
132            proof,
133        }
134    }
135
136    #[cfg(feature = "evm-prove")]
137    fn generate_circuit_object_for_proving(
138        &self,
139        k: usize,
140        snark_to_verify: Snark,
141    ) -> AggregationCircuit {
142        assert_eq!(
143            snark_to_verify.instances.len(),
144            1,
145            "Snark should only have 1 instance column"
146        );
147        assert_eq!(
148            self.pinning.metadata.num_pvs[0],
149            snark_to_verify.instances[0].len() + 12,
150        );
151        generate_wrapper_circuit_object(CircuitBuilderStage::Prover, k, snark_to_verify)
152            .use_params(
153                self.pinning
154                    .metadata
155                    .config_params
156                    .clone()
157                    .try_into()
158                    .unwrap(),
159            )
160            .use_break_points(self.pinning.metadata.break_points.clone())
161    }
162
163    pub(crate) fn select_k(dummy_snark: Snark) -> usize {
164        let mut k = 20;
165        let mut first_run = true;
166        loop {
167            let mut circuit = generate_wrapper_circuit_object(
168                CircuitBuilderStage::Keygen,
169                k,
170                dummy_snark.clone(),
171            );
172            circuit.calculate_params(Some(MIN_ROWS));
173            assert_eq!(
174                circuit.builder.config_params.num_advice_per_phase.len(),
175                1,
176                "Snark has multiple phases"
177            );
178            if circuit.builder.config_params.num_advice_per_phase[0] == 1 {
179                circuit.builder.clear();
180                break;
181            }
182            if first_run {
183                k = log2_ceil_usize(
184                    circuit.builder.statistics().gate.total_advice_per_phase[0] + MIN_ROWS,
185                );
186            } else {
187                k += 1;
188            }
189            first_run = false;
190            // Prevent drop warnings
191            circuit.builder.clear();
192        }
193        k
194    }
195}
196
197fn generate_wrapper_circuit_object(
198    stage: CircuitBuilderStage,
199    k: usize,
200    snark: Snark,
201) -> AggregationCircuit {
202    let config_params = AggregationConfigParams {
203        degree: k as u32,
204        lookup_bits: k - 1,
205        ..Default::default()
206    };
207    let mut circuit = AggregationCircuit::new::<SHPLONK>(
208        stage,
209        config_params,
210        &KZG_PARAMS_FOR_SVK,
211        [snark],
212        VerifierUniversality::None,
213    );
214    circuit.expose_previous_instances(false);
215    circuit
216}
217
218#[cfg(feature = "metrics")]
219fn emit_wrapper_circuit_metrics(agg_circuit: &AggregationCircuit) {
220    let stats = agg_circuit.builder.statistics();
221    let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum();
222    let total_lookups: usize = stats.total_lookup_advice_per_phase.into_iter().sum();
223    let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
224    metrics::gauge!("halo2_total_cells").set(total_cell as f64);
225}
226
227#[cfg(feature = "evm-prove")]
228fn gen_evm_verifier(
229    params: &Halo2Params,
230    vk: &VerifyingKey<G1Affine>,
231    num_instance: Vec<usize>,
232) -> FallbackEvmVerifier {
233    let sol_code = snark_verifier_sdk::evm::gen_evm_verifier_sol_code::<AggregationCircuit, SHPLONK>(
234        params,
235        vk,
236        num_instance,
237    );
238    let byte_code = compile_solidity(&sol_code);
239    FallbackEvmVerifier {
240        sol_code,
241        artifact: EvmVerifierByteCode {
242            sol_compiler_version: "0.8.19".to_string(),
243            sol_compiler_options: "".to_string(),
244            bytecode: byte_code,
245        },
246    }
247}