1use std::sync::Arc;
2
3use backtrace::Backtrace;
4use openvm_instructions::{
5 exe::FnBounds,
6 instruction::{DebugInfo, Instruction},
7 program::Program,
8};
9use openvm_stark_backend::{
10 config::{Domain, StarkGenericConfig},
11 keygen::types::LinearConstraint,
12 p3_commit::PolynomialSpace,
13 p3_field::PrimeField32,
14 prover::types::{CommittedTraceData, ProofInput},
15 utils::metrics_span,
16 Chip,
17};
18
19use super::{
20 ExecutionError, GenerationError, Streams, SystemBase, SystemConfig, VmChipComplex,
21 VmComplexTraceHeights, VmConfig,
22};
23#[cfg(feature = "bench-metrics")]
24use crate::metrics::VmMetrics;
25use crate::{
26 arch::{instructions::*, ExecutionState, InstructionExecutor},
27 system::memory::MemoryImage,
28};
29
30const SEGMENT_CHECK_INTERVAL: usize = 100;
32
33const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100;
34const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120;
41
42pub trait SegmentationStrategy:
43 std::fmt::Debug + Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe
44{
45 fn should_segment(
49 &self,
50 air_names: &[String],
51 trace_heights: &[usize],
52 trace_cells: &[usize],
53 ) -> bool;
54
55 fn stricter_strategy(&self) -> Arc<dyn SegmentationStrategy>;
60}
61
62#[derive(Debug, Clone)]
64pub struct DefaultSegmentationStrategy {
65 max_segment_len: usize,
66 max_cells_per_chip_in_segment: usize,
67}
68
69impl Default for DefaultSegmentationStrategy {
70 fn default() -> Self {
71 Self {
72 max_segment_len: DEFAULT_MAX_SEGMENT_LEN,
73 max_cells_per_chip_in_segment: DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT,
74 }
75 }
76}
77
78impl DefaultSegmentationStrategy {
79 pub fn new_with_max_segment_len(max_segment_len: usize) -> Self {
80 Self {
81 max_segment_len,
82 max_cells_per_chip_in_segment: max_segment_len * 120,
83 }
84 }
85
86 pub fn new(max_segment_len: usize, max_cells_per_chip_in_segment: usize) -> Self {
87 Self {
88 max_segment_len,
89 max_cells_per_chip_in_segment,
90 }
91 }
92
93 pub fn max_segment_len(&self) -> usize {
94 self.max_segment_len
95 }
96}
97
98const SEGMENTATION_BACKOFF_FACTOR: usize = 4;
99
100impl SegmentationStrategy for DefaultSegmentationStrategy {
101 fn should_segment(
102 &self,
103 air_names: &[String],
104 trace_heights: &[usize],
105 trace_cells: &[usize],
106 ) -> bool {
107 for (i, &height) in trace_heights.iter().enumerate() {
108 if height > self.max_segment_len {
109 tracing::info!(
110 "Should segment because chip {} (name: {}) has height {}",
111 i,
112 air_names[i],
113 height
114 );
115 return true;
116 }
117 }
118 for (i, &num_cells) in trace_cells.iter().enumerate() {
119 if num_cells > self.max_cells_per_chip_in_segment {
120 tracing::info!(
121 "Should segment because chip {} (name: {}) has {} cells",
122 i,
123 air_names[i],
124 num_cells
125 );
126 return true;
127 }
128 }
129 false
130 }
131
132 fn stricter_strategy(&self) -> Arc<dyn SegmentationStrategy> {
133 Arc::new(Self {
134 max_segment_len: self.max_segment_len / SEGMENTATION_BACKOFF_FACTOR,
135 max_cells_per_chip_in_segment: self.max_cells_per_chip_in_segment
136 / SEGMENTATION_BACKOFF_FACTOR,
137 })
138 }
139}
140
141pub struct ExecutionSegment<F, VC>
142where
143 F: PrimeField32,
144 VC: VmConfig<F>,
145{
146 pub chip_complex: VmChipComplex<F, VC::Executor, VC::Periphery>,
147 pub final_memory: Option<MemoryImage<F>>,
149
150 pub since_last_segment_check: usize,
151 pub trace_height_constraints: Vec<LinearConstraint>,
152
153 pub(crate) air_names: Vec<String>,
155 #[cfg(feature = "bench-metrics")]
157 pub metrics: VmMetrics,
158}
159
160pub struct ExecutionSegmentState {
161 pub pc: u32,
162 pub is_terminated: bool,
163}
164
165impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {
166 pub fn new(
168 config: &VC,
169 program: Program<F>,
170 init_streams: Streams<F>,
171 initial_memory: Option<MemoryImage<F>>,
172 trace_height_constraints: Vec<LinearConstraint>,
173 #[allow(unused_variables)] fn_bounds: FnBounds,
174 ) -> Self {
175 let mut chip_complex = config.create_chip_complex().unwrap();
176 chip_complex.set_streams(init_streams);
177 let program = if !config.system().profiling {
178 program.strip_debug_infos()
179 } else {
180 program
181 };
182 chip_complex.set_program(program);
183
184 if let Some(initial_memory) = initial_memory {
185 chip_complex.set_initial_memory(initial_memory);
186 }
187 let air_names = chip_complex.air_names();
188
189 Self {
190 chip_complex,
191 final_memory: None,
192 air_names,
193 trace_height_constraints,
194 #[cfg(feature = "bench-metrics")]
195 metrics: VmMetrics {
196 fn_bounds,
197 ..Default::default()
198 },
199 since_last_segment_check: 0,
200 }
201 }
202
203 pub fn system_config(&self) -> &SystemConfig {
204 self.chip_complex.config()
205 }
206
207 pub fn set_override_trace_heights(&mut self, overridden_heights: VmComplexTraceHeights) {
208 self.chip_complex
209 .set_override_system_trace_heights(overridden_heights.system);
210 self.chip_complex
211 .set_override_inventory_trace_heights(overridden_heights.inventory);
212 }
213
214 pub fn execute_from_pc(
216 &mut self,
217 mut pc: u32,
218 ) -> Result<ExecutionSegmentState, ExecutionError> {
219 let mut timestamp = self.chip_complex.memory_controller().timestamp();
220 let mut prev_backtrace: Option<Backtrace> = None;
221
222 self.chip_complex
223 .connector_chip_mut()
224 .begin(ExecutionState::new(pc, timestamp));
225
226 let mut did_terminate = false;
227
228 loop {
229 #[allow(unused_variables)]
230 let (opcode, dsl_instr) = {
231 let Self {
232 chip_complex,
233 #[cfg(feature = "bench-metrics")]
234 metrics,
235 ..
236 } = self;
237 let SystemBase {
238 program_chip,
239 memory_controller,
240 ..
241 } = &mut chip_complex.base;
242
243 let (instruction, debug_info) = program_chip.get_instruction(pc)?;
244 tracing::trace!("pc: {pc:#x} | time: {timestamp} | {:?}", instruction);
245
246 #[allow(unused_variables)]
247 let (dsl_instr, trace) = debug_info.as_ref().map_or(
248 (None, None),
249 |DebugInfo {
250 dsl_instruction,
251 trace,
252 }| (Some(dsl_instruction), trace.as_ref()),
253 );
254
255 let &Instruction { opcode, c, .. } = instruction;
256 if opcode == SystemOpcode::TERMINATE.global_opcode() {
257 did_terminate = true;
258 self.chip_complex.connector_chip_mut().end(
259 ExecutionState::new(pc, timestamp),
260 Some(c.as_canonical_u32()),
261 );
262 break;
263 }
264
265 if opcode == SystemOpcode::PHANTOM.global_opcode() {
267 let discriminant = c.as_canonical_u32() as u16;
269 let phantom = SysPhantom::from_repr(discriminant);
270 tracing::trace!("pc: {pc:#x} | system phantom: {phantom:?}");
271 match phantom {
272 Some(SysPhantom::DebugPanic) => {
273 if let Some(mut backtrace) = prev_backtrace {
274 backtrace.resolve();
275 eprintln!("openvm program failure; backtrace:\n{:?}", backtrace);
276 } else {
277 eprintln!("openvm program failure; no backtrace");
278 }
279 return Err(ExecutionError::Fail { pc });
280 }
281 Some(SysPhantom::CtStart) =>
282 {
283 #[cfg(feature = "bench-metrics")]
284 metrics
285 .cycle_tracker
286 .start(dsl_instr.cloned().unwrap_or("Default".to_string()))
287 }
288 Some(SysPhantom::CtEnd) =>
289 {
290 #[cfg(feature = "bench-metrics")]
291 metrics
292 .cycle_tracker
293 .end(dsl_instr.cloned().unwrap_or("Default".to_string()))
294 }
295 _ => {}
296 }
297 }
298 prev_backtrace = trace.cloned();
299
300 if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) {
301 let next_state = InstructionExecutor::execute(
302 executor,
303 memory_controller,
304 instruction,
305 ExecutionState::new(pc, timestamp),
306 )?;
307 assert!(next_state.timestamp > timestamp);
308 pc = next_state.pc;
309 timestamp = next_state.timestamp;
310 } else {
311 return Err(ExecutionError::DisabledOperation { pc, opcode });
312 };
313 (opcode, dsl_instr.cloned())
314 };
315
316 #[cfg(feature = "bench-metrics")]
317 self.update_instruction_metrics(pc, opcode, dsl_instr);
318
319 if self.should_segment() {
320 self.chip_complex
321 .connector_chip_mut()
322 .end(ExecutionState::new(pc, timestamp), None);
323 break;
324 }
325 }
326 self.final_memory = Some(
327 self.chip_complex
328 .base
329 .memory_controller
330 .memory_image()
331 .clone(),
332 );
333
334 Ok(ExecutionSegmentState {
335 pc,
336 is_terminated: did_terminate,
337 })
338 }
339
340 pub fn generate_proof_input<SC: StarkGenericConfig>(
342 #[allow(unused_mut)] mut self,
343 cached_program: Option<CommittedTraceData<SC>>,
344 ) -> Result<ProofInput<SC>, GenerationError>
345 where
346 Domain<SC>: PolynomialSpace<Val = F>,
347 VC::Executor: Chip<SC>,
348 VC::Periphery: Chip<SC>,
349 {
350 metrics_span("trace_gen_time_ms", || {
351 self.chip_complex.generate_proof_input(
352 cached_program,
353 &self.trace_height_constraints,
354 #[cfg(feature = "bench-metrics")]
355 &mut self.metrics,
356 )
357 })
358 }
359
360 fn should_segment(&mut self) -> bool {
362 if !self.system_config().continuation_enabled {
363 return false;
364 }
365 if self.since_last_segment_check != SEGMENT_CHECK_INTERVAL {
367 self.since_last_segment_check += 1;
368 return false;
369 }
370 self.since_last_segment_check = 0;
371 let segmentation_strategy = &self.system_config().segmentation_strategy;
372 segmentation_strategy.should_segment(
373 &self.air_names,
374 &self
375 .chip_complex
376 .dynamic_trace_heights()
377 .collect::<Vec<_>>(),
378 &self.chip_complex.current_trace_cells(),
379 )
380 }
381
382 pub fn current_trace_cells(&self) -> Vec<usize> {
383 self.chip_complex.current_trace_cells()
384 }
385}