openvm_native_recursion/halo2/
mod.rs1pub mod utils;
2pub mod verifier;
3
4pub mod testing_utils;
5#[cfg(test)]
6mod tests;
7pub mod wrapper;
8
9use std::fmt::Debug;
10
11use itertools::Itertools;
12use openvm_native_compiler::{
13 constraints::halo2::compiler::{Halo2ConstraintCompiler, Halo2State},
14 ir::{Config, DslIr, TracedVec, Witness},
15};
16use openvm_stark_backend::p3_field::extension::BinomialExtensionField;
17use openvm_stark_sdk::{p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr};
18use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
19use snark_verifier_sdk::{
20 halo2::{gen_dummy_snark_from_vk, gen_snark_shplonk},
21 snark_verifier::halo2_base::{
22 gates::{
23 circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage},
24 flex_gate::MultiPhaseThreadBreakPoints,
25 },
26 halo2_proofs::{
27 dev::MockProver,
28 halo2curves::bn256::{Bn256, G1Affine},
29 plonk::{keygen_pk2, ProvingKey},
30 poly::{commitment::Params, kzg::commitment::ParamsKZG},
31 SerdeFormat,
32 },
33 },
34 CircuitExt, Snark, SHPLONK,
35};
36
37use crate::halo2::utils::Halo2ParamsReader;
38
39pub type Halo2Params = ParamsKZG<Bn256>;
40pub use snark_verifier_sdk::snark_verifier::halo2_base::halo2_proofs::halo2curves::bn256::Fr;
41
42#[derive(Debug, Clone)]
44pub struct Halo2Prover;
45
46#[derive(Clone, Deserialize, Serialize)]
47pub struct RawEvmProof {
48 pub instances: Vec<Fr>,
49 pub proof: Vec<u8>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct DslOperations<C: Config> {
54 pub operations: TracedVec<DslIr<C>>,
55 pub num_public_values: usize,
56}
57
58#[derive(Debug, Clone)]
62pub struct Halo2ProvingPinning {
63 pub pk: ProvingKey<G1Affine>,
64 pub metadata: Halo2ProvingMetadata,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Halo2ProvingMetadata {
69 pub config_params: BaseCircuitParams,
70 pub break_points: MultiPhaseThreadBreakPoints,
71 pub num_pvs: Vec<usize>,
73}
74
75impl RawEvmProof {
76 #[cfg(feature = "evm-prove")]
78 pub fn verifier_calldata(&self) -> Vec<u8> {
79 snark_verifier_sdk::evm::encode_calldata(&[self.instances.clone()], &self.proof)
80 }
81}
82
83impl Halo2ProvingPinning {
84 pub fn generate_dummy_snark(&self, reader: &impl Halo2ParamsReader) -> Snark {
85 let k = self.metadata.config_params.k;
86 let params = reader.read_params(k);
87 gen_dummy_snark_from_vk::<SHPLONK>(
88 ¶ms,
89 self.pk.get_vk(),
90 self.metadata.num_pvs.clone(),
91 None,
92 )
93 }
94}
95
96impl Halo2Prover {
97 pub fn builder(stage: CircuitBuilderStage, k: usize) -> BaseCircuitBuilder<Fr> {
98 BaseCircuitBuilder::from_stage(stage)
99 .use_k(k)
100 .use_lookup_bits(k - 1)
101 .use_instance_columns(1)
102 }
103
104 pub fn populate<
105 C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
106 >(
107 builder: BaseCircuitBuilder<Fr>,
108 dsl_operations: DslOperations<C>,
109 witness: Witness<C>,
110 #[allow(unused_variables)] profiling: bool,
111 ) -> BaseCircuitBuilder<Fr> {
112 let mut state = Halo2State {
113 builder,
114 ..Default::default()
115 };
116 state.load_witness(witness);
117
118 let backend = Halo2ConstraintCompiler::<C>::new(dsl_operations.num_public_values);
119 #[cfg(feature = "metrics")]
120 let backend = if profiling {
121 backend.with_profiling()
122 } else {
123 backend
124 };
125 backend.constrain_halo2(&mut state, dsl_operations.operations);
126
127 state.builder
128 }
129
130 pub fn mock<
134 C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
135 >(
136 k: usize,
137 dsl_operations: DslOperations<C>,
138 witness: Witness<C>,
139 ) -> Vec<Vec<Fr>> {
140 let builder = Self::builder(CircuitBuilderStage::Mock, k);
141 let mut builder = Self::populate(builder, dsl_operations, witness, true);
142
143 let public_instances = builder.instances();
144 println!("Public instances: {:?}", public_instances);
145
146 builder.calculate_params(Some(20));
147
148 MockProver::run(k as u32, &builder, public_instances.clone())
149 .unwrap()
150 .assert_satisfied();
151 public_instances
152 }
153
154 pub fn keygen<
156 C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
157 >(
158 params: &Halo2Params,
159 dsl_operations: DslOperations<C>,
160 witness: Witness<C>,
161 ) -> Halo2ProvingPinning {
162 let k = params.k() as usize;
163 let builder = Self::builder(CircuitBuilderStage::Keygen, k);
164 let mut builder = Self::populate(builder, dsl_operations, witness, true);
165 builder.calculate_params(Some(20));
166
167 #[cfg(feature = "metrics")]
168 let start = std::time::Instant::now();
169 let pk = keygen_pk2(params, &builder, false).unwrap();
170 #[cfg(feature = "metrics")]
171 metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64);
172 let break_points = builder.break_points();
173
174 let config_params = builder.config_params.clone();
175 let num_pvs = builder
176 .assigned_instances
177 .iter()
178 .map(|x| x.len())
179 .collect_vec();
180
181 Halo2ProvingPinning {
182 pk,
183 metadata: Halo2ProvingMetadata {
184 config_params,
185 break_points,
186 num_pvs,
187 },
188 }
189 }
190
191 pub fn prove<
192 C: Config<N = Bn254Fr, F = BabyBear, EF = BinomialExtensionField<BabyBear, 4>> + Debug,
193 >(
194 params: &Halo2Params,
195 config_params: BaseCircuitParams,
196 break_points: MultiPhaseThreadBreakPoints,
197 pk: &ProvingKey<G1Affine>,
198 dsl_operations: DslOperations<C>,
199 witness: Witness<C>,
200 profiling: bool,
201 ) -> Snark {
202 let k = config_params.k;
203 #[cfg(feature = "metrics")]
204 let start = std::time::Instant::now();
205 let builder = Self::builder(CircuitBuilderStage::Prover, k)
206 .use_params(config_params)
207 .use_break_points(break_points);
208 let builder = Self::populate(builder, dsl_operations, witness, profiling);
209 #[cfg(feature = "metrics")]
210 {
211 let stats = builder.statistics();
212 let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum();
213 let total_lookups: usize = stats.total_lookup_advice_per_phase.into_iter().sum();
214 let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
215 metrics::counter!("main_cells_used").absolute(total_cell as u64);
216 }
217 let snark = gen_snark_shplonk(params, pk, builder, None::<&str>);
218
219 #[cfg(feature = "metrics")]
220 metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64);
221
222 snark
223 }
224}
225
226#[derive(Serialize, Deserialize)]
227struct SerializedHalo2ProvingPinning {
228 pk_bytes: Vec<u8>,
229 metadata: Halo2ProvingMetadata,
230}
231
232impl Serialize for Halo2ProvingPinning {
233 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
234 where
235 S: Serializer,
236 {
237 let serialized = SerializedHalo2ProvingPinning {
238 pk_bytes: self.pk.to_bytes(SerdeFormat::RawBytes),
239 metadata: self.metadata.clone(),
240 };
241 serialized.serialize(serializer)
242 }
243}
244
245impl<'de> Deserialize<'de> for Halo2ProvingPinning {
246 fn deserialize<D>(deserializer: D) -> Result<Halo2ProvingPinning, D::Error>
247 where
248 D: Deserializer<'de>,
249 {
250 let SerializedHalo2ProvingPinning { pk_bytes, metadata } =
251 SerializedHalo2ProvingPinning::deserialize(deserializer)?;
252
253 let pk = ProvingKey::<G1Affine>::from_bytes::<BaseCircuitBuilder<Fr>>(
254 &pk_bytes,
255 SerdeFormat::RawBytes,
256 metadata.config_params.clone(),
257 )
258 .map_err(|e| de::Error::custom(format!("invalid bytes for proving key: {}", e)))?;
259
260 Ok(Halo2ProvingPinning { pk, metadata })
261 }
262}