openvm_sdk/prover/
root.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use async_trait::async_trait;
use openvm_circuit::arch::{SingleSegmentVmExecutor, Streams};
use openvm_native_circuit::NativeConfig;
use openvm_native_recursion::hints::Hintable;
use openvm_stark_sdk::{
    config::baby_bear_poseidon2_root::BabyBearPoseidon2RootEngine,
    engine::{StarkEngine, StarkFriEngine},
    openvm_stark_backend::prover::types::Proof,
};

use crate::{
    keygen::RootVerifierProvingKey,
    prover::vm::{AsyncSingleSegmentVmProver, SingleSegmentVmProver},
    verifier::root::types::RootVmVerifierInput,
    RootSC, F, SC,
};

/// Local prover for a root verifier.
pub struct RootVerifierLocalProver {
    pub root_verifier_pk: RootVerifierProvingKey,
    executor_for_heights: SingleSegmentVmExecutor<F, NativeConfig>,
}

impl RootVerifierLocalProver {
    pub fn new(root_verifier_pk: RootVerifierProvingKey) -> Self {
        let executor_for_heights =
            SingleSegmentVmExecutor::<F, _>::new(root_verifier_pk.vm_pk.vm_config.clone());
        Self {
            root_verifier_pk,
            executor_for_heights,
        }
    }
    pub fn execute_for_air_heights(&self, input: RootVmVerifierInput<SC>) -> Vec<usize> {
        let result = self
            .executor_for_heights
            .execute(
                self.root_verifier_pk.root_committed_exe.exe.clone(),
                input.write(),
            )
            .unwrap();
        result.air_heights
    }
}

impl SingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
    fn prove(&self, input: impl Into<Streams<F>>) -> Proof<RootSC> {
        let input = input.into();
        let vm = SingleSegmentVmExecutor::new(self.root_verifier_pk.vm_pk.vm_config.clone());
        let mut proof_input = vm
            .execute_and_generate(self.root_verifier_pk.root_committed_exe.clone(), input)
            .unwrap();
        assert_eq!(
            proof_input.per_air.len(),
            self.root_verifier_pk.air_heights.len(),
            "All AIRs of root verifier should present"
        );
        proof_input.per_air.iter().for_each(|(air_id, input)| {
            assert_eq!(
                input.main_trace_height(),
                self.root_verifier_pk.air_heights[*air_id],
                "Trace height doesn't match"
            );
        });
        // Reorder the AIRs by heights.
        let air_id_perm = self.root_verifier_pk.air_id_permutation();
        air_id_perm.permute(&mut proof_input.per_air);
        for i in 0..proof_input.per_air.len() {
            // Overwrite the AIR ID.
            proof_input.per_air[i].0 = i;
        }
        let e = BabyBearPoseidon2RootEngine::new(self.root_verifier_pk.vm_pk.fri_params);
        e.prove(&self.root_verifier_pk.vm_pk.vm_pk, proof_input)
    }
}

#[async_trait]
impl AsyncSingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
    async fn prove(&self, input: impl Into<Streams<F>> + Send + Sync) -> Proof<RootSC> {
        SingleSegmentVmProver::prove(self, input)
    }
}