1use std::sync::Arc;
2
3use openvm_circuit::arch::{
4 instructions::exe::VmExe, ContinuationVmProof, PreflightExecutor, SingleSegmentVmProver,
5 VirtualMachineError, VmBuilder, VmExecutionConfig, VmInstance,
6};
7#[cfg(feature = "evm-prove")]
8use openvm_continuations::verifier::root::types::RootVmVerifierInput;
9use openvm_continuations::verifier::{
10 internal::types::{InternalVmVerifierInput, VmStarkProof},
11 leaf::types::LeafVmVerifierInput,
12};
13use openvm_native_circuit::{NativeConfig, NATIVE_MAX_TRACE_HEIGHTS};
14use openvm_native_recursion::hints::Hintable;
15use openvm_stark_sdk::{engine::StarkFriEngine, openvm_stark_backend::proof::Proof};
16use tracing::{info_span, instrument};
17
18use crate::{
19 config::AggregationTreeConfig, keygen::AggProvingKey, prover::vm::new_local_prover,
20 util::check_max_constraint_degrees, F, SC,
21};
22#[cfg(feature = "evm-prove")]
23use crate::{prover::RootVerifierLocalProver, RootSC};
24
25pub struct AggStarkProver<E, NativeBuilder>
26where
27 E: StarkFriEngine<SC = SC>,
28 NativeBuilder: VmBuilder<E, VmConfig = NativeConfig>,
29{
30 leaf_prover: VmInstance<E, NativeBuilder>,
31 leaf_controller: LeafProvingController,
32
33 pub internal_prover: VmInstance<E, NativeBuilder>,
34 #[cfg(feature = "evm-prove")]
35 root_prover: RootVerifierLocalProver,
36 pub num_children_internal: usize,
37 pub max_internal_wrapper_layers: usize,
38}
39
40pub struct LeafProvingController {
41 pub num_children: usize,
43}
44
45impl<E, NativeBuilder> AggStarkProver<E, NativeBuilder>
46where
47 E: StarkFriEngine<SC = SC>,
48 NativeBuilder: VmBuilder<E, VmConfig = NativeConfig> + Clone,
49 <NativeConfig as VmExecutionConfig<F>>::Executor:
50 PreflightExecutor<F, <NativeBuilder as VmBuilder<E>>::RecordArena>,
51{
52 pub fn new(
53 native_builder: NativeBuilder,
54 agg_pk: &AggProvingKey,
55 leaf_verifier_exe: Arc<VmExe<F>>,
56 tree_config: AggregationTreeConfig,
57 ) -> Result<Self, VirtualMachineError> {
58 let leaf_prover = new_local_prover(
59 native_builder.clone(),
60 &agg_pk.leaf_vm_pk,
61 leaf_verifier_exe,
62 )?;
63 let internal_prover = new_local_prover(
64 native_builder,
65 &agg_pk.internal_vm_pk,
66 agg_pk.internal_committed_exe.exe.clone(),
67 )?;
68 #[cfg(feature = "evm-prove")]
69 let root_prover = RootVerifierLocalProver::new(&agg_pk.root_verifier_pk)?;
70 Ok(Self::new_from_instances(
71 leaf_prover,
72 internal_prover,
73 #[cfg(feature = "evm-prove")]
74 root_prover,
75 tree_config,
76 ))
77 }
78
79 pub fn new_from_instances(
80 leaf_instance: VmInstance<E, NativeBuilder>,
81 internal_instance: VmInstance<E, NativeBuilder>,
82 #[cfg(feature = "evm-prove")] root_instance: RootVerifierLocalProver,
83 tree_config: AggregationTreeConfig,
84 ) -> Self {
85 let leaf_controller = LeafProvingController {
86 num_children: tree_config.num_children_leaf,
87 };
88 Self {
89 leaf_prover: leaf_instance,
90 leaf_controller,
91 internal_prover: internal_instance,
92 #[cfg(feature = "evm-prove")]
93 root_prover: root_instance,
94 num_children_internal: tree_config.num_children_internal,
95 max_internal_wrapper_layers: tree_config.max_internal_wrapper_layers,
96 }
97 }
98
99 pub fn with_num_children_leaf(mut self, num_children_leaf: usize) -> Self {
100 self.leaf_controller.num_children = num_children_leaf;
101 self
102 }
103
104 pub fn with_num_children_internal(mut self, num_children_internal: usize) -> Self {
105 self.num_children_internal = num_children_internal;
106 self
107 }
108
109 pub fn with_max_internal_wrapper_layers(mut self, max_internal_wrapper_layers: usize) -> Self {
110 self.max_internal_wrapper_layers = max_internal_wrapper_layers;
111 self
112 }
113
114 #[cfg(feature = "evm-prove")]
116 pub fn generate_root_proof(
117 &mut self,
118 app_proofs: ContinuationVmProof<SC>,
119 ) -> Result<Proof<RootSC>, VirtualMachineError> {
120 let root_verifier_input = self.generate_root_verifier_input(app_proofs)?;
121 self.generate_root_proof_impl(root_verifier_input)
122 }
123
124 pub fn generate_leaf_proofs(
125 &mut self,
126 app_proofs: &ContinuationVmProof<SC>,
127 ) -> Result<Vec<Proof<SC>>, VirtualMachineError> {
128 check_max_constraint_degrees(
129 self.leaf_prover.vm.config().as_ref(),
130 &self.leaf_prover.vm.engine.fri_params(),
131 );
132 self.leaf_controller
133 .generate_proof(&mut self.leaf_prover, app_proofs)
134 }
135
136 #[cfg(feature = "evm-prove")]
138 pub fn generate_root_verifier_input(
139 &mut self,
140 app_proofs: ContinuationVmProof<SC>,
141 ) -> Result<RootVmVerifierInput<SC>, VirtualMachineError> {
142 let leaf_proofs = self.generate_leaf_proofs(&app_proofs)?;
143 let public_values = app_proofs.user_public_values.public_values;
144 let e2e_stark_proof = self.aggregate_leaf_proofs(leaf_proofs, public_values)?;
145 let wrapped_stark_proof = self.wrap_e2e_stark_proof(e2e_stark_proof)?;
146 Ok(wrapped_stark_proof)
147 }
148
149 pub fn aggregate_leaf_proofs(
150 &mut self,
151 leaf_proofs: Vec<Proof<SC>>,
152 public_values: Vec<F>,
153 ) -> Result<VmStarkProof<SC>, VirtualMachineError> {
154 check_max_constraint_degrees(
155 self.internal_prover.vm.config().as_ref(),
156 &self.internal_prover.vm.engine.fri_params(),
157 );
158
159 let mut internal_node_idx = -1;
160 let mut internal_node_height = 0;
161 let mut proofs = leaf_proofs;
162 while proofs.len() > 1 || internal_node_height == 0 {
165 let internal_inputs = InternalVmVerifierInput::chunk_leaf_or_internal_proofs(
166 (*self.internal_prover.program_commitment()).into(),
167 &proofs,
168 self.num_children_internal,
169 );
170 proofs = info_span!(
171 "agg_layer",
172 group = format!("internal.{internal_node_height}")
173 )
174 .in_scope(|| {
175 #[cfg(feature = "metrics")]
176 {
177 metrics::counter!("fri.log_blowup")
178 .absolute(self.internal_prover.vm.engine.fri_params().log_blowup as u64);
179 metrics::counter!("num_children").absolute(self.num_children_internal as u64);
180 }
181 internal_inputs
182 .into_iter()
183 .map(|input| {
184 internal_node_idx += 1;
185 info_span!("single_internal_agg", idx = internal_node_idx,).in_scope(|| {
186 SingleSegmentVmProver::prove(
187 &mut self.internal_prover,
188 input.write(),
189 NATIVE_MAX_TRACE_HEIGHTS,
190 )
191 })
192 })
193 .collect::<Result<Vec<_>, _>>()
194 })?;
195 internal_node_height += 1;
196 }
197 let proof = proofs.pop().unwrap();
198 Ok(VmStarkProof {
199 inner: proof,
200 user_public_values: public_values,
201 })
202 }
203
204 #[cfg(feature = "evm-prove")]
206 fn wrap_e2e_stark_proof(
207 &mut self,
208 e2e_stark_proof: VmStarkProof<SC>,
209 ) -> Result<RootVmVerifierInput<SC>, VirtualMachineError> {
210 let internal_commit = (*self.internal_prover.program_commitment()).into();
211 let internal_prover = &mut self.internal_prover;
212 let root_prover = &mut self.root_prover;
213 let max_internal_wrapper_layers = self.max_internal_wrapper_layers;
214 fn heights_le(a: &[u32], b: &[u32]) -> bool {
215 assert_eq!(a.len(), b.len());
216 a.iter().zip(b.iter()).all(|(a, b)| a <= b)
217 }
218
219 let VmStarkProof {
220 inner: mut proof,
221 user_public_values,
222 } = e2e_stark_proof;
223 let mut wrapper_layers = 0;
224 loop {
225 let input = RootVmVerifierInput {
226 proofs: vec![proof.clone()],
227 public_values: user_public_values.clone(),
228 };
229 let actual_air_heights = root_prover.execute_for_air_heights(input)?;
230 if heights_le(&actual_air_heights, root_prover.fixed_air_heights()) {
232 break;
233 }
234 if wrapper_layers >= max_internal_wrapper_layers {
235 panic!("The heights of the root verifier still exceed the required heights after {} wrapper layers", max_internal_wrapper_layers);
236 }
237 wrapper_layers += 1;
238 let input = InternalVmVerifierInput {
239 self_program_commit: internal_commit,
240 proofs: vec![proof.clone()],
241 };
242 proof = info_span!(
243 "wrapper_layer",
244 group = format!("internal_wrapper.{wrapper_layers}")
245 )
246 .in_scope(|| {
247 #[cfg(feature = "metrics")]
248 {
249 metrics::counter!("fri.log_blowup")
250 .absolute(internal_prover.vm.engine.fri_params().log_blowup as u64);
251 }
252 SingleSegmentVmProver::prove(
253 internal_prover,
254 input.write(),
255 NATIVE_MAX_TRACE_HEIGHTS,
256 )
257 })?;
258 }
259 Ok(RootVmVerifierInput {
260 proofs: vec![proof],
261 public_values: user_public_values,
262 })
263 }
264
265 #[cfg(feature = "evm-prove")]
266 #[instrument(name = "agg_layer", skip_all, fields(group = "root", idx = 0))]
267 fn generate_root_proof_impl(
268 &mut self,
269 root_input: RootVmVerifierInput<SC>,
270 ) -> Result<Proof<RootSC>, VirtualMachineError> {
271 check_max_constraint_degrees(
272 self.root_prover.vm_config().as_ref(),
273 self.root_prover.fri_params(),
274 );
275 let input = root_input.write();
276 #[cfg(feature = "metrics")]
277 metrics::counter!("fri.log_blowup")
278 .absolute(self.root_prover.fri_params().log_blowup as u64);
279 SingleSegmentVmProver::prove(&mut self.root_prover, input, NATIVE_MAX_TRACE_HEIGHTS)
280 }
281}
282
283impl LeafProvingController {
284 pub fn with_num_children(mut self, num_children_leaf: usize) -> Self {
285 self.num_children = num_children_leaf;
286 self
287 }
288
289 #[instrument(name = "agg_layer", skip_all, fields(group = "leaf"))]
290 pub fn generate_proof<E, NativeBuilder>(
291 &self,
292 prover: &mut VmInstance<E, NativeBuilder>,
293 app_proofs: &ContinuationVmProof<SC>,
294 ) -> Result<Vec<Proof<SC>>, VirtualMachineError>
295 where
296 E: StarkFriEngine<SC = SC>,
297 NativeBuilder: VmBuilder<E, VmConfig = NativeConfig>,
298 <NativeConfig as VmExecutionConfig<F>>::Executor:
299 PreflightExecutor<F, <NativeBuilder as VmBuilder<E>>::RecordArena>,
300 {
301 #[cfg(feature = "metrics")]
302 {
303 metrics::counter!("fri.log_blowup")
304 .absolute(prover.vm.engine.fri_params().log_blowup as u64);
305 metrics::counter!("num_children").absolute(self.num_children as u64);
306 }
307 let leaf_inputs =
308 LeafVmVerifierInput::chunk_continuation_vm_proof(app_proofs, self.num_children);
309 tracing::info!("num_leaf_proofs={}", leaf_inputs.len());
310 leaf_inputs
311 .into_iter()
312 .enumerate()
313 .map(|(leaf_node_idx, input)| {
314 info_span!("single_leaf_agg", idx = leaf_node_idx).in_scope(|| {
315 SingleSegmentVmProver::prove(
316 prover,
317 input.write_to_stream(),
318 NATIVE_MAX_TRACE_HEIGHTS,
319 )
320 })
321 })
322 .collect()
323 }
324}