1use std::sync::Arc;
2
3use openvm_circuit::arch::ContinuationVmProof;
4use openvm_continuations::verifier::{
5 internal::types::InternalVmVerifierInput, leaf::types::LeafVmVerifierInput,
6 root::types::RootVmVerifierInput,
7};
8use openvm_native_circuit::NativeConfig;
9use openvm_native_recursion::hints::Hintable;
10use openvm_stark_sdk::{
11 config::baby_bear_poseidon2::BabyBearPoseidon2Engine, openvm_stark_backend::proof::Proof,
12};
13use tracing::info_span;
14
15use crate::{
16 keygen::AggStarkProvingKey,
17 prover::{
18 vm::{local::VmLocalProver, SingleSegmentVmProver},
19 RootVerifierLocalProver,
20 },
21 NonRootCommittedExe, RootSC, F, SC,
22};
23
24pub const DEFAULT_NUM_CHILDREN_LEAF: usize = 1;
25const DEFAULT_NUM_CHILDREN_INTERNAL: usize = 2;
26const DEFAULT_MAX_INTERNAL_WRAPPER_LAYERS: usize = 4;
27
28pub struct AggStarkProver {
29 leaf_prover: VmLocalProver<SC, NativeConfig, BabyBearPoseidon2Engine>,
30 leaf_controller: LeafProvingController,
31
32 internal_prover: VmLocalProver<SC, NativeConfig, BabyBearPoseidon2Engine>,
33 root_prover: RootVerifierLocalProver,
34
35 pub num_children_internal: usize,
36 pub max_internal_wrapper_layers: usize,
37}
38
39pub struct LeafProvingController {
40 pub num_children: usize,
42}
43
44impl AggStarkProver {
45 pub fn new(
46 agg_stark_pk: AggStarkProvingKey,
47 leaf_committed_exe: Arc<NonRootCommittedExe>,
48 ) -> Self {
49 let leaf_prover = VmLocalProver::<SC, NativeConfig, BabyBearPoseidon2Engine>::new(
50 agg_stark_pk.leaf_vm_pk,
51 leaf_committed_exe,
52 );
53 let leaf_controller = LeafProvingController {
54 num_children: DEFAULT_NUM_CHILDREN_LEAF,
55 };
56 let internal_prover = VmLocalProver::<SC, NativeConfig, BabyBearPoseidon2Engine>::new(
57 agg_stark_pk.internal_vm_pk,
58 agg_stark_pk.internal_committed_exe,
59 );
60 let root_prover = RootVerifierLocalProver::new(agg_stark_pk.root_verifier_pk);
61 Self {
62 leaf_prover,
63 leaf_controller,
64 internal_prover,
65 root_prover,
66 num_children_internal: DEFAULT_NUM_CHILDREN_INTERNAL,
67 max_internal_wrapper_layers: DEFAULT_MAX_INTERNAL_WRAPPER_LAYERS,
68 }
69 }
70
71 pub fn with_num_children_leaf(mut self, num_children_leaf: usize) -> Self {
72 self.leaf_controller.num_children = num_children_leaf;
73 self
74 }
75
76 pub fn with_num_children_internal(mut self, num_children_internal: usize) -> Self {
77 self.num_children_internal = num_children_internal;
78 self
79 }
80
81 pub fn with_max_internal_wrapper_layers(mut self, max_internal_wrapper_layers: usize) -> Self {
82 self.max_internal_wrapper_layers = max_internal_wrapper_layers;
83 self
84 }
85
86 pub fn generate_agg_proof(&self, app_proofs: ContinuationVmProof<SC>) -> Proof<RootSC> {
88 let root_verifier_input = self.generate_root_verifier_input(app_proofs);
89 self.generate_root_proof_impl(root_verifier_input)
90 }
91
92 pub fn generate_root_verifier_input(
93 &self,
94 app_proofs: ContinuationVmProof<SC>,
95 ) -> RootVmVerifierInput<SC> {
96 let leaf_proofs = self
97 .leaf_controller
98 .generate_proof(&self.leaf_prover, &app_proofs);
99 let public_values = app_proofs.user_public_values.public_values;
100 let internal_proof = self.generate_internal_proof_impl(leaf_proofs, &public_values);
101 RootVmVerifierInput {
102 proofs: vec![internal_proof],
103 public_values,
104 }
105 }
106
107 fn generate_internal_proof_impl(
108 &self,
109 leaf_proofs: Vec<Proof<SC>>,
110 public_values: &[F],
111 ) -> Proof<SC> {
112 let mut internal_node_idx = -1;
113 let mut internal_node_height = 0;
114 let mut proofs = leaf_proofs;
115 let mut wrapper_layers = 0;
116 loop {
117 if proofs.len() == 1 {
118 let actual_air_heights =
119 self.root_prover
120 .execute_for_air_heights(RootVmVerifierInput {
121 proofs: vec![proofs[0].clone()],
122 public_values: public_values.to_vec(),
123 });
124 if heights_le(
126 &actual_air_heights,
127 &self.root_prover.root_verifier_pk.air_heights,
128 ) {
129 break;
130 }
131 if wrapper_layers >= self.max_internal_wrapper_layers {
132 panic!("The heights of the root verifier still exceed the required heights after {} wrapper layers", self.max_internal_wrapper_layers);
133 }
134 wrapper_layers += 1;
135 }
136 let internal_inputs = InternalVmVerifierInput::chunk_leaf_or_internal_proofs(
137 self.internal_prover
138 .committed_exe
139 .get_program_commit()
140 .into(),
141 &proofs,
142 self.num_children_internal,
143 );
144 proofs = info_span!(
145 "agg_layer",
146 group = format!("internal.{internal_node_height}")
147 )
148 .in_scope(|| {
149 #[cfg(feature = "bench-metrics")]
150 {
151 metrics::counter!("fri.log_blowup")
152 .absolute(self.internal_prover.fri_params().log_blowup as u64);
153 metrics::counter!("num_children").absolute(self.num_children_internal as u64);
154 }
155 internal_inputs
156 .into_iter()
157 .map(|input| {
158 internal_node_idx += 1;
159 info_span!("single_internal_agg", idx = internal_node_idx,).in_scope(|| {
160 SingleSegmentVmProver::prove(&self.internal_prover, input.write())
161 })
162 })
163 .collect()
164 });
165 internal_node_height += 1;
166 }
167 proofs.pop().unwrap()
168 }
169
170 fn generate_root_proof_impl(&self, root_input: RootVmVerifierInput<SC>) -> Proof<RootSC> {
171 info_span!("agg_layer", group = "root", idx = 0).in_scope(|| {
172 let input = root_input.write();
173 #[cfg(feature = "bench-metrics")]
174 metrics::counter!("fri.log_blowup")
175 .absolute(self.root_prover.fri_params().log_blowup as u64);
176 SingleSegmentVmProver::prove(&self.root_prover, input)
177 })
178 }
179}
180
181impl LeafProvingController {
182 pub fn with_num_children(mut self, num_children_leaf: usize) -> Self {
183 self.num_children = num_children_leaf;
184 self
185 }
186
187 pub fn generate_proof(
188 &self,
189 prover: &VmLocalProver<SC, NativeConfig, BabyBearPoseidon2Engine>,
190 app_proofs: &ContinuationVmProof<SC>,
191 ) -> Vec<Proof<SC>> {
192 info_span!("agg_layer", group = "leaf").in_scope(|| {
193 #[cfg(feature = "bench-metrics")]
194 {
195 metrics::counter!("fri.log_blowup").absolute(prover.fri_params().log_blowup as u64);
196 metrics::counter!("num_children").absolute(self.num_children as u64);
197 }
198 let leaf_inputs =
199 LeafVmVerifierInput::chunk_continuation_vm_proof(app_proofs, self.num_children);
200 tracing::info!("num_leaf_proofs={}", leaf_inputs.len());
201 leaf_inputs
202 .into_iter()
203 .enumerate()
204 .map(|(leaf_node_idx, input)| {
205 info_span!("single_leaf_agg", idx = leaf_node_idx)
206 .in_scope(|| SingleSegmentVmProver::prove(prover, input.write_to_stream()))
207 })
208 .collect::<Vec<_>>()
209 })
210 }
211}
212
213fn heights_le(a: &[usize], b: &[usize]) -> bool {
214 assert_eq!(a.len(), b.len());
215 a.iter().zip(b.iter()).all(|(a, b)| a <= b)
216}