openvm_sdk/prover/
root.rs
1use async_trait::async_trait;
2use openvm_circuit::arch::{SingleSegmentVmExecutor, Streams};
3use openvm_continuations::verifier::root::types::RootVmVerifierInput;
4use openvm_native_circuit::NativeConfig;
5use openvm_native_recursion::hints::Hintable;
6use openvm_stark_sdk::{
7 config::{baby_bear_poseidon2_root::BabyBearPoseidon2RootEngine, FriParameters},
8 engine::{StarkEngine, StarkFriEngine},
9 openvm_stark_backend::proof::Proof,
10};
11
12use crate::{
13 keygen::RootVerifierProvingKey,
14 prover::vm::{AsyncSingleSegmentVmProver, SingleSegmentVmProver},
15 RootSC, F, SC,
16};
17
18pub struct RootVerifierLocalProver {
20 pub root_verifier_pk: RootVerifierProvingKey,
21 executor_for_heights: SingleSegmentVmExecutor<F, NativeConfig>,
22}
23
24impl RootVerifierLocalProver {
25 pub fn new(root_verifier_pk: RootVerifierProvingKey) -> Self {
26 let executor_for_heights =
27 SingleSegmentVmExecutor::<F, _>::new(root_verifier_pk.vm_pk.vm_config.clone());
28 Self {
29 root_verifier_pk,
30 executor_for_heights,
31 }
32 }
33 pub fn execute_for_air_heights(&self, input: RootVmVerifierInput<SC>) -> Vec<usize> {
34 let result = self
35 .executor_for_heights
36 .execute_and_compute_heights(
37 self.root_verifier_pk.root_committed_exe.exe.clone(),
38 input.write(),
39 )
40 .unwrap();
41 result.air_heights
42 }
43 pub fn vm_config(&self) -> &NativeConfig {
44 &self.root_verifier_pk.vm_pk.vm_config
45 }
46 #[allow(dead_code)]
47 pub(crate) fn fri_params(&self) -> &FriParameters {
48 &self.root_verifier_pk.vm_pk.fri_params
49 }
50}
51
52impl SingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
53 fn prove(&self, input: impl Into<Streams<F>>) -> Proof<RootSC> {
54 let input = input.into();
55 let vm = SingleSegmentVmExecutor::new(self.vm_config().clone());
56 let mut proof_input = vm
57 .execute_and_generate(self.root_verifier_pk.root_committed_exe.clone(), input)
58 .unwrap();
59 assert_eq!(
60 proof_input.per_air.len(),
61 self.root_verifier_pk.air_heights.len(),
62 "All AIRs of root verifier should present"
63 );
64 proof_input.per_air.iter().for_each(|(air_id, input)| {
65 assert_eq!(
66 input.main_trace_height(),
67 self.root_verifier_pk.air_heights[*air_id],
68 "Trace height doesn't match"
69 );
70 });
71 let air_id_perm = self.root_verifier_pk.air_id_permutation();
73 air_id_perm.permute(&mut proof_input.per_air);
74 for i in 0..proof_input.per_air.len() {
75 proof_input.per_air[i].0 = i;
77 }
78 let e = BabyBearPoseidon2RootEngine::new(*self.fri_params());
79 e.prove(&self.root_verifier_pk.vm_pk.vm_pk, proof_input)
80 }
81}
82
83#[async_trait]
84impl AsyncSingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
85 async fn prove(&self, input: impl Into<Streams<F>> + Send + Sync) -> Proof<RootSC> {
86 SingleSegmentVmProver::prove(self, input)
87 }
88}