openvm_sdk/prover/
root.rs
1use async_trait::async_trait;
2use openvm_circuit::arch::{SingleSegmentVmExecutor, Streams};
3use openvm_continuations::verifier::root::types::RootVmVerifierInput;
4use openvm_native_circuit::NativeConfig;
5use openvm_native_recursion::hints::Hintable;
6use openvm_stark_sdk::{
7 config::{baby_bear_poseidon2_root::BabyBearPoseidon2RootEngine, FriParameters},
8 engine::{StarkEngine, StarkFriEngine},
9 openvm_stark_backend::proof::Proof,
10};
11
12use crate::{
13 keygen::RootVerifierProvingKey,
14 prover::vm::{AsyncSingleSegmentVmProver, SingleSegmentVmProver},
15 RootSC, F, SC,
16};
17
18pub struct RootVerifierLocalProver {
20 pub root_verifier_pk: RootVerifierProvingKey,
21 executor_for_heights: SingleSegmentVmExecutor<F, NativeConfig>,
22}
23
24impl RootVerifierLocalProver {
25 pub fn new(root_verifier_pk: RootVerifierProvingKey) -> Self {
26 let executor_for_heights =
27 SingleSegmentVmExecutor::<F, _>::new(root_verifier_pk.vm_pk.vm_config.clone());
28 Self {
29 root_verifier_pk,
30 executor_for_heights,
31 }
32 }
33 pub fn execute_for_air_heights(&self, input: RootVmVerifierInput<SC>) -> Vec<usize> {
34 let result = self
35 .executor_for_heights
36 .execute_and_compute_heights(
37 self.root_verifier_pk.root_committed_exe.exe.clone(),
38 input.write(),
39 )
40 .unwrap();
41 result.air_heights
42 }
43 pub fn vm_config(&self) -> &NativeConfig {
44 &self.root_verifier_pk.vm_pk.vm_config
45 }
46 #[allow(dead_code)]
47 pub(crate) fn fri_params(&self) -> &FriParameters {
48 &self.root_verifier_pk.vm_pk.fri_params
49 }
50}
51
52impl SingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
53 fn prove(&self, input: impl Into<Streams<F>>) -> Proof<RootSC> {
54 let input = input.into();
55 let mut vm = SingleSegmentVmExecutor::new(self.vm_config().clone());
56 vm.set_override_trace_heights(self.root_verifier_pk.vm_heights.clone());
57 let mut proof_input = vm
58 .execute_and_generate(self.root_verifier_pk.root_committed_exe.clone(), input)
59 .unwrap();
60 assert_eq!(
61 proof_input.per_air.len(),
62 self.root_verifier_pk.air_heights.len(),
63 "All AIRs of root verifier should present"
64 );
65 proof_input.per_air.iter().for_each(|(air_id, input)| {
66 assert_eq!(
67 input.main_trace_height(),
68 self.root_verifier_pk.air_heights[*air_id],
69 "Trace height doesn't match"
70 );
71 });
72 let air_id_perm = self.root_verifier_pk.air_id_permutation();
74 air_id_perm.permute(&mut proof_input.per_air);
75 for i in 0..proof_input.per_air.len() {
76 proof_input.per_air[i].0 = i;
78 }
79 let e = BabyBearPoseidon2RootEngine::new(*self.fri_params());
80 e.prove(&self.root_verifier_pk.vm_pk.vm_pk, proof_input)
81 }
82}
83
84#[async_trait]
85impl AsyncSingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
86 async fn prove(&self, input: impl Into<Streams<F>> + Send + Sync) -> Proof<RootSC> {
87 SingleSegmentVmProver::prove(self, input)
88 }
89}