openvm_sdk/prover/
agg.rs

1use std::sync::Arc;
2
3use openvm_circuit::arch::{
4    instructions::exe::VmExe, ContinuationVmProof, PreflightExecutor, SingleSegmentVmProver,
5    VirtualMachineError, VmBuilder, VmExecutionConfig, VmInstance,
6};
7#[cfg(feature = "evm-prove")]
8use openvm_continuations::verifier::root::types::RootVmVerifierInput;
9use openvm_continuations::verifier::{
10    internal::types::{InternalVmVerifierInput, VmStarkProof},
11    leaf::types::LeafVmVerifierInput,
12};
13use openvm_native_circuit::{NativeConfig, NATIVE_MAX_TRACE_HEIGHTS};
14use openvm_native_recursion::hints::Hintable;
15use openvm_stark_sdk::{engine::StarkFriEngine, openvm_stark_backend::proof::Proof};
16use tracing::{info_span, instrument};
17
18use crate::{
19    config::AggregationTreeConfig, keygen::AggProvingKey, prover::vm::new_local_prover,
20    util::check_max_constraint_degrees, F, SC,
21};
22#[cfg(feature = "evm-prove")]
23use crate::{prover::RootVerifierLocalProver, RootSC};
24
25pub struct AggStarkProver<E, NativeBuilder>
26where
27    E: StarkFriEngine<SC = SC>,
28    NativeBuilder: VmBuilder<E, VmConfig = NativeConfig>,
29{
30    leaf_prover: VmInstance<E, NativeBuilder>,
31    leaf_controller: LeafProvingController,
32
33    pub internal_prover: VmInstance<E, NativeBuilder>,
34    #[cfg(feature = "evm-prove")]
35    root_prover: RootVerifierLocalProver,
36    pub num_children_internal: usize,
37    pub max_internal_wrapper_layers: usize,
38}
39
40pub struct LeafProvingController {
41    /// Each leaf proof aggregations `<= num_children` App VM proofs
42    pub num_children: usize,
43}
44
45impl<E, NativeBuilder> AggStarkProver<E, NativeBuilder>
46where
47    E: StarkFriEngine<SC = SC>,
48    NativeBuilder: VmBuilder<E, VmConfig = NativeConfig> + Clone,
49    <NativeConfig as VmExecutionConfig<F>>::Executor:
50        PreflightExecutor<F, <NativeBuilder as VmBuilder<E>>::RecordArena>,
51{
52    pub fn new(
53        native_builder: NativeBuilder,
54        agg_pk: &AggProvingKey,
55        leaf_verifier_exe: Arc<VmExe<F>>,
56        tree_config: AggregationTreeConfig,
57    ) -> Result<Self, VirtualMachineError> {
58        let leaf_prover = new_local_prover(
59            native_builder.clone(),
60            &agg_pk.leaf_vm_pk,
61            leaf_verifier_exe,
62        )?;
63        let internal_prover = new_local_prover(
64            native_builder,
65            &agg_pk.internal_vm_pk,
66            agg_pk.internal_committed_exe.exe.clone(),
67        )?;
68        #[cfg(feature = "evm-prove")]
69        let root_prover = RootVerifierLocalProver::new(&agg_pk.root_verifier_pk)?;
70        Ok(Self::new_from_instances(
71            leaf_prover,
72            internal_prover,
73            #[cfg(feature = "evm-prove")]
74            root_prover,
75            tree_config,
76        ))
77    }
78
79    pub fn new_from_instances(
80        leaf_instance: VmInstance<E, NativeBuilder>,
81        internal_instance: VmInstance<E, NativeBuilder>,
82        #[cfg(feature = "evm-prove")] root_instance: RootVerifierLocalProver,
83        tree_config: AggregationTreeConfig,
84    ) -> Self {
85        let leaf_controller = LeafProvingController {
86            num_children: tree_config.num_children_leaf,
87        };
88        Self {
89            leaf_prover: leaf_instance,
90            leaf_controller,
91            internal_prover: internal_instance,
92            #[cfg(feature = "evm-prove")]
93            root_prover: root_instance,
94            num_children_internal: tree_config.num_children_internal,
95            max_internal_wrapper_layers: tree_config.max_internal_wrapper_layers,
96        }
97    }
98
99    pub fn with_num_children_leaf(mut self, num_children_leaf: usize) -> Self {
100        self.leaf_controller.num_children = num_children_leaf;
101        self
102    }
103
104    pub fn with_num_children_internal(mut self, num_children_internal: usize) -> Self {
105        self.num_children_internal = num_children_internal;
106        self
107    }
108
109    pub fn with_max_internal_wrapper_layers(mut self, max_internal_wrapper_layers: usize) -> Self {
110        self.max_internal_wrapper_layers = max_internal_wrapper_layers;
111        self
112    }
113
114    /// Generate the root proof for outer recursion.
115    #[cfg(feature = "evm-prove")]
116    pub fn generate_root_proof(
117        &mut self,
118        app_proofs: ContinuationVmProof<SC>,
119    ) -> Result<Proof<RootSC>, VirtualMachineError> {
120        let root_verifier_input = self.generate_root_verifier_input(app_proofs)?;
121        self.generate_root_proof_impl(root_verifier_input)
122    }
123
124    pub fn generate_leaf_proofs(
125        &mut self,
126        app_proofs: &ContinuationVmProof<SC>,
127    ) -> Result<Vec<Proof<SC>>, VirtualMachineError> {
128        check_max_constraint_degrees(
129            self.leaf_prover.vm.config().as_ref(),
130            &self.leaf_prover.vm.engine.fri_params(),
131        );
132        self.leaf_controller
133            .generate_proof(&mut self.leaf_prover, app_proofs)
134    }
135
136    /// This is typically only used for the halo2 verifier.
137    #[cfg(feature = "evm-prove")]
138    pub fn generate_root_verifier_input(
139        &mut self,
140        app_proofs: ContinuationVmProof<SC>,
141    ) -> Result<RootVmVerifierInput<SC>, VirtualMachineError> {
142        let leaf_proofs = self.generate_leaf_proofs(&app_proofs)?;
143        let public_values = app_proofs.user_public_values.public_values;
144        let e2e_stark_proof = self.aggregate_leaf_proofs(leaf_proofs, public_values)?;
145        let wrapped_stark_proof = self.wrap_e2e_stark_proof(e2e_stark_proof)?;
146        Ok(wrapped_stark_proof)
147    }
148
149    pub fn aggregate_leaf_proofs(
150        &mut self,
151        leaf_proofs: Vec<Proof<SC>>,
152        public_values: Vec<F>,
153    ) -> Result<VmStarkProof<SC>, VirtualMachineError> {
154        check_max_constraint_degrees(
155            self.internal_prover.vm.config().as_ref(),
156            &self.internal_prover.vm.engine.fri_params(),
157        );
158
159        let mut internal_node_idx = -1;
160        let mut internal_node_height = 0;
161        let mut proofs = leaf_proofs;
162        // We will always generate at least one internal proof, even if there is only one leaf
163        // proof, in order to shrink the proof size
164        while proofs.len() > 1 || internal_node_height == 0 {
165            let internal_inputs = InternalVmVerifierInput::chunk_leaf_or_internal_proofs(
166                (*self.internal_prover.program_commitment()).into(),
167                &proofs,
168                self.num_children_internal,
169            );
170            proofs = info_span!(
171                "agg_layer",
172                group = format!("internal.{internal_node_height}")
173            )
174            .in_scope(|| {
175                #[cfg(feature = "metrics")]
176                {
177                    metrics::counter!("fri.log_blowup")
178                        .absolute(self.internal_prover.vm.engine.fri_params().log_blowup as u64);
179                    metrics::counter!("num_children").absolute(self.num_children_internal as u64);
180                }
181                internal_inputs
182                    .into_iter()
183                    .map(|input| {
184                        internal_node_idx += 1;
185                        info_span!("single_internal_agg", idx = internal_node_idx,).in_scope(|| {
186                            SingleSegmentVmProver::prove(
187                                &mut self.internal_prover,
188                                input.write(),
189                                NATIVE_MAX_TRACE_HEIGHTS,
190                            )
191                        })
192                    })
193                    .collect::<Result<Vec<_>, _>>()
194            })?;
195            internal_node_height += 1;
196        }
197        let proof = proofs.pop().unwrap();
198        Ok(VmStarkProof {
199            inner: proof,
200            user_public_values: public_values,
201        })
202    }
203
204    /// Wrap the e2e stark proof until its heights meet the requirements of the root verifier.
205    #[cfg(feature = "evm-prove")]
206    fn wrap_e2e_stark_proof(
207        &mut self,
208        e2e_stark_proof: VmStarkProof<SC>,
209    ) -> Result<RootVmVerifierInput<SC>, VirtualMachineError> {
210        let internal_commit = (*self.internal_prover.program_commitment()).into();
211        let internal_prover = &mut self.internal_prover;
212        let root_prover = &mut self.root_prover;
213        let max_internal_wrapper_layers = self.max_internal_wrapper_layers;
214        fn heights_le(a: &[u32], b: &[u32]) -> bool {
215            assert_eq!(a.len(), b.len());
216            a.iter().zip(b.iter()).all(|(a, b)| a <= b)
217        }
218
219        let VmStarkProof {
220            inner: mut proof,
221            user_public_values,
222        } = e2e_stark_proof;
223        let mut wrapper_layers = 0;
224        loop {
225            let input = RootVmVerifierInput {
226                proofs: vec![proof.clone()],
227                public_values: user_public_values.clone(),
228            };
229            let actual_air_heights = root_prover.execute_for_air_heights(input)?;
230            // Root verifier can handle the internal proof. We can stop here.
231            if heights_le(&actual_air_heights, root_prover.fixed_air_heights()) {
232                break;
233            }
234            if wrapper_layers >= max_internal_wrapper_layers {
235                panic!("The heights of the root verifier still exceed the required heights after {} wrapper layers", max_internal_wrapper_layers);
236            }
237            wrapper_layers += 1;
238            let input = InternalVmVerifierInput {
239                self_program_commit: internal_commit,
240                proofs: vec![proof.clone()],
241            };
242            proof = info_span!(
243                "wrapper_layer",
244                group = format!("internal_wrapper.{wrapper_layers}")
245            )
246            .in_scope(|| {
247                #[cfg(feature = "metrics")]
248                {
249                    metrics::counter!("fri.log_blowup")
250                        .absolute(internal_prover.vm.engine.fri_params().log_blowup as u64);
251                }
252                SingleSegmentVmProver::prove(
253                    internal_prover,
254                    input.write(),
255                    NATIVE_MAX_TRACE_HEIGHTS,
256                )
257            })?;
258        }
259        Ok(RootVmVerifierInput {
260            proofs: vec![proof],
261            public_values: user_public_values,
262        })
263    }
264
265    #[cfg(feature = "evm-prove")]
266    #[instrument(name = "agg_layer", skip_all, fields(group = "root", idx = 0))]
267    fn generate_root_proof_impl(
268        &mut self,
269        root_input: RootVmVerifierInput<SC>,
270    ) -> Result<Proof<RootSC>, VirtualMachineError> {
271        check_max_constraint_degrees(
272            self.root_prover.vm_config().as_ref(),
273            self.root_prover.fri_params(),
274        );
275        let input = root_input.write();
276        #[cfg(feature = "metrics")]
277        metrics::counter!("fri.log_blowup")
278            .absolute(self.root_prover.fri_params().log_blowup as u64);
279        SingleSegmentVmProver::prove(&mut self.root_prover, input, NATIVE_MAX_TRACE_HEIGHTS)
280    }
281}
282
283impl LeafProvingController {
284    pub fn with_num_children(mut self, num_children_leaf: usize) -> Self {
285        self.num_children = num_children_leaf;
286        self
287    }
288
289    #[instrument(name = "agg_layer", skip_all, fields(group = "leaf"))]
290    pub fn generate_proof<E, NativeBuilder>(
291        &self,
292        prover: &mut VmInstance<E, NativeBuilder>,
293        app_proofs: &ContinuationVmProof<SC>,
294    ) -> Result<Vec<Proof<SC>>, VirtualMachineError>
295    where
296        E: StarkFriEngine<SC = SC>,
297        NativeBuilder: VmBuilder<E, VmConfig = NativeConfig>,
298        <NativeConfig as VmExecutionConfig<F>>::Executor:
299            PreflightExecutor<F, <NativeBuilder as VmBuilder<E>>::RecordArena>,
300    {
301        #[cfg(feature = "metrics")]
302        {
303            metrics::counter!("fri.log_blowup")
304                .absolute(prover.vm.engine.fri_params().log_blowup as u64);
305            metrics::counter!("num_children").absolute(self.num_children as u64);
306        }
307        let leaf_inputs =
308            LeafVmVerifierInput::chunk_continuation_vm_proof(app_proofs, self.num_children);
309        tracing::info!("num_leaf_proofs={}", leaf_inputs.len());
310        leaf_inputs
311            .into_iter()
312            .enumerate()
313            .map(|(leaf_node_idx, input)| {
314                info_span!("single_leaf_agg", idx = leaf_node_idx).in_scope(|| {
315                    SingleSegmentVmProver::prove(
316                        prover,
317                        input.write_to_stream(),
318                        NATIVE_MAX_TRACE_HEIGHTS,
319                    )
320                })
321            })
322            .collect()
323    }
324}