openvm_native_recursion/halo2/
mod.rs

1pub mod utils;
2pub mod verifier;
3
4pub mod testing_utils;
5#[cfg(test)]
6mod tests;
7pub mod wrapper;
8
9use std::fmt::Debug;
10
11use itertools::Itertools;
12use openvm_native_compiler::{
13    constraints::halo2::compiler::{Halo2ConstraintCompiler, Halo2State},
14    ir::{Config, DslIr, TracedVec, Witness},
15};
16use openvm_stark_backend::p3_field::extension::BinomialExtensionField;
17use openvm_stark_sdk::{p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr};
18use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
19use snark_verifier_sdk::{
20    halo2::{gen_dummy_snark_from_vk, gen_snark_shplonk},
21    snark_verifier::halo2_base::{
22        gates::{
23            circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage},
24            flex_gate::MultiPhaseThreadBreakPoints,
25        },
26        halo2_proofs::{
27            dev::MockProver,
28            halo2curves::bn256::{Bn256, G1Affine},
29            plonk::{keygen_pk2, ProvingKey},
30            poly::{commitment::Params, kzg::commitment::ParamsKZG},
31            SerdeFormat,
32        },
33    },
34    CircuitExt, Snark, SHPLONK,
35};
36
37use crate::halo2::utils::Halo2ParamsReader;
38
39pub type Halo2Params = ParamsKZG<Bn256>;
40pub use snark_verifier_sdk::snark_verifier::halo2_base::halo2_proofs::halo2curves::bn256::Fr;
41
42/// A prover that can generate proofs with the Halo2
43#[derive(Debug, Clone)]
44pub struct Halo2Prover;
45
46#[derive(Clone, Deserialize, Serialize)]
47pub struct RawEvmProof {
48    pub instances: Vec<Fr>,
49    pub proof: Vec<u8>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct DslOperations<C: Config> {
54    pub operations: TracedVec<DslIr<C>>,
55    pub num_public_values: usize,
56}
57
58/// Necessary metadata to prove a Halo2 circuit
59/// Attention: Deserializer of this struct is not generic. It only works for verifier/wrapper
60/// circuit.
61#[derive(Debug, Clone)]
62pub struct Halo2ProvingPinning {
63    pub pk: ProvingKey<G1Affine>,
64    pub metadata: Halo2ProvingMetadata,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Halo2ProvingMetadata {
69    pub config_params: BaseCircuitParams,
70    pub break_points: MultiPhaseThreadBreakPoints,
71    /// Number of public values per column in order.
72    pub num_pvs: Vec<usize>,
73}
74
75impl RawEvmProof {
76    /// Return bytes calldata to be passed to the verifier contract.
77    #[cfg(feature = "evm-prove")]
78    pub fn verifier_calldata(&self) -> Vec<u8> {
79        snark_verifier_sdk::evm::encode_calldata(&[self.instances.clone()], &self.proof)
80    }
81}
82
83impl Halo2ProvingPinning {
84    pub fn generate_dummy_snark(&self, reader: &impl Halo2ParamsReader) -> Snark {
85        let k = self.metadata.config_params.k;
86        let params = reader.read_params(k);
87        gen_dummy_snark_from_vk::<SHPLONK>(
88            &params,
89            self.pk.get_vk(),
90            self.metadata.num_pvs.clone(),
91            None,
92        )
93    }
94}
95
96impl Halo2Prover {
97    pub fn builder(stage: CircuitBuilderStage, k: usize) -> BaseCircuitBuilder<Fr> {
98        BaseCircuitBuilder::from_stage(stage)
99            .use_k(k)
100            .use_lookup_bits(k - 1)
101            .use_instance_columns(1)
102    }
103
104    pub fn populate<
105        C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
106    >(
107        builder: BaseCircuitBuilder<Fr>,
108        dsl_operations: DslOperations<C>,
109        witness: Witness<C>,
110        #[allow(unused_variables)] profiling: bool,
111    ) -> BaseCircuitBuilder<Fr> {
112        let mut state = Halo2State {
113            builder,
114            ..Default::default()
115        };
116        state.load_witness(witness);
117
118        let backend = Halo2ConstraintCompiler::<C>::new(dsl_operations.num_public_values);
119        #[cfg(feature = "metrics")]
120        let backend = if profiling {
121            backend.with_profiling()
122        } else {
123            backend
124        };
125        backend.constrain_halo2(&mut state, dsl_operations.operations);
126
127        state.builder
128    }
129
130    /// Executes the prover in testing mode with a circuit definition and witness.
131    ///
132    /// Returns the public instances.
133    pub fn mock<
134        C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
135    >(
136        k: usize,
137        dsl_operations: DslOperations<C>,
138        witness: Witness<C>,
139    ) -> Vec<Vec<Fr>> {
140        let builder = Self::builder(CircuitBuilderStage::Mock, k);
141        let mut builder = Self::populate(builder, dsl_operations, witness, true);
142
143        let public_instances = builder.instances();
144        println!("Public instances: {:?}", public_instances);
145
146        builder.calculate_params(Some(20));
147
148        MockProver::run(k as u32, &builder, public_instances.clone())
149            .unwrap()
150            .assert_satisfied();
151        public_instances
152    }
153
154    /// Populates builder, tunes circuit, keygen
155    pub fn keygen<
156        C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
157    >(
158        params: &Halo2Params,
159        dsl_operations: DslOperations<C>,
160        witness: Witness<C>,
161    ) -> Halo2ProvingPinning {
162        let k = params.k() as usize;
163        let builder = Self::builder(CircuitBuilderStage::Keygen, k);
164        let mut builder = Self::populate(builder, dsl_operations, witness, true);
165        builder.calculate_params(Some(20));
166
167        #[cfg(feature = "metrics")]
168        let start = std::time::Instant::now();
169        let pk = keygen_pk2(params, &builder, false).unwrap();
170        #[cfg(feature = "metrics")]
171        metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64);
172        let break_points = builder.break_points();
173
174        let config_params = builder.config_params.clone();
175        let num_pvs = builder
176            .assigned_instances
177            .iter()
178            .map(|x| x.len())
179            .collect_vec();
180
181        Halo2ProvingPinning {
182            pk,
183            metadata: Halo2ProvingMetadata {
184                config_params,
185                break_points,
186                num_pvs,
187            },
188        }
189    }
190
191    pub fn prove<
192        C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
193    >(
194        params: &Halo2Params,
195        config_params: BaseCircuitParams,
196        break_points: MultiPhaseThreadBreakPoints,
197        pk: &ProvingKey<G1Affine>,
198        dsl_operations: DslOperations<C>,
199        witness: Witness<C>,
200        profiling: bool,
201    ) -> Snark {
202        let k = config_params.k;
203        #[cfg(feature = "metrics")]
204        let start = std::time::Instant::now();
205        let builder = Self::builder(CircuitBuilderStage::Prover, k)
206            .use_params(config_params)
207            .use_break_points(break_points);
208        let builder = Self::populate(builder, dsl_operations, witness, profiling);
209        #[cfg(feature = "metrics")]
210        {
211            let stats = builder.statistics();
212            let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum();
213            let total_lookups: usize = stats.total_lookup_advice_per_phase.into_iter().sum();
214            let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
215            metrics::counter!("main_cells_used").absolute(total_cell as u64);
216        }
217        let snark = gen_snark_shplonk(params, pk, builder, None::<&str>);
218
219        #[cfg(feature = "metrics")]
220        metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64);
221
222        snark
223    }
224}
225
226#[derive(Serialize, Deserialize)]
227struct SerializedHalo2ProvingPinning {
228    pk_bytes: Vec<u8>,
229    metadata: Halo2ProvingMetadata,
230}
231
232impl Serialize for Halo2ProvingPinning {
233    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
234    where
235        S: Serializer,
236    {
237        let serialized = SerializedHalo2ProvingPinning {
238            pk_bytes: self.pk.to_bytes(SerdeFormat::RawBytes),
239            metadata: self.metadata.clone(),
240        };
241        serialized.serialize(serializer)
242    }
243}
244
245impl<'de> Deserialize<'de> for Halo2ProvingPinning {
246    fn deserialize<D>(deserializer: D) -> Result<Halo2ProvingPinning, D::Error>
247    where
248        D: Deserializer<'de>,
249    {
250        let SerializedHalo2ProvingPinning { pk_bytes, metadata } =
251            SerializedHalo2ProvingPinning::deserialize(deserializer)?;
252
253        let pk = ProvingKey::<G1Affine>::from_bytes::<BaseCircuitBuilder<Fr>>(
254            &pk_bytes,
255            SerdeFormat::RawBytes,
256            metadata.config_params.clone(),
257        )
258        .map_err(|e| de::Error::custom(format!("invalid bytes for proving key: {}", e)))?;
259
260        Ok(Halo2ProvingPinning { pk, metadata })
261    }
262}