openvm_sdk/prover/
root.rs

1use async_trait::async_trait;
2use openvm_circuit::arch::{SingleSegmentVmExecutor, Streams};
3use openvm_continuations::verifier::root::types::RootVmVerifierInput;
4use openvm_native_circuit::NativeConfig;
5use openvm_native_recursion::hints::Hintable;
6use openvm_stark_sdk::{
7    config::{baby_bear_poseidon2_root::BabyBearPoseidon2RootEngine, FriParameters},
8    engine::{StarkEngine, StarkFriEngine},
9    openvm_stark_backend::proof::Proof,
10};
11
12use crate::{
13    keygen::RootVerifierProvingKey,
14    prover::vm::{AsyncSingleSegmentVmProver, SingleSegmentVmProver},
15    RootSC, F, SC,
16};
17
18/// Local prover for a root verifier.
19pub struct RootVerifierLocalProver {
20    pub root_verifier_pk: RootVerifierProvingKey,
21    executor_for_heights: SingleSegmentVmExecutor<F, NativeConfig>,
22}
23
24impl RootVerifierLocalProver {
25    pub fn new(root_verifier_pk: RootVerifierProvingKey) -> Self {
26        let executor_for_heights =
27            SingleSegmentVmExecutor::<F, _>::new(root_verifier_pk.vm_pk.vm_config.clone());
28        Self {
29            root_verifier_pk,
30            executor_for_heights,
31        }
32    }
33    pub fn execute_for_air_heights(&self, input: RootVmVerifierInput<SC>) -> Vec<usize> {
34        let result = self
35            .executor_for_heights
36            .execute_and_compute_heights(
37                self.root_verifier_pk.root_committed_exe.exe.clone(),
38                input.write(),
39            )
40            .unwrap();
41        result.air_heights
42    }
43    pub fn vm_config(&self) -> &NativeConfig {
44        &self.root_verifier_pk.vm_pk.vm_config
45    }
46    #[allow(dead_code)]
47    pub(crate) fn fri_params(&self) -> &FriParameters {
48        &self.root_verifier_pk.vm_pk.fri_params
49    }
50}
51
52impl SingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
53    fn prove(&self, input: impl Into<Streams<F>>) -> Proof<RootSC> {
54        let input = input.into();
55        let vm = SingleSegmentVmExecutor::new(self.vm_config().clone());
56        let mut proof_input = vm
57            .execute_and_generate(self.root_verifier_pk.root_committed_exe.clone(), input)
58            .unwrap();
59        assert_eq!(
60            proof_input.per_air.len(),
61            self.root_verifier_pk.air_heights.len(),
62            "All AIRs of root verifier should present"
63        );
64        proof_input.per_air.iter().for_each(|(air_id, input)| {
65            assert_eq!(
66                input.main_trace_height(),
67                self.root_verifier_pk.air_heights[*air_id],
68                "Trace height doesn't match"
69            );
70        });
71        // Reorder the AIRs by heights.
72        let air_id_perm = self.root_verifier_pk.air_id_permutation();
73        air_id_perm.permute(&mut proof_input.per_air);
74        for i in 0..proof_input.per_air.len() {
75            // Overwrite the AIR ID.
76            proof_input.per_air[i].0 = i;
77        }
78        let e = BabyBearPoseidon2RootEngine::new(*self.fri_params());
79        e.prove(&self.root_verifier_pk.vm_pk.vm_pk, proof_input)
80    }
81}
82
83#[async_trait]
84impl AsyncSingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
85    async fn prove(&self, input: impl Into<Streams<F>> + Send + Sync) -> Proof<RootSC> {
86        SingleSegmentVmProver::prove(self, input)
87    }
88}