openvm_native_recursion/halo2/
mod.rs
1pub mod utils;
2pub mod verifier;
3
4pub mod testing_utils;
5#[cfg(test)]
6mod tests;
7pub mod wrapper;
8
9use std::fmt::Debug;
10
11use itertools::Itertools;
12use openvm_native_compiler::{
13 constraints::halo2::compiler::{Halo2ConstraintCompiler, Halo2State},
14 ir::{Config, DslIr, TracedVec, Witness},
15};
16use openvm_stark_backend::p3_field::extension::BinomialExtensionField;
17use openvm_stark_sdk::{p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr};
18use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
19use snark_verifier_sdk::{
20 evm::encode_calldata,
21 halo2::{gen_dummy_snark_from_vk, gen_snark_shplonk},
22 snark_verifier::halo2_base::{
23 gates::{
24 circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage},
25 flex_gate::MultiPhaseThreadBreakPoints,
26 },
27 halo2_proofs::{
28 dev::MockProver,
29 halo2curves::bn256::{Bn256, G1Affine},
30 plonk::{keygen_pk2, ProvingKey},
31 poly::{commitment::Params, kzg::commitment::ParamsKZG},
32 SerdeFormat,
33 },
34 },
35 CircuitExt, Snark, SHPLONK,
36};
37
38use crate::halo2::utils::Halo2ParamsReader;
39
40pub type Halo2Params = ParamsKZG<Bn256>;
41pub use snark_verifier_sdk::snark_verifier::halo2_base::halo2_proofs::halo2curves::bn256::Fr;
42
43#[derive(Debug, Clone)]
45pub struct Halo2Prover;
46
47#[derive(Clone, Deserialize, Serialize)]
48pub struct RawEvmProof {
49 pub instances: Vec<Fr>,
50 pub proof: Vec<u8>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct DslOperations<C: Config> {
55 pub operations: TracedVec<DslIr<C>>,
56 pub num_public_values: usize,
57}
58
59#[derive(Debug, Clone)]
62pub struct Halo2ProvingPinning {
63 pub pk: ProvingKey<G1Affine>,
64 pub metadata: Halo2ProvingMetadata,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Halo2ProvingMetadata {
69 pub config_params: BaseCircuitParams,
70 pub break_points: MultiPhaseThreadBreakPoints,
71 pub num_pvs: Vec<usize>,
73}
74
75impl RawEvmProof {
76 pub fn verifier_calldata(&self) -> Vec<u8> {
78 encode_calldata(&[self.instances.clone()], &self.proof)
79 }
80}
81
82impl Halo2ProvingPinning {
83 pub fn generate_dummy_snark(&self, reader: &impl Halo2ParamsReader) -> Snark {
84 let k = self.metadata.config_params.k;
85 let params = reader.read_params(k);
86 gen_dummy_snark_from_vk::<SHPLONK>(
87 ¶ms,
88 self.pk.get_vk(),
89 self.metadata.num_pvs.clone(),
90 None,
91 )
92 }
93}
94
95impl Halo2Prover {
96 pub fn builder(stage: CircuitBuilderStage, k: usize) -> BaseCircuitBuilder<Fr> {
97 BaseCircuitBuilder::from_stage(stage)
98 .use_k(k)
99 .use_lookup_bits(k - 1)
100 .use_instance_columns(1)
101 }
102
103 pub fn populate<
104 C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
105 >(
106 builder: BaseCircuitBuilder<Fr>,
107 dsl_operations: DslOperations<C>,
108 witness: Witness<C>,
109 #[allow(unused_variables)] profiling: bool,
110 ) -> BaseCircuitBuilder<Fr> {
111 let mut state = Halo2State {
112 builder,
113 ..Default::default()
114 };
115 state.load_witness(witness);
116
117 let backend = Halo2ConstraintCompiler::<C>::new(dsl_operations.num_public_values);
118 #[cfg(feature = "bench-metrics")]
119 let backend = if profiling {
120 backend.with_profiling()
121 } else {
122 backend
123 };
124 backend.constrain_halo2(&mut state, dsl_operations.operations);
125
126 state.builder
127 }
128
129 pub fn mock<
133 C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
134 >(
135 k: usize,
136 dsl_operations: DslOperations<C>,
137 witness: Witness<C>,
138 ) -> Vec<Vec<Fr>> {
139 let builder = Self::builder(CircuitBuilderStage::Mock, k);
140 let mut builder = Self::populate(builder, dsl_operations, witness, true);
141
142 let public_instances = builder.instances();
143 println!("Public instances: {:?}", public_instances);
144
145 builder.calculate_params(Some(20));
146
147 MockProver::run(k as u32, &builder, public_instances.clone())
148 .unwrap()
149 .assert_satisfied();
150 public_instances
151 }
152
153 pub fn keygen<
155 C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
156 >(
157 params: &Halo2Params,
158 dsl_operations: DslOperations<C>,
159 witness: Witness<C>,
160 ) -> Halo2ProvingPinning {
161 let k = params.k() as usize;
162 let builder = Self::builder(CircuitBuilderStage::Keygen, k);
163 let mut builder = Self::populate(builder, dsl_operations, witness, true);
164 builder.calculate_params(Some(20));
165
166 #[cfg(feature = "bench-metrics")]
177 let start = std::time::Instant::now();
178 let pk = keygen_pk2(params, &builder, false).unwrap();
179 #[cfg(feature = "bench-metrics")]
180 metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64);
181 let break_points = builder.break_points();
182
183 let config_params = builder.config_params.clone();
184 let num_pvs = builder
185 .assigned_instances
186 .iter()
187 .map(|x| x.len())
188 .collect_vec();
189
190 Halo2ProvingPinning {
193 pk,
194 metadata: Halo2ProvingMetadata {
195 config_params,
196 break_points,
197 num_pvs,
198 },
199 }
200 }
201
202 pub fn prove<
203 C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
204 >(
205 params: &Halo2Params,
206 config_params: BaseCircuitParams,
207 break_points: MultiPhaseThreadBreakPoints,
208 pk: &ProvingKey<G1Affine>,
209 dsl_operations: DslOperations<C>,
210 witness: Witness<C>,
211 profiling: bool,
212 ) -> Snark {
213 let k = config_params.k;
214 #[cfg(feature = "bench-metrics")]
215 let start = std::time::Instant::now();
216 let builder = Self::builder(CircuitBuilderStage::Prover, k)
217 .use_params(config_params)
218 .use_break_points(break_points);
219 let builder = Self::populate(builder, dsl_operations, witness, profiling);
220 #[cfg(feature = "bench-metrics")]
221 {
222 let stats = builder.statistics();
223 let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum();
224 let total_lookups: usize = stats.total_lookup_advice_per_phase.into_iter().sum();
225 let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
226 metrics::counter!("main_cells_used").absolute(total_cell as u64);
227 }
228 let snark = gen_snark_shplonk(params, pk, builder, None::<&str>);
229
230 #[cfg(feature = "bench-metrics")]
231 metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64);
232
233 snark
234 }
235}
236
237#[derive(Serialize, Deserialize)]
238struct SerializedHalo2ProvingPinning {
239 pk_bytes: Vec<u8>,
240 metadata: Halo2ProvingMetadata,
241}
242
243impl Serialize for Halo2ProvingPinning {
244 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
245 where
246 S: Serializer,
247 {
248 let serialized = SerializedHalo2ProvingPinning {
249 pk_bytes: self.pk.to_bytes(SerdeFormat::RawBytes),
250 metadata: self.metadata.clone(),
251 };
252 serialized.serialize(serializer)
253 }
254}
255
256impl<'de> Deserialize<'de> for Halo2ProvingPinning {
257 fn deserialize<D>(deserializer: D) -> Result<Halo2ProvingPinning, D::Error>
258 where
259 D: Deserializer<'de>,
260 {
261 let SerializedHalo2ProvingPinning { pk_bytes, metadata } =
262 SerializedHalo2ProvingPinning::deserialize(deserializer)?;
263
264 let pk = ProvingKey::<G1Affine>::from_bytes::<BaseCircuitBuilder<Fr>>(
265 &pk_bytes,
266 SerdeFormat::RawBytes,
267 metadata.config_params.clone(),
268 )
269 .map_err(|e| de::Error::custom(format!("invalid bytes for proving key: {}", e)))?;
270
271 Ok(Halo2ProvingPinning { pk, metadata })
272 }
273}