openvm_sdk/prover/
agg.rs

1use std::sync::Arc;
2
3use openvm_circuit::arch::ContinuationVmProof;
4use openvm_continuations::verifier::{
5    internal::types::InternalVmVerifierInput, leaf::types::LeafVmVerifierInput,
6    root::types::RootVmVerifierInput,
7};
8use openvm_native_circuit::NativeConfig;
9use openvm_native_recursion::hints::Hintable;
10use openvm_stark_sdk::{
11    config::baby_bear_poseidon2::BabyBearPoseidon2Engine, openvm_stark_backend::proof::Proof,
12};
13use tracing::info_span;
14
15use crate::{
16    keygen::AggStarkProvingKey,
17    prover::{
18        vm::{local::VmLocalProver, SingleSegmentVmProver},
19        RootVerifierLocalProver,
20    },
21    NonRootCommittedExe, RootSC, F, SC,
22};
23
24pub const DEFAULT_NUM_CHILDREN_LEAF: usize = 1;
25const DEFAULT_NUM_CHILDREN_INTERNAL: usize = 2;
26const DEFAULT_MAX_INTERNAL_WRAPPER_LAYERS: usize = 4;
27
28pub struct AggStarkProver {
29    leaf_prover: VmLocalProver<SC, NativeConfig, BabyBearPoseidon2Engine>,
30    leaf_controller: LeafProvingController,
31
32    internal_prover: VmLocalProver<SC, NativeConfig, BabyBearPoseidon2Engine>,
33    root_prover: RootVerifierLocalProver,
34
35    pub num_children_internal: usize,
36    pub max_internal_wrapper_layers: usize,
37}
38
39pub struct LeafProvingController {
40    /// Each leaf proof aggregations `<= num_children` App VM proofs
41    pub num_children: usize,
42}
43
44impl AggStarkProver {
45    pub fn new(
46        agg_stark_pk: AggStarkProvingKey,
47        leaf_committed_exe: Arc<NonRootCommittedExe>,
48    ) -> Self {
49        let leaf_prover = VmLocalProver::<SC, NativeConfig, BabyBearPoseidon2Engine>::new(
50            agg_stark_pk.leaf_vm_pk,
51            leaf_committed_exe,
52        );
53        let leaf_controller = LeafProvingController {
54            num_children: DEFAULT_NUM_CHILDREN_LEAF,
55        };
56        let internal_prover = VmLocalProver::<SC, NativeConfig, BabyBearPoseidon2Engine>::new(
57            agg_stark_pk.internal_vm_pk,
58            agg_stark_pk.internal_committed_exe,
59        );
60        let root_prover = RootVerifierLocalProver::new(agg_stark_pk.root_verifier_pk);
61        Self {
62            leaf_prover,
63            leaf_controller,
64            internal_prover,
65            root_prover,
66            num_children_internal: DEFAULT_NUM_CHILDREN_INTERNAL,
67            max_internal_wrapper_layers: DEFAULT_MAX_INTERNAL_WRAPPER_LAYERS,
68        }
69    }
70
71    pub fn with_num_children_leaf(mut self, num_children_leaf: usize) -> Self {
72        self.leaf_controller.num_children = num_children_leaf;
73        self
74    }
75
76    pub fn with_num_children_internal(mut self, num_children_internal: usize) -> Self {
77        self.num_children_internal = num_children_internal;
78        self
79    }
80
81    pub fn with_max_internal_wrapper_layers(mut self, max_internal_wrapper_layers: usize) -> Self {
82        self.max_internal_wrapper_layers = max_internal_wrapper_layers;
83        self
84    }
85
86    /// Generate a proof to aggregate app proofs.
87    pub fn generate_agg_proof(&self, app_proofs: ContinuationVmProof<SC>) -> Proof<RootSC> {
88        let root_verifier_input = self.generate_root_verifier_input(app_proofs);
89        self.generate_root_proof_impl(root_verifier_input)
90    }
91
92    pub fn generate_root_verifier_input(
93        &self,
94        app_proofs: ContinuationVmProof<SC>,
95    ) -> RootVmVerifierInput<SC> {
96        let leaf_proofs = self
97            .leaf_controller
98            .generate_proof(&self.leaf_prover, &app_proofs);
99        let public_values = app_proofs.user_public_values.public_values;
100        let internal_proof = self.generate_internal_proof_impl(leaf_proofs, &public_values);
101        RootVmVerifierInput {
102            proofs: vec![internal_proof],
103            public_values,
104        }
105    }
106
107    fn generate_internal_proof_impl(
108        &self,
109        leaf_proofs: Vec<Proof<SC>>,
110        public_values: &[F],
111    ) -> Proof<SC> {
112        let mut internal_node_idx = -1;
113        let mut internal_node_height = 0;
114        let mut proofs = leaf_proofs;
115        let mut wrapper_layers = 0;
116        loop {
117            if proofs.len() == 1 {
118                let actual_air_heights =
119                    self.root_prover
120                        .execute_for_air_heights(RootVmVerifierInput {
121                            proofs: vec![proofs[0].clone()],
122                            public_values: public_values.to_vec(),
123                        });
124                // Root verifier can handle the internal proof. We can stop here.
125                if heights_le(
126                    &actual_air_heights,
127                    &self.root_prover.root_verifier_pk.air_heights,
128                ) {
129                    break;
130                }
131                if wrapper_layers >= self.max_internal_wrapper_layers {
132                    panic!("The heights of the root verifier still exceed the required heights after {} wrapper layers", self.max_internal_wrapper_layers);
133                }
134                wrapper_layers += 1;
135            }
136            let internal_inputs = InternalVmVerifierInput::chunk_leaf_or_internal_proofs(
137                self.internal_prover
138                    .committed_exe
139                    .get_program_commit()
140                    .into(),
141                &proofs,
142                self.num_children_internal,
143            );
144            proofs = info_span!(
145                "agg_layer",
146                group = format!("internal.{internal_node_height}")
147            )
148            .in_scope(|| {
149                #[cfg(feature = "bench-metrics")]
150                {
151                    metrics::counter!("fri.log_blowup")
152                        .absolute(self.internal_prover.fri_params().log_blowup as u64);
153                    metrics::counter!("num_children").absolute(self.num_children_internal as u64);
154                }
155                internal_inputs
156                    .into_iter()
157                    .map(|input| {
158                        internal_node_idx += 1;
159                        info_span!("single_internal_agg", idx = internal_node_idx,).in_scope(|| {
160                            SingleSegmentVmProver::prove(&self.internal_prover, input.write())
161                        })
162                    })
163                    .collect()
164            });
165            internal_node_height += 1;
166        }
167        proofs.pop().unwrap()
168    }
169
170    fn generate_root_proof_impl(&self, root_input: RootVmVerifierInput<SC>) -> Proof<RootSC> {
171        info_span!("agg_layer", group = "root", idx = 0).in_scope(|| {
172            let input = root_input.write();
173            #[cfg(feature = "bench-metrics")]
174            metrics::counter!("fri.log_blowup")
175                .absolute(self.root_prover.fri_params().log_blowup as u64);
176            SingleSegmentVmProver::prove(&self.root_prover, input)
177        })
178    }
179}
180
181impl LeafProvingController {
182    pub fn with_num_children(mut self, num_children_leaf: usize) -> Self {
183        self.num_children = num_children_leaf;
184        self
185    }
186
187    pub fn generate_proof(
188        &self,
189        prover: &VmLocalProver<SC, NativeConfig, BabyBearPoseidon2Engine>,
190        app_proofs: &ContinuationVmProof<SC>,
191    ) -> Vec<Proof<SC>> {
192        info_span!("agg_layer", group = "leaf").in_scope(|| {
193            #[cfg(feature = "bench-metrics")]
194            {
195                metrics::counter!("fri.log_blowup").absolute(prover.fri_params().log_blowup as u64);
196                metrics::counter!("num_children").absolute(self.num_children as u64);
197            }
198            let leaf_inputs =
199                LeafVmVerifierInput::chunk_continuation_vm_proof(app_proofs, self.num_children);
200            tracing::info!("num_leaf_proofs={}", leaf_inputs.len());
201            leaf_inputs
202                .into_iter()
203                .enumerate()
204                .map(|(leaf_node_idx, input)| {
205                    info_span!("single_leaf_agg", idx = leaf_node_idx)
206                        .in_scope(|| SingleSegmentVmProver::prove(prover, input.write_to_stream()))
207                })
208                .collect::<Vec<_>>()
209        })
210    }
211}
212
213fn heights_le(a: &[usize], b: &[usize]) -> bool {
214    assert_eq!(a.len(), b.len());
215    a.iter().zip(b.iter()).all(|(a, b)| a <= b)
216}