openvm_native_recursion/halo2/
mod.rs

1pub mod utils;
2pub mod verifier;
3
4pub mod testing_utils;
5#[cfg(test)]
6mod tests;
7pub mod wrapper;
8
9use std::fmt::Debug;
10
11use itertools::Itertools;
12use openvm_native_compiler::{
13    constraints::halo2::compiler::{Halo2ConstraintCompiler, Halo2State},
14    ir::{Config, DslIr, TracedVec, Witness},
15};
16use openvm_stark_backend::p3_field::extension::BinomialExtensionField;
17use openvm_stark_sdk::{p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr};
18use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
19use snark_verifier_sdk::{
20    evm::encode_calldata,
21    halo2::{gen_dummy_snark_from_vk, gen_snark_shplonk},
22    snark_verifier::halo2_base::{
23        gates::{
24            circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage},
25            flex_gate::MultiPhaseThreadBreakPoints,
26        },
27        halo2_proofs::{
28            dev::MockProver,
29            halo2curves::bn256::{Bn256, G1Affine},
30            plonk::{keygen_pk2, ProvingKey},
31            poly::{commitment::Params, kzg::commitment::ParamsKZG},
32            SerdeFormat,
33        },
34    },
35    CircuitExt, Snark, SHPLONK,
36};
37
38use crate::halo2::utils::Halo2ParamsReader;
39
40pub type Halo2Params = ParamsKZG<Bn256>;
41pub use snark_verifier_sdk::snark_verifier::halo2_base::halo2_proofs::halo2curves::bn256::Fr;
42
43/// A prover that can generate proofs with the Halo2
44#[derive(Debug, Clone)]
45pub struct Halo2Prover;
46
47#[derive(Clone, Deserialize, Serialize)]
48pub struct RawEvmProof {
49    pub instances: Vec<Fr>,
50    pub proof: Vec<u8>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct DslOperations<C: Config> {
55    pub operations: TracedVec<DslIr<C>>,
56    pub num_public_values: usize,
57}
58
59/// Necessary metadata to prove a Halo2 circuit
60/// Attention: Deserializer of this struct is not generic. It only works for verifier/wrapper circuit.
61#[derive(Debug, Clone)]
62pub struct Halo2ProvingPinning {
63    pub pk: ProvingKey<G1Affine>,
64    pub metadata: Halo2ProvingMetadata,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Halo2ProvingMetadata {
69    pub config_params: BaseCircuitParams,
70    pub break_points: MultiPhaseThreadBreakPoints,
71    /// Number of public values per column in order.
72    pub num_pvs: Vec<usize>,
73}
74
75impl RawEvmProof {
76    /// Return bytes calldata to be passed to the verifier contract.
77    pub fn verifier_calldata(&self) -> Vec<u8> {
78        encode_calldata(&[self.instances.clone()], &self.proof)
79    }
80}
81
82impl Halo2ProvingPinning {
83    pub fn generate_dummy_snark(&self, reader: &impl Halo2ParamsReader) -> Snark {
84        let k = self.metadata.config_params.k;
85        let params = reader.read_params(k);
86        gen_dummy_snark_from_vk::<SHPLONK>(
87            &params,
88            self.pk.get_vk(),
89            self.metadata.num_pvs.clone(),
90            None,
91        )
92    }
93}
94
95impl Halo2Prover {
96    pub fn builder(stage: CircuitBuilderStage, k: usize) -> BaseCircuitBuilder<Fr> {
97        BaseCircuitBuilder::from_stage(stage)
98            .use_k(k)
99            .use_lookup_bits(k - 1)
100            .use_instance_columns(1)
101    }
102
103    pub fn populate<
104        C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
105    >(
106        builder: BaseCircuitBuilder<Fr>,
107        dsl_operations: DslOperations<C>,
108        witness: Witness<C>,
109        #[allow(unused_variables)] profiling: bool,
110    ) -> BaseCircuitBuilder<Fr> {
111        let mut state = Halo2State {
112            builder,
113            ..Default::default()
114        };
115        state.load_witness(witness);
116
117        let backend = Halo2ConstraintCompiler::<C>::new(dsl_operations.num_public_values);
118        #[cfg(feature = "bench-metrics")]
119        let backend = if profiling {
120            backend.with_profiling()
121        } else {
122            backend
123        };
124        backend.constrain_halo2(&mut state, dsl_operations.operations);
125
126        state.builder
127    }
128
129    /// Executes the prover in testing mode with a circuit definition and witness.
130    ///
131    /// Returns the public instances.
132    pub fn mock<
133        C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
134    >(
135        k: usize,
136        dsl_operations: DslOperations<C>,
137        witness: Witness<C>,
138    ) -> Vec<Vec<Fr>> {
139        let builder = Self::builder(CircuitBuilderStage::Mock, k);
140        let mut builder = Self::populate(builder, dsl_operations, witness, true);
141
142        let public_instances = builder.instances();
143        println!("Public instances: {:?}", public_instances);
144
145        builder.calculate_params(Some(20));
146
147        MockProver::run(k as u32, &builder, public_instances.clone())
148            .unwrap()
149            .assert_satisfied();
150        public_instances
151    }
152
153    /// Populates builder, tunes circuit, keygen
154    pub fn keygen<
155        C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
156    >(
157        params: &Halo2Params,
158        dsl_operations: DslOperations<C>,
159        witness: Witness<C>,
160    ) -> Halo2ProvingPinning {
161        let k = params.k() as usize;
162        let builder = Self::builder(CircuitBuilderStage::Keygen, k);
163        let mut builder = Self::populate(builder, dsl_operations, witness, true);
164        builder.calculate_params(Some(20));
165
166        // let break_points;
167        // // if pk already exists, read break points from file
168        // let pk = if Path::new("halo2_final.pk").exists() {
169        //     let file = File::open("halo2_final.json").unwrap();
170        //     break_points = serde_json::from_reader(file).unwrap();
171        //     gen_pk(&params, &builder, Some(Path::new("halo2_final.pk")))
172        // } else {
173        //
174        //     pk
175        // };
176        #[cfg(feature = "bench-metrics")]
177        let start = std::time::Instant::now();
178        let pk = keygen_pk2(params, &builder, false).unwrap();
179        #[cfg(feature = "bench-metrics")]
180        metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64);
181        let break_points = builder.break_points();
182
183        let config_params = builder.config_params.clone();
184        let num_pvs = builder
185            .assigned_instances
186            .iter()
187            .map(|x| x.len())
188            .collect_vec();
189
190        // let file = File::create("halo2_final.json").unwrap();
191        // serde_json::to_writer(file, &break_points).unwrap();
192        Halo2ProvingPinning {
193            pk,
194            metadata: Halo2ProvingMetadata {
195                config_params,
196                break_points,
197                num_pvs,
198            },
199        }
200    }
201
202    pub fn prove<
203        C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
204    >(
205        params: &Halo2Params,
206        config_params: BaseCircuitParams,
207        break_points: MultiPhaseThreadBreakPoints,
208        pk: &ProvingKey<G1Affine>,
209        dsl_operations: DslOperations<C>,
210        witness: Witness<C>,
211        profiling: bool,
212    ) -> Snark {
213        let k = config_params.k;
214        #[cfg(feature = "bench-metrics")]
215        let start = std::time::Instant::now();
216        let builder = Self::builder(CircuitBuilderStage::Prover, k)
217            .use_params(config_params)
218            .use_break_points(break_points);
219        let builder = Self::populate(builder, dsl_operations, witness, profiling);
220        #[cfg(feature = "bench-metrics")]
221        {
222            let stats = builder.statistics();
223            let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum();
224            let total_lookups: usize = stats.total_lookup_advice_per_phase.into_iter().sum();
225            let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
226            metrics::counter!("main_cells_used").absolute(total_cell as u64);
227        }
228        let snark = gen_snark_shplonk(params, pk, builder, None::<&str>);
229
230        #[cfg(feature = "bench-metrics")]
231        metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64);
232
233        snark
234    }
235}
236
237#[derive(Serialize, Deserialize)]
238struct SerializedHalo2ProvingPinning {
239    pk_bytes: Vec<u8>,
240    metadata: Halo2ProvingMetadata,
241}
242
243impl Serialize for Halo2ProvingPinning {
244    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
245    where
246        S: Serializer,
247    {
248        let serialized = SerializedHalo2ProvingPinning {
249            pk_bytes: self.pk.to_bytes(SerdeFormat::RawBytes),
250            metadata: self.metadata.clone(),
251        };
252        serialized.serialize(serializer)
253    }
254}
255
256impl<'de> Deserialize<'de> for Halo2ProvingPinning {
257    fn deserialize<D>(deserializer: D) -> Result<Halo2ProvingPinning, D::Error>
258    where
259        D: Deserializer<'de>,
260    {
261        let SerializedHalo2ProvingPinning { pk_bytes, metadata } =
262            SerializedHalo2ProvingPinning::deserialize(deserializer)?;
263
264        let pk = ProvingKey::<G1Affine>::from_bytes::<BaseCircuitBuilder<Fr>>(
265            &pk_bytes,
266            SerdeFormat::RawBytes,
267            metadata.config_params.clone(),
268        )
269        .map_err(|e| de::Error::custom(format!("invalid bytes for proving key: {}", e)))?;
270
271        Ok(Halo2ProvingPinning { pk, metadata })
272    }
273}