1use std::{marker::PhantomData, mem, sync::Arc};
2
3use async_trait::async_trait;
4use openvm_circuit::{
5 arch::{
6 hasher::poseidon2::vm_poseidon2_hasher, GenerationError, SingleSegmentVmExecutor, Streams,
7 VirtualMachine, VmComplexTraceHeights, VmConfig,
8 },
9 system::{memory::tree::public_values::UserPublicValuesProof, program::trace::VmCommittedExe},
10};
11use openvm_stark_backend::{
12 config::{StarkGenericConfig, Val},
13 p3_field::PrimeField32,
14 proof::Proof,
15 Chip,
16};
17use openvm_stark_sdk::{config::FriParameters, engine::StarkFriEngine};
18use tracing::info_span;
19
20use crate::prover::vm::{
21 types::VmProvingKey, AsyncContinuationVmProver, AsyncSingleSegmentVmProver,
22 ContinuationVmProof, ContinuationVmProver, SingleSegmentVmProver,
23};
24
25pub struct VmLocalProver<SC: StarkGenericConfig, VC, E: StarkFriEngine<SC>> {
26 pub pk: Arc<VmProvingKey<SC, VC>>,
27 pub committed_exe: Arc<VmCommittedExe<SC>>,
28 overridden_heights: Option<VmComplexTraceHeights>,
29 _marker: PhantomData<E>,
30}
31
32impl<SC: StarkGenericConfig, VC, E: StarkFriEngine<SC>> VmLocalProver<SC, VC, E> {
33 pub fn new(pk: Arc<VmProvingKey<SC, VC>>, committed_exe: Arc<VmCommittedExe<SC>>) -> Self {
34 Self {
35 pk,
36 committed_exe,
37 overridden_heights: None,
38 _marker: PhantomData,
39 }
40 }
41
42 pub fn new_with_overridden_trace_heights(
43 pk: Arc<VmProvingKey<SC, VC>>,
44 committed_exe: Arc<VmCommittedExe<SC>>,
45 overridden_heights: Option<VmComplexTraceHeights>,
46 ) -> Self {
47 Self {
48 pk,
49 committed_exe,
50 overridden_heights,
51 _marker: PhantomData,
52 }
53 }
54
55 pub fn set_override_trace_heights(&mut self, overridden_heights: VmComplexTraceHeights) {
56 self.overridden_heights = Some(overridden_heights);
57 }
58
59 pub fn vm_config(&self) -> &VC {
60 &self.pk.vm_config
61 }
62 #[allow(dead_code)]
63 pub(crate) fn fri_params(&self) -> &FriParameters {
64 &self.pk.fri_params
65 }
66}
67
68const MAX_SEGMENTATION_RETRIES: usize = 4;
69
70impl<SC: StarkGenericConfig, VC: VmConfig<Val<SC>>, E: StarkFriEngine<SC>> ContinuationVmProver<SC>
71 for VmLocalProver<SC, VC, E>
72where
73 Val<SC>: PrimeField32,
74 VC::Executor: Chip<SC>,
75 VC::Periphery: Chip<SC>,
76{
77 fn prove(&self, input: impl Into<Streams<Val<SC>>>) -> ContinuationVmProof<SC> {
78 assert!(self.pk.vm_config.system().continuation_enabled);
79 let e = E::new(self.pk.fri_params);
80 let trace_height_constraints = self.pk.vm_pk.trace_height_constraints.clone();
81 let mut vm = VirtualMachine::new_with_overridden_trace_heights(
82 e,
83 self.pk.vm_config.clone(),
84 self.overridden_heights.clone(),
85 );
86 vm.set_trace_height_constraints(trace_height_constraints.clone());
87 let mut final_memory = None;
88 let VmCommittedExe {
89 exe,
90 committed_program,
91 } = self.committed_exe.as_ref();
92 let input = input.into();
93
94 let mut retries = 0;
97 let per_segment = loop {
98 match vm.executor.execute_and_then(
99 exe.clone(),
100 input.clone(),
101 |seg_idx, mut seg| {
102 final_memory = mem::take(&mut seg.final_memory);
103 let proof_input = info_span!("trace_gen", segment = seg_idx)
104 .in_scope(|| seg.generate_proof_input(Some(committed_program.clone())))?;
105 info_span!("prove_segment", segment = seg_idx)
106 .in_scope(|| Ok(vm.engine.prove(&self.pk.vm_pk, proof_input)))
107 },
108 GenerationError::Execution,
109 ) {
110 Ok(per_segment) => break per_segment,
111 Err(GenerationError::Execution(err)) => panic!("execution error: {err}"),
112 Err(GenerationError::TraceHeightsLimitExceeded) => {
113 if retries >= MAX_SEGMENTATION_RETRIES {
114 panic!(
115 "trace heights limit exceeded after {MAX_SEGMENTATION_RETRIES} retries"
116 );
117 }
118 retries += 1;
119 tracing::info!(
120 "trace heights limit exceeded; retrying execution (attempt {retries})"
121 );
122 let sys_config = vm.executor.config.system_mut();
123 let new_seg_strat = sys_config.segmentation_strategy.stricter_strategy();
124 sys_config.set_segmentation_strategy(new_seg_strat);
125 }
127 };
128 };
129
130 let user_public_values = UserPublicValuesProof::compute(
131 self.pk.vm_config.system().memory_config.memory_dimensions(),
132 self.pk.vm_config.system().num_public_values,
133 &vm_poseidon2_hasher(),
134 final_memory.as_ref().unwrap(),
135 );
136 ContinuationVmProof {
137 per_segment,
138 user_public_values,
139 }
140 }
141}
142
143#[async_trait]
144impl<SC: StarkGenericConfig, VC: VmConfig<Val<SC>>, E: StarkFriEngine<SC>>
145 AsyncContinuationVmProver<SC> for VmLocalProver<SC, VC, E>
146where
147 VmLocalProver<SC, VC, E>: Send + Sync,
148 Val<SC>: PrimeField32,
149 VC::Executor: Chip<SC>,
150 VC::Periphery: Chip<SC>,
151{
152 async fn prove(
153 &self,
154 input: impl Into<Streams<Val<SC>>> + Send + Sync,
155 ) -> ContinuationVmProof<SC> {
156 ContinuationVmProver::prove(self, input)
157 }
158}
159
160impl<SC: StarkGenericConfig, VC: VmConfig<Val<SC>>, E: StarkFriEngine<SC>> SingleSegmentVmProver<SC>
161 for VmLocalProver<SC, VC, E>
162where
163 Val<SC>: PrimeField32,
164 VC::Executor: Chip<SC>,
165 VC::Periphery: Chip<SC>,
166{
167 fn prove(&self, input: impl Into<Streams<Val<SC>>>) -> Proof<SC> {
168 assert!(!self.pk.vm_config.system().continuation_enabled);
169 let e = E::new(self.pk.fri_params);
170 let executor = {
172 let mut executor = SingleSegmentVmExecutor::new(self.pk.vm_config.clone());
173 executor.set_trace_height_constraints(self.pk.vm_pk.trace_height_constraints.clone());
174 executor
175 };
176 let proof_input = executor
177 .execute_and_generate(self.committed_exe.clone(), input)
178 .unwrap();
179 let vm = VirtualMachine::new(e, executor.config);
180 vm.prove_single(&self.pk.vm_pk, proof_input)
181 }
182}
183
184#[async_trait]
185impl<SC: StarkGenericConfig, VC: VmConfig<Val<SC>>, E: StarkFriEngine<SC>>
186 AsyncSingleSegmentVmProver<SC> for VmLocalProver<SC, VC, E>
187where
188 VmLocalProver<SC, VC, E>: Send + Sync,
189 Val<SC>: PrimeField32,
190 VC::Executor: Chip<SC>,
191 VC::Periphery: Chip<SC>,
192{
193 async fn prove(&self, input: impl Into<Streams<Val<SC>>> + Send + Sync) -> Proof<SC> {
194 SingleSegmentVmProver::prove(self, input)
195 }
196}