openvm_circuit/arch/
segment.rs

1use std::sync::Arc;
2
3use backtrace::Backtrace;
4use openvm_instructions::{
5    exe::FnBounds,
6    instruction::{DebugInfo, Instruction},
7    program::Program,
8};
9use openvm_stark_backend::{
10    config::{Domain, StarkGenericConfig},
11    keygen::types::LinearConstraint,
12    p3_commit::PolynomialSpace,
13    p3_field::PrimeField32,
14    prover::types::{CommittedTraceData, ProofInput},
15    utils::metrics_span,
16    Chip,
17};
18
19use super::{
20    ExecutionError, GenerationError, Streams, SystemBase, SystemConfig, VmChipComplex,
21    VmComplexTraceHeights, VmConfig,
22};
23#[cfg(feature = "bench-metrics")]
24use crate::metrics::VmMetrics;
25use crate::{
26    arch::{instructions::*, ExecutionState, InstructionExecutor},
27    system::memory::MemoryImage,
28};
29
30/// Check segment every 100 instructions.
31const SEGMENT_CHECK_INTERVAL: usize = 100;
32
33const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100;
34// a heuristic number for the maximum number of cells per chip in a segment
35// a few reasons for this number:
36//  1. `VmAirWrapper<Rv32BaseAluAdapterAir, BaseAluCoreAir<4, 8>` is
37//    the chip with the most cells in a segment from the reth-benchmark.
38//  2. `VmAirWrapper<Rv32BaseAluAdapterAir, BaseAluCoreAir<4, 8>`:
39//    its trace width is 36 and its after challenge trace width is 80.
40const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120;
41
42pub trait SegmentationStrategy:
43    std::fmt::Debug + Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe
44{
45    /// Whether the execution should segment based on the trace heights and cells.
46    ///
47    /// Air names are provided for debugging purposes.
48    fn should_segment(
49        &self,
50        air_names: &[String],
51        trace_heights: &[usize],
52        trace_cells: &[usize],
53    ) -> bool;
54
55    /// A strategy that segments more aggressively than the current one.
56    ///
57    /// Called when `should_segment` results in a segment that is infeasible. Execution will be
58    /// re-run with the stricter segmentation strategy.
59    fn stricter_strategy(&self) -> Arc<dyn SegmentationStrategy>;
60}
61
62/// Default segmentation strategy: segment if any chip's height or cells exceed the limits.
63#[derive(Debug, Clone)]
64pub struct DefaultSegmentationStrategy {
65    max_segment_len: usize,
66    max_cells_per_chip_in_segment: usize,
67}
68
69impl Default for DefaultSegmentationStrategy {
70    fn default() -> Self {
71        Self {
72            max_segment_len: DEFAULT_MAX_SEGMENT_LEN,
73            max_cells_per_chip_in_segment: DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT,
74        }
75    }
76}
77
78impl DefaultSegmentationStrategy {
79    pub fn new_with_max_segment_len(max_segment_len: usize) -> Self {
80        Self {
81            max_segment_len,
82            max_cells_per_chip_in_segment: max_segment_len * 120,
83        }
84    }
85
86    pub fn new(max_segment_len: usize, max_cells_per_chip_in_segment: usize) -> Self {
87        Self {
88            max_segment_len,
89            max_cells_per_chip_in_segment,
90        }
91    }
92
93    pub fn max_segment_len(&self) -> usize {
94        self.max_segment_len
95    }
96}
97
98const SEGMENTATION_BACKOFF_FACTOR: usize = 4;
99
100impl SegmentationStrategy for DefaultSegmentationStrategy {
101    fn should_segment(
102        &self,
103        air_names: &[String],
104        trace_heights: &[usize],
105        trace_cells: &[usize],
106    ) -> bool {
107        for (i, &height) in trace_heights.iter().enumerate() {
108            if height > self.max_segment_len {
109                tracing::info!(
110                    "Should segment because chip {} (name: {}) has height {}",
111                    i,
112                    air_names[i],
113                    height
114                );
115                return true;
116            }
117        }
118        for (i, &num_cells) in trace_cells.iter().enumerate() {
119            if num_cells > self.max_cells_per_chip_in_segment {
120                tracing::info!(
121                    "Should segment because chip {} (name: {}) has {} cells",
122                    i,
123                    air_names[i],
124                    num_cells
125                );
126                return true;
127            }
128        }
129        false
130    }
131
132    fn stricter_strategy(&self) -> Arc<dyn SegmentationStrategy> {
133        Arc::new(Self {
134            max_segment_len: self.max_segment_len / SEGMENTATION_BACKOFF_FACTOR,
135            max_cells_per_chip_in_segment: self.max_cells_per_chip_in_segment
136                / SEGMENTATION_BACKOFF_FACTOR,
137        })
138    }
139}
140
141pub struct ExecutionSegment<F, VC>
142where
143    F: PrimeField32,
144    VC: VmConfig<F>,
145{
146    pub chip_complex: VmChipComplex<F, VC::Executor, VC::Periphery>,
147    /// Memory image after segment was executed. Not used in trace generation.
148    pub final_memory: Option<MemoryImage<F>>,
149
150    pub since_last_segment_check: usize,
151    pub trace_height_constraints: Vec<LinearConstraint>,
152
153    /// Air names for debug purposes only.
154    pub(crate) air_names: Vec<String>,
155    /// Metrics collected for this execution segment alone.
156    #[cfg(feature = "bench-metrics")]
157    pub metrics: VmMetrics,
158}
159
160pub struct ExecutionSegmentState {
161    pub pc: u32,
162    pub is_terminated: bool,
163}
164
165impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {
166    /// Creates a new execution segment from a program and initial state, using parent VM config
167    pub fn new(
168        config: &VC,
169        program: Program<F>,
170        init_streams: Streams<F>,
171        initial_memory: Option<MemoryImage<F>>,
172        trace_height_constraints: Vec<LinearConstraint>,
173        #[allow(unused_variables)] fn_bounds: FnBounds,
174    ) -> Self {
175        let mut chip_complex = config.create_chip_complex().unwrap();
176        chip_complex.set_streams(init_streams);
177        let program = if !config.system().profiling {
178            program.strip_debug_infos()
179        } else {
180            program
181        };
182        chip_complex.set_program(program);
183
184        if let Some(initial_memory) = initial_memory {
185            chip_complex.set_initial_memory(initial_memory);
186        }
187        let air_names = chip_complex.air_names();
188
189        Self {
190            chip_complex,
191            final_memory: None,
192            air_names,
193            trace_height_constraints,
194            #[cfg(feature = "bench-metrics")]
195            metrics: VmMetrics {
196                fn_bounds,
197                ..Default::default()
198            },
199            since_last_segment_check: 0,
200        }
201    }
202
203    pub fn system_config(&self) -> &SystemConfig {
204        self.chip_complex.config()
205    }
206
207    pub fn set_override_trace_heights(&mut self, overridden_heights: VmComplexTraceHeights) {
208        self.chip_complex
209            .set_override_system_trace_heights(overridden_heights.system);
210        self.chip_complex
211            .set_override_inventory_trace_heights(overridden_heights.inventory);
212    }
213
214    /// Stopping is triggered by should_segment()
215    pub fn execute_from_pc(
216        &mut self,
217        mut pc: u32,
218    ) -> Result<ExecutionSegmentState, ExecutionError> {
219        let mut timestamp = self.chip_complex.memory_controller().timestamp();
220        let mut prev_backtrace: Option<Backtrace> = None;
221
222        self.chip_complex
223            .connector_chip_mut()
224            .begin(ExecutionState::new(pc, timestamp));
225
226        let mut did_terminate = false;
227
228        loop {
229            #[allow(unused_variables)]
230            let (opcode, dsl_instr) = {
231                let Self {
232                    chip_complex,
233                    #[cfg(feature = "bench-metrics")]
234                    metrics,
235                    ..
236                } = self;
237                let SystemBase {
238                    program_chip,
239                    memory_controller,
240                    ..
241                } = &mut chip_complex.base;
242
243                let (instruction, debug_info) = program_chip.get_instruction(pc)?;
244                tracing::trace!("pc: {pc:#x} | time: {timestamp} | {:?}", instruction);
245
246                #[allow(unused_variables)]
247                let (dsl_instr, trace) = debug_info.as_ref().map_or(
248                    (None, None),
249                    |DebugInfo {
250                         dsl_instruction,
251                         trace,
252                     }| (Some(dsl_instruction), trace.as_ref()),
253                );
254
255                let &Instruction { opcode, c, .. } = instruction;
256                if opcode == SystemOpcode::TERMINATE.global_opcode() {
257                    did_terminate = true;
258                    self.chip_complex.connector_chip_mut().end(
259                        ExecutionState::new(pc, timestamp),
260                        Some(c.as_canonical_u32()),
261                    );
262                    break;
263                }
264
265                // Some phantom instruction handling is more convenient to do here than in PhantomChip.
266                if opcode == SystemOpcode::PHANTOM.global_opcode() {
267                    // Note: the discriminant is the lower 16 bits of the c operand.
268                    let discriminant = c.as_canonical_u32() as u16;
269                    let phantom = SysPhantom::from_repr(discriminant);
270                    tracing::trace!("pc: {pc:#x} | system phantom: {phantom:?}");
271                    match phantom {
272                        Some(SysPhantom::DebugPanic) => {
273                            if let Some(mut backtrace) = prev_backtrace {
274                                backtrace.resolve();
275                                eprintln!("openvm program failure; backtrace:\n{:?}", backtrace);
276                            } else {
277                                eprintln!("openvm program failure; no backtrace");
278                            }
279                            return Err(ExecutionError::Fail { pc });
280                        }
281                        Some(SysPhantom::CtStart) =>
282                        {
283                            #[cfg(feature = "bench-metrics")]
284                            metrics
285                                .cycle_tracker
286                                .start(dsl_instr.cloned().unwrap_or("Default".to_string()))
287                        }
288                        Some(SysPhantom::CtEnd) =>
289                        {
290                            #[cfg(feature = "bench-metrics")]
291                            metrics
292                                .cycle_tracker
293                                .end(dsl_instr.cloned().unwrap_or("Default".to_string()))
294                        }
295                        _ => {}
296                    }
297                }
298                prev_backtrace = trace.cloned();
299
300                if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) {
301                    let next_state = InstructionExecutor::execute(
302                        executor,
303                        memory_controller,
304                        instruction,
305                        ExecutionState::new(pc, timestamp),
306                    )?;
307                    assert!(next_state.timestamp > timestamp);
308                    pc = next_state.pc;
309                    timestamp = next_state.timestamp;
310                } else {
311                    return Err(ExecutionError::DisabledOperation { pc, opcode });
312                };
313                (opcode, dsl_instr.cloned())
314            };
315
316            #[cfg(feature = "bench-metrics")]
317            self.update_instruction_metrics(pc, opcode, dsl_instr);
318
319            if self.should_segment() {
320                self.chip_complex
321                    .connector_chip_mut()
322                    .end(ExecutionState::new(pc, timestamp), None);
323                break;
324            }
325        }
326        self.final_memory = Some(
327            self.chip_complex
328                .base
329                .memory_controller
330                .memory_image()
331                .clone(),
332        );
333
334        Ok(ExecutionSegmentState {
335            pc,
336            is_terminated: did_terminate,
337        })
338    }
339
340    /// Generate ProofInput to prove the segment. Should be called after ::execute
341    pub fn generate_proof_input<SC: StarkGenericConfig>(
342        #[allow(unused_mut)] mut self,
343        cached_program: Option<CommittedTraceData<SC>>,
344    ) -> Result<ProofInput<SC>, GenerationError>
345    where
346        Domain<SC>: PolynomialSpace<Val = F>,
347        VC::Executor: Chip<SC>,
348        VC::Periphery: Chip<SC>,
349    {
350        metrics_span("trace_gen_time_ms", || {
351            self.chip_complex.generate_proof_input(
352                cached_program,
353                &self.trace_height_constraints,
354                #[cfg(feature = "bench-metrics")]
355                &mut self.metrics,
356            )
357        })
358    }
359
360    /// Returns bool of whether to switch to next segment or not. This is called every clock cycle inside of Core trace generation.
361    fn should_segment(&mut self) -> bool {
362        if !self.system_config().continuation_enabled {
363            return false;
364        }
365        // Avoid checking segment too often.
366        if self.since_last_segment_check != SEGMENT_CHECK_INTERVAL {
367            self.since_last_segment_check += 1;
368            return false;
369        }
370        self.since_last_segment_check = 0;
371        let segmentation_strategy = &self.system_config().segmentation_strategy;
372        segmentation_strategy.should_segment(
373            &self.air_names,
374            &self
375                .chip_complex
376                .dynamic_trace_heights()
377                .collect::<Vec<_>>(),
378            &self.chip_complex.current_trace_cells(),
379        )
380    }
381
382    pub fn current_trace_cells(&self) -> Vec<usize> {
383        self.chip_complex.current_trace_cells()
384    }
385}