openvm_circuit/arch/execution_mode/metered/
segment_ctx.rs

1use getset::WithSetters;
2use openvm_stark_backend::p3_field::PrimeField32;
3use p3_baby_bear::BabyBear;
4use serde::{Deserialize, Serialize};
5
6pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000;
7
8pub const DEFAULT_MAX_TRACE_HEIGHT_BITS: u8 = 22;
9pub const DEFAULT_MAX_TRACE_HEIGHT: u32 = 1 << DEFAULT_MAX_TRACE_HEIGHT_BITS;
10pub const DEFAULT_MAX_CELLS: usize = 1_200_000_000; // 1.2B
11const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize;
12
13#[derive(derive_new::new, Clone, Debug, Serialize, Deserialize)]
14pub struct Segment {
15    pub instret_start: u64,
16    pub num_insns: u64,
17    pub trace_heights: Vec<u32>,
18}
19
20#[derive(Clone, Copy, Debug, WithSetters)]
21pub struct SegmentationLimits {
22    #[getset(set_with = "pub")]
23    pub max_trace_height: u32,
24    #[getset(set_with = "pub")]
25    pub max_cells: usize,
26    #[getset(set_with = "pub")]
27    pub max_interactions: usize,
28}
29
30impl Default for SegmentationLimits {
31    fn default() -> Self {
32        Self {
33            max_trace_height: DEFAULT_MAX_TRACE_HEIGHT,
34            max_cells: DEFAULT_MAX_CELLS,
35            max_interactions: DEFAULT_MAX_INTERACTIONS,
36        }
37    }
38}
39
40impl SegmentationLimits {
41    pub fn new(max_trace_height: u32, max_cells: usize, max_interactions: usize) -> Self {
42        debug_assert!(
43            max_trace_height.is_power_of_two(),
44            "max_trace_height should be a power of two"
45        );
46        Self {
47            max_trace_height,
48            max_cells,
49            max_interactions,
50        }
51    }
52
53    pub fn set_max_trace_height(&mut self, max_trace_height: u32) {
54        debug_assert!(
55            max_trace_height.is_power_of_two(),
56            "max_trace_height should be a power of two"
57        );
58        self.max_trace_height = max_trace_height;
59    }
60}
61
62#[derive(Clone, Debug, WithSetters)]
63pub struct SegmentationCtx {
64    pub segments: Vec<Segment>,
65    pub(crate) air_names: Vec<String>,
66    pub(crate) widths: Vec<usize>,
67    interactions: Vec<usize>,
68    pub(crate) segmentation_limits: SegmentationLimits,
69    pub instret: u64,
70    pub instrets_until_check: u64,
71    pub(super) segment_check_insns: u64,
72    /// Checkpoint of trace heights at last known state where all thresholds satisfied
73    pub(crate) checkpoint_trace_heights: Vec<u32>,
74    /// Instruction count at the checkpoint
75    checkpoint_instret: u64,
76}
77
78impl SegmentationCtx {
79    pub fn new(
80        air_names: Vec<String>,
81        widths: Vec<usize>,
82        interactions: Vec<usize>,
83        segmentation_limits: SegmentationLimits,
84    ) -> Self {
85        assert_eq!(air_names.len(), widths.len());
86        assert_eq!(air_names.len(), interactions.len());
87
88        let num_airs = air_names.len();
89        Self {
90            segments: Vec::new(),
91            air_names,
92            widths,
93            interactions,
94            segmentation_limits,
95            instret: 0,
96            instrets_until_check: DEFAULT_SEGMENT_CHECK_INSNS,
97            segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
98            checkpoint_trace_heights: vec![0; num_airs],
99            checkpoint_instret: 0,
100        }
101    }
102
103    pub fn new_with_default_segmentation_limits(
104        air_names: Vec<String>,
105        widths: Vec<usize>,
106        interactions: Vec<usize>,
107    ) -> Self {
108        assert_eq!(air_names.len(), widths.len());
109        assert_eq!(air_names.len(), interactions.len());
110
111        let num_airs = air_names.len();
112        Self {
113            segments: Vec::new(),
114            air_names,
115            widths,
116            interactions,
117            segmentation_limits: SegmentationLimits::default(),
118            instret: 0,
119            instrets_until_check: DEFAULT_SEGMENT_CHECK_INSNS,
120            segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
121            checkpoint_trace_heights: vec![0; num_airs],
122            checkpoint_instret: 0,
123        }
124    }
125
126    pub fn set_max_trace_height(&mut self, max_trace_height: u32) {
127        self.segmentation_limits
128            .set_max_trace_height(max_trace_height);
129    }
130
131    pub fn set_max_cells(&mut self, max_cells: usize) {
132        self.segmentation_limits.max_cells = max_cells;
133    }
134
135    pub fn set_max_interactions(&mut self, max_interactions: usize) {
136        self.segmentation_limits.max_interactions = max_interactions;
137    }
138
139    /// Calculate the maximum trace height and corresponding air name
140    #[inline(always)]
141    fn calculate_max_trace_height_with_name(&self, trace_heights: &[u32]) -> (u32, &str) {
142        trace_heights
143            .iter()
144            .enumerate()
145            .map(|(i, &height)| (height.next_power_of_two(), i))
146            .max_by_key(|(height, _)| *height)
147            .map(|(height, idx)| (height, self.air_names[idx].as_str()))
148            .unwrap_or((0, "unknown"))
149    }
150
151    /// Calculate the total cells used based on trace heights and widths
152    #[inline(always)]
153    fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize {
154        debug_assert_eq!(trace_heights.len(), self.widths.len());
155
156        trace_heights
157            .iter()
158            .zip(self.widths.iter())
159            .map(|(&height, &width)| height.next_power_of_two() as usize * width)
160            .sum()
161    }
162
163    /// Calculate the total interactions based on trace heights
164    /// All padding rows contribute a single message to the interactions (+1) since
165    /// we assume chips don't send/receive with nonzero multiplicity on padding rows.
166    #[inline(always)]
167    fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize {
168        debug_assert_eq!(trace_heights.len(), self.interactions.len());
169
170        trace_heights
171            .iter()
172            .zip(self.interactions.iter())
173            .map(|(&height, &interactions)| (height + 1) as usize * interactions)
174            .sum()
175    }
176
177    #[inline(always)]
178    pub(crate) fn should_segment(
179        &self,
180        instret: u64,
181        trace_heights: &[u32],
182        is_trace_height_constant: &[bool],
183    ) -> bool {
184        debug_assert_eq!(trace_heights.len(), is_trace_height_constant.len());
185        debug_assert_eq!(trace_heights.len(), self.air_names.len());
186        debug_assert_eq!(trace_heights.len(), self.widths.len());
187        debug_assert_eq!(trace_heights.len(), self.interactions.len());
188
189        let instret_start = self
190            .segments
191            .last()
192            .map_or(0, |s| s.instret_start + s.num_insns);
193        let num_insns = instret - instret_start;
194
195        // Segment should contain at least one cycle
196        if num_insns == 0 {
197            return false;
198        }
199
200        let mut total_cells = 0;
201        for (i, ((padded_height, width), is_constant)) in trace_heights
202            .iter()
203            .map(|&height| height.next_power_of_two())
204            .zip(self.widths.iter())
205            .zip(is_trace_height_constant.iter())
206            .enumerate()
207        {
208            // Only segment if the height is not constant and exceeds the maximum height after
209            // padding
210            if !is_constant && padded_height > self.segmentation_limits.max_trace_height {
211                let air_name = unsafe { self.air_names.get_unchecked(i) };
212                tracing::info!(
213                    "instret {:10} | height ({:8}) > max ({:8}) | chip {:3} ({}) ",
214                    instret,
215                    padded_height,
216                    self.segmentation_limits.max_trace_height,
217                    i,
218                    air_name,
219                );
220                return true;
221            }
222            total_cells += padded_height as usize * width;
223        }
224
225        if total_cells > self.segmentation_limits.max_cells {
226            tracing::info!(
227                "instret {:10} | total cells ({:10}) > max ({:10})",
228                instret,
229                total_cells,
230                self.segmentation_limits.max_cells
231            );
232            return true;
233        }
234
235        let total_interactions = self.calculate_total_interactions(trace_heights);
236        if total_interactions > self.segmentation_limits.max_interactions {
237            tracing::info!(
238                "instret {:10} | total interactions ({:10}) > max ({:10})",
239                instret,
240                total_interactions,
241                self.segmentation_limits.max_interactions
242            );
243            return true;
244        }
245
246        false
247    }
248
249    #[inline(always)]
250    pub fn check_and_segment(
251        &mut self,
252        instret: u64,
253        trace_heights: &mut [u32],
254        is_trace_height_constant: &[bool],
255    ) -> bool {
256        let should_seg = self.should_segment(instret, trace_heights, is_trace_height_constant);
257
258        if should_seg {
259            self.create_segment_from_checkpoint(instret, trace_heights);
260        }
261        should_seg
262    }
263
264    #[inline(always)]
265    fn create_segment_from_checkpoint(&mut self, instret: u64, trace_heights: &mut [u32]) {
266        let instret_start = self
267            .segments
268            .last()
269            .map_or(0, |s| s.instret_start + s.num_insns);
270
271        let (segment_instret, segment_heights) = if self.checkpoint_instret > instret_start {
272            (
273                self.checkpoint_instret,
274                self.checkpoint_trace_heights.clone(),
275            )
276        } else {
277            let trace_heights_str = trace_heights
278                .iter()
279                .zip(self.air_names.iter())
280                .filter(|(&height, _)| height > 0)
281                .map(|(&height, name)| format!("  {name} = {height}"))
282                .collect::<Vec<_>>()
283                .join("\n");
284            tracing::warn!(
285                "No valid checkpoint, creating segment using instret={instret}\ntrace_heights=[\n{trace_heights_str}\n]"
286            );
287            // No valid checkpoint, use current values
288            (instret, trace_heights.to_vec())
289        };
290
291        let num_insns = segment_instret - instret_start;
292        self.create_segment::<false>(instret_start, num_insns, segment_heights);
293    }
294
295    /// Initialize state for a new segment
296    #[inline(always)]
297    pub(crate) fn initialize_segment(
298        &mut self,
299        trace_heights: &mut [u32],
300        is_trace_height_constant: &[bool],
301    ) {
302        // Reset trace heights by subtracting the last segment's heights
303        let last_segment = self.segments.last().unwrap();
304        self.reset_trace_heights(
305            trace_heights,
306            &last_segment.trace_heights,
307            is_trace_height_constant,
308        );
309    }
310
311    /// Resets trace heights by subtracting segment heights
312    #[inline(always)]
313    fn reset_trace_heights(
314        &self,
315        trace_heights: &mut [u32],
316        segment_heights: &[u32],
317        is_trace_height_constant: &[bool],
318    ) {
319        for ((trace_height, &segment_height), &is_trace_height_constant) in trace_heights
320            .iter_mut()
321            .zip(segment_heights.iter())
322            .zip(is_trace_height_constant.iter())
323        {
324            if !is_trace_height_constant {
325                *trace_height = trace_height.checked_sub(segment_height).unwrap();
326            }
327        }
328    }
329
330    /// Updates the checkpoint with current safe state
331    #[inline(always)]
332    pub(crate) fn update_checkpoint(&mut self, instret: u64, trace_heights: &[u32]) {
333        self.checkpoint_trace_heights.copy_from_slice(trace_heights);
334        self.checkpoint_instret = instret;
335    }
336
337    /// Try segment if there is at least one instruction
338    #[inline(always)]
339    pub fn create_final_segment(&mut self, trace_heights: &[u32]) {
340        self.instret += self.segment_check_insns - self.instrets_until_check;
341        self.instrets_until_check = self.segment_check_insns;
342        let instret_start = self
343            .segments
344            .last()
345            .map_or(0, |s| s.instret_start + s.num_insns);
346
347        let num_insns = self.instret - instret_start;
348        self.create_segment::<true>(instret_start, num_insns, trace_heights.to_vec());
349    }
350
351    /// Push a new segment with logging
352    #[inline(always)]
353    fn create_segment<const IS_FINAL: bool>(
354        &mut self,
355        instret_start: u64,
356        num_insns: u64,
357        trace_heights: Vec<u32>,
358    ) {
359        debug_assert!(
360            num_insns > 0,
361            "Segment should contain at least one instruction"
362        );
363
364        self.log_segment_info::<IS_FINAL>(instret_start, num_insns, &trace_heights);
365        self.segments.push(Segment {
366            instret_start,
367            num_insns,
368            trace_heights,
369        });
370    }
371
372    /// Log segment information
373    #[inline(always)]
374    fn log_segment_info<const IS_FINAL: bool>(
375        &self,
376        instret_start: u64,
377        num_insns: u64,
378        trace_heights: &[u32],
379    ) {
380        let (max_trace_height, air_name) = self.calculate_max_trace_height_with_name(trace_heights);
381        let total_cells = self.calculate_total_cells(trace_heights);
382        let total_interactions = self.calculate_total_interactions(trace_heights);
383
384        let final_marker = if IS_FINAL { " [TERMINATED]" } else { "" };
385
386        tracing::info!(
387            "Segment {:3} | instret {:10} | {:8} instructions | {:10} cells | {:10} interactions | {:8} max height ({}){}",
388            self.segments.len(),
389            instret_start,
390            num_insns,
391            total_cells,
392            total_interactions,
393            max_trace_height,
394            air_name,
395            final_marker
396        );
397    }
398}