openvm_sdk/prover/vm/
local.rs

1use std::{marker::PhantomData, mem, sync::Arc};
2
3use async_trait::async_trait;
4use openvm_circuit::{
5    arch::{
6        hasher::poseidon2::vm_poseidon2_hasher, GenerationError, SingleSegmentVmExecutor, Streams,
7        VirtualMachine, VmComplexTraceHeights, VmConfig,
8    },
9    system::{memory::tree::public_values::UserPublicValuesProof, program::trace::VmCommittedExe},
10};
11use openvm_stark_backend::{
12    config::{StarkGenericConfig, Val},
13    p3_field::PrimeField32,
14    proof::Proof,
15    Chip,
16};
17use openvm_stark_sdk::{config::FriParameters, engine::StarkFriEngine};
18use tracing::info_span;
19
20use crate::prover::vm::{
21    types::VmProvingKey, AsyncContinuationVmProver, AsyncSingleSegmentVmProver,
22    ContinuationVmProof, ContinuationVmProver, SingleSegmentVmProver,
23};
24
25pub struct VmLocalProver<SC: StarkGenericConfig, VC, E: StarkFriEngine<SC>> {
26    pub pk: Arc<VmProvingKey<SC, VC>>,
27    pub committed_exe: Arc<VmCommittedExe<SC>>,
28    overridden_heights: Option<VmComplexTraceHeights>,
29    _marker: PhantomData<E>,
30}
31
32impl<SC: StarkGenericConfig, VC, E: StarkFriEngine<SC>> VmLocalProver<SC, VC, E> {
33    pub fn new(pk: Arc<VmProvingKey<SC, VC>>, committed_exe: Arc<VmCommittedExe<SC>>) -> Self {
34        Self {
35            pk,
36            committed_exe,
37            overridden_heights: None,
38            _marker: PhantomData,
39        }
40    }
41
42    pub fn new_with_overridden_trace_heights(
43        pk: Arc<VmProvingKey<SC, VC>>,
44        committed_exe: Arc<VmCommittedExe<SC>>,
45        overridden_heights: Option<VmComplexTraceHeights>,
46    ) -> Self {
47        Self {
48            pk,
49            committed_exe,
50            overridden_heights,
51            _marker: PhantomData,
52        }
53    }
54
55    pub fn set_override_trace_heights(&mut self, overridden_heights: VmComplexTraceHeights) {
56        self.overridden_heights = Some(overridden_heights);
57    }
58
59    pub fn vm_config(&self) -> &VC {
60        &self.pk.vm_config
61    }
62    #[allow(dead_code)]
63    pub(crate) fn fri_params(&self) -> &FriParameters {
64        &self.pk.fri_params
65    }
66}
67
68const MAX_SEGMENTATION_RETRIES: usize = 4;
69
70impl<SC: StarkGenericConfig, VC: VmConfig<Val<SC>>, E: StarkFriEngine<SC>> ContinuationVmProver<SC>
71    for VmLocalProver<SC, VC, E>
72where
73    Val<SC>: PrimeField32,
74    VC::Executor: Chip<SC>,
75    VC::Periphery: Chip<SC>,
76{
77    fn prove(&self, input: impl Into<Streams<Val<SC>>>) -> ContinuationVmProof<SC> {
78        assert!(self.pk.vm_config.system().continuation_enabled);
79        let e = E::new(self.pk.fri_params);
80        let trace_height_constraints = self.pk.vm_pk.trace_height_constraints.clone();
81        let mut vm = VirtualMachine::new_with_overridden_trace_heights(
82            e,
83            self.pk.vm_config.clone(),
84            self.overridden_heights.clone(),
85        );
86        vm.set_trace_height_constraints(trace_height_constraints.clone());
87        let mut final_memory = None;
88        let VmCommittedExe {
89            exe,
90            committed_program,
91        } = self.committed_exe.as_ref();
92        let input = input.into();
93
94        // This loop should typically iterate exactly once. Only in exceptional cases will the
95        // segmentation produce an invalid segment and we will have to retry.
96        let mut retries = 0;
97        let per_segment = loop {
98            match vm.executor.execute_and_then(
99                exe.clone(),
100                input.clone(),
101                |seg_idx, mut seg| {
102                    final_memory = mem::take(&mut seg.final_memory);
103                    let proof_input = info_span!("trace_gen", segment = seg_idx)
104                        .in_scope(|| seg.generate_proof_input(Some(committed_program.clone())))?;
105                    info_span!("prove_segment", segment = seg_idx)
106                        .in_scope(|| Ok(vm.engine.prove(&self.pk.vm_pk, proof_input)))
107                },
108                GenerationError::Execution,
109            ) {
110                Ok(per_segment) => break per_segment,
111                Err(GenerationError::Execution(err)) => panic!("execution error: {err}"),
112                Err(GenerationError::TraceHeightsLimitExceeded) => {
113                    if retries >= MAX_SEGMENTATION_RETRIES {
114                        panic!(
115                            "trace heights limit exceeded after {MAX_SEGMENTATION_RETRIES} retries"
116                        );
117                    }
118                    retries += 1;
119                    tracing::info!(
120                        "trace heights limit exceeded; retrying execution (attempt {retries})"
121                    );
122                    let sys_config = vm.executor.config.system_mut();
123                    let new_seg_strat = sys_config.segmentation_strategy.stricter_strategy();
124                    sys_config.set_segmentation_strategy(new_seg_strat);
125                    // continue
126                }
127            };
128        };
129
130        let user_public_values = UserPublicValuesProof::compute(
131            self.pk.vm_config.system().memory_config.memory_dimensions(),
132            self.pk.vm_config.system().num_public_values,
133            &vm_poseidon2_hasher(),
134            final_memory.as_ref().unwrap(),
135        );
136        ContinuationVmProof {
137            per_segment,
138            user_public_values,
139        }
140    }
141}
142
143#[async_trait]
144impl<SC: StarkGenericConfig, VC: VmConfig<Val<SC>>, E: StarkFriEngine<SC>>
145    AsyncContinuationVmProver<SC> for VmLocalProver<SC, VC, E>
146where
147    VmLocalProver<SC, VC, E>: Send + Sync,
148    Val<SC>: PrimeField32,
149    VC::Executor: Chip<SC>,
150    VC::Periphery: Chip<SC>,
151{
152    async fn prove(
153        &self,
154        input: impl Into<Streams<Val<SC>>> + Send + Sync,
155    ) -> ContinuationVmProof<SC> {
156        ContinuationVmProver::prove(self, input)
157    }
158}
159
160impl<SC: StarkGenericConfig, VC: VmConfig<Val<SC>>, E: StarkFriEngine<SC>> SingleSegmentVmProver<SC>
161    for VmLocalProver<SC, VC, E>
162where
163    Val<SC>: PrimeField32,
164    VC::Executor: Chip<SC>,
165    VC::Periphery: Chip<SC>,
166{
167    fn prove(&self, input: impl Into<Streams<Val<SC>>>) -> Proof<SC> {
168        assert!(!self.pk.vm_config.system().continuation_enabled);
169        let e = E::new(self.pk.fri_params);
170        // note: use SingleSegmentVmExecutor so there's not a "segment" label in metrics
171        let executor = {
172            let mut executor = SingleSegmentVmExecutor::new(self.pk.vm_config.clone());
173            executor.set_trace_height_constraints(self.pk.vm_pk.trace_height_constraints.clone());
174            executor
175        };
176        let proof_input = executor
177            .execute_and_generate(self.committed_exe.clone(), input)
178            .unwrap();
179        let vm = VirtualMachine::new(e, executor.config);
180        vm.prove_single(&self.pk.vm_pk, proof_input)
181    }
182}
183
184#[async_trait]
185impl<SC: StarkGenericConfig, VC: VmConfig<Val<SC>>, E: StarkFriEngine<SC>>
186    AsyncSingleSegmentVmProver<SC> for VmLocalProver<SC, VC, E>
187where
188    VmLocalProver<SC, VC, E>: Send + Sync,
189    Val<SC>: PrimeField32,
190    VC::Executor: Chip<SC>,
191    VC::Periphery: Chip<SC>,
192{
193    async fn prove(&self, input: impl Into<Streams<Val<SC>>> + Send + Sync) -> Proof<SC> {
194        SingleSegmentVmProver::prove(self, input)
195    }
196}