openvm_circuit/arch/execution_mode/metered/
segment_ctx.rs

1use getset::WithSetters;
2use openvm_stark_backend::p3_field::PrimeField32;
3use p3_baby_bear::BabyBear;
4use serde::{Deserialize, Serialize};
5
6pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000;
7
8pub const DEFAULT_MAX_TRACE_HEIGHT_BITS: u8 = 22;
9pub const DEFAULT_MAX_TRACE_HEIGHT: u32 = 1 << DEFAULT_MAX_TRACE_HEIGHT_BITS;
10pub const DEFAULT_MAX_CELLS: usize = 1_200_000_000; // 1.2B
11const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize;
12
13#[derive(derive_new::new, Clone, Debug, Serialize, Deserialize)]
14pub struct Segment {
15    pub instret_start: u64,
16    pub num_insns: u64,
17    pub trace_heights: Vec<u32>,
18}
19
20#[derive(Clone, Copy, Debug, WithSetters)]
21pub struct SegmentationLimits {
22    #[getset(set_with = "pub")]
23    pub max_trace_height: u32,
24    #[getset(set_with = "pub")]
25    pub max_cells: usize,
26    #[getset(set_with = "pub")]
27    pub max_interactions: usize,
28}
29
30impl Default for SegmentationLimits {
31    fn default() -> Self {
32        Self {
33            max_trace_height: DEFAULT_MAX_TRACE_HEIGHT,
34            max_cells: DEFAULT_MAX_CELLS,
35            max_interactions: DEFAULT_MAX_INTERACTIONS,
36        }
37    }
38}
39
40#[derive(Clone, Debug, WithSetters)]
41pub struct SegmentationCtx {
42    pub segments: Vec<Segment>,
43    pub(crate) air_names: Vec<String>,
44    pub(crate) widths: Vec<usize>,
45    interactions: Vec<usize>,
46    pub(crate) segmentation_limits: SegmentationLimits,
47    pub instret_last_segment_check: u64,
48    #[getset(set_with = "pub")]
49    pub segment_check_insns: u64,
50    /// Checkpoint of trace heights at last known state where all thresholds satisfied
51    pub(crate) checkpoint_trace_heights: Vec<u32>,
52    /// Instruction count at the checkpoint
53    checkpoint_instret: u64,
54}
55
56impl SegmentationCtx {
57    pub fn new(
58        air_names: Vec<String>,
59        widths: Vec<usize>,
60        interactions: Vec<usize>,
61        segmentation_limits: SegmentationLimits,
62    ) -> Self {
63        assert_eq!(air_names.len(), widths.len());
64        assert_eq!(air_names.len(), interactions.len());
65
66        let num_airs = air_names.len();
67        Self {
68            segments: Vec::new(),
69            air_names,
70            widths,
71            interactions,
72            segmentation_limits,
73            segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
74            instret_last_segment_check: 0,
75            checkpoint_trace_heights: vec![0; num_airs],
76            checkpoint_instret: 0,
77        }
78    }
79
80    pub fn new_with_default_segmentation_limits(
81        air_names: Vec<String>,
82        widths: Vec<usize>,
83        interactions: Vec<usize>,
84    ) -> Self {
85        assert_eq!(air_names.len(), widths.len());
86        assert_eq!(air_names.len(), interactions.len());
87
88        let num_airs = air_names.len();
89        Self {
90            segments: Vec::new(),
91            air_names,
92            widths,
93            interactions,
94            segmentation_limits: SegmentationLimits::default(),
95            segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
96            instret_last_segment_check: 0,
97            checkpoint_trace_heights: vec![0; num_airs],
98            checkpoint_instret: 0,
99        }
100    }
101
102    pub fn set_max_trace_height(&mut self, max_trace_height: u32) {
103        debug_assert!(
104            max_trace_height.is_power_of_two(),
105            "max_trace_height should be a power of two"
106        );
107        self.segmentation_limits.max_trace_height = max_trace_height;
108    }
109
110    pub fn set_max_cells(&mut self, max_cells: usize) {
111        self.segmentation_limits.max_cells = max_cells;
112    }
113
114    pub fn set_max_interactions(&mut self, max_interactions: usize) {
115        self.segmentation_limits.max_interactions = max_interactions;
116    }
117
118    /// Calculate the maximum trace height and corresponding air name
119    #[inline(always)]
120    fn calculate_max_trace_height_with_name(&self, trace_heights: &[u32]) -> (u32, &str) {
121        trace_heights
122            .iter()
123            .enumerate()
124            .map(|(i, &height)| (height.next_power_of_two(), i))
125            .max_by_key(|(height, _)| *height)
126            .map(|(height, idx)| (height, self.air_names[idx].as_str()))
127            .unwrap_or((0, "unknown"))
128    }
129
130    /// Calculate the total cells used based on trace heights and widths
131    #[inline(always)]
132    fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize {
133        debug_assert_eq!(trace_heights.len(), self.widths.len());
134
135        trace_heights
136            .iter()
137            .zip(self.widths.iter())
138            .map(|(&height, &width)| height.next_power_of_two() as usize * width)
139            .sum()
140    }
141
142    /// Calculate the total interactions based on trace heights
143    /// All padding rows contribute a single message to the interactions (+1) since
144    /// we assume chips don't send/receive with nonzero multiplicity on padding rows.
145    #[inline(always)]
146    fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize {
147        debug_assert_eq!(trace_heights.len(), self.interactions.len());
148
149        trace_heights
150            .iter()
151            .zip(self.interactions.iter())
152            .map(|(&height, &interactions)| (height + 1) as usize * interactions)
153            .sum()
154    }
155
156    #[inline(always)]
157    fn should_segment(
158        &self,
159        instret: u64,
160        trace_heights: &[u32],
161        is_trace_height_constant: &[bool],
162    ) -> bool {
163        debug_assert_eq!(trace_heights.len(), is_trace_height_constant.len());
164        debug_assert_eq!(trace_heights.len(), self.air_names.len());
165        debug_assert_eq!(trace_heights.len(), self.widths.len());
166        debug_assert_eq!(trace_heights.len(), self.interactions.len());
167
168        let instret_start = self
169            .segments
170            .last()
171            .map_or(0, |s| s.instret_start + s.num_insns);
172        let num_insns = instret - instret_start;
173
174        // Segment should contain at least one cycle
175        if num_insns == 0 {
176            return false;
177        }
178
179        let mut total_cells = 0;
180        for (i, ((padded_height, width), is_constant)) in trace_heights
181            .iter()
182            .map(|&height| height.next_power_of_two())
183            .zip(self.widths.iter())
184            .zip(is_trace_height_constant.iter())
185            .enumerate()
186        {
187            // Only segment if the height is not constant and exceeds the maximum height after
188            // padding
189            if !is_constant && padded_height > self.segmentation_limits.max_trace_height {
190                let air_name = unsafe { self.air_names.get_unchecked(i) };
191                tracing::info!(
192                    "instret {:10} | height ({:8}) > max ({:8}) | chip {:3} ({}) ",
193                    instret,
194                    i,
195                    air_name,
196                    padded_height,
197                    self.segmentation_limits.max_trace_height
198                );
199                return true;
200            }
201            total_cells += padded_height as usize * width;
202        }
203
204        if total_cells > self.segmentation_limits.max_cells {
205            tracing::info!(
206                "instret {:10} | total cells ({:10}) > max ({:10})",
207                instret,
208                total_cells,
209                self.segmentation_limits.max_cells
210            );
211            return true;
212        }
213
214        let total_interactions = self.calculate_total_interactions(trace_heights);
215        if total_interactions > self.segmentation_limits.max_interactions {
216            tracing::info!(
217                "instret {:10} | total interactions ({:10}) > max ({:10})",
218                instret,
219                total_interactions,
220                self.segmentation_limits.max_interactions
221            );
222            return true;
223        }
224
225        false
226    }
227
228    #[inline(always)]
229    pub fn check_and_segment(
230        &mut self,
231        instret: u64,
232        trace_heights: &mut [u32],
233        is_trace_height_constant: &[bool],
234    ) -> bool {
235        let should_seg = self.should_segment(instret, trace_heights, is_trace_height_constant);
236
237        if should_seg {
238            self.create_segment_from_checkpoint(instret, trace_heights, is_trace_height_constant);
239        } else {
240            self.update_checkpoint(instret, trace_heights);
241        }
242
243        self.instret_last_segment_check = instret;
244        should_seg
245    }
246
247    #[inline(always)]
248    fn create_segment_from_checkpoint(
249        &mut self,
250        instret: u64,
251        trace_heights: &mut [u32],
252        is_trace_height_constant: &[bool],
253    ) {
254        let instret_start = self
255            .segments
256            .last()
257            .map_or(0, |s| s.instret_start + s.num_insns);
258
259        let (segment_instret, segment_heights) = if self.checkpoint_instret > instret_start {
260            (
261                self.checkpoint_instret,
262                self.checkpoint_trace_heights.clone(),
263            )
264        } else {
265            tracing::warn!(
266                "No valid checkpoint, creating segment using instret={instret}, trace_heights={:?}",
267                trace_heights
268            );
269            // No valid checkpoint, use current values
270            (instret, trace_heights.to_vec())
271        };
272
273        // Reset current trace heights and checkpoint
274        self.reset_trace_heights(trace_heights, &segment_heights, is_trace_height_constant);
275        self.checkpoint_instret = 0;
276
277        let num_insns = segment_instret - instret_start;
278        self.create_segment::<false>(instret_start, num_insns, segment_heights);
279    }
280
281    /// Resets trace heights by subtracting segment heights
282    #[inline(always)]
283    fn reset_trace_heights(
284        &self,
285        trace_heights: &mut [u32],
286        segment_heights: &[u32],
287        is_trace_height_constant: &[bool],
288    ) {
289        for ((trace_height, &segment_height), &is_trace_height_constant) in trace_heights
290            .iter_mut()
291            .zip(segment_heights.iter())
292            .zip(is_trace_height_constant.iter())
293        {
294            if !is_trace_height_constant {
295                *trace_height = trace_height.checked_sub(segment_height).unwrap();
296            }
297        }
298    }
299
300    /// Updates the checkpoint with current safe state
301    #[inline(always)]
302    fn update_checkpoint(&mut self, instret: u64, trace_heights: &[u32]) {
303        self.checkpoint_trace_heights.copy_from_slice(trace_heights);
304        self.checkpoint_instret = instret;
305    }
306
307    /// Try segment if there is at least one instruction
308    #[inline(always)]
309    pub fn create_final_segment(&mut self, instret: u64, trace_heights: &[u32]) {
310        let instret_start = self
311            .segments
312            .last()
313            .map_or(0, |s| s.instret_start + s.num_insns);
314
315        let num_insns = instret - instret_start;
316        self.create_segment::<true>(instret_start, num_insns, trace_heights.to_vec());
317    }
318
319    /// Push a new segment with logging
320    #[inline(always)]
321    fn create_segment<const IS_FINAL: bool>(
322        &mut self,
323        instret_start: u64,
324        num_insns: u64,
325        trace_heights: Vec<u32>,
326    ) {
327        debug_assert!(
328            num_insns > 0,
329            "Segment should contain at least one instruction"
330        );
331
332        self.log_segment_info::<IS_FINAL>(instret_start, num_insns, &trace_heights);
333        self.segments.push(Segment {
334            instret_start,
335            num_insns,
336            trace_heights,
337        });
338    }
339
340    /// Log segment information
341    #[inline(always)]
342    fn log_segment_info<const IS_FINAL: bool>(
343        &self,
344        instret_start: u64,
345        num_insns: u64,
346        trace_heights: &[u32],
347    ) {
348        let (max_trace_height, air_name) = self.calculate_max_trace_height_with_name(trace_heights);
349        let total_cells = self.calculate_total_cells(trace_heights);
350        let total_interactions = self.calculate_total_interactions(trace_heights);
351
352        let final_marker = if IS_FINAL { " [TERMINATED]" } else { "" };
353
354        tracing::info!(
355            "Segment {:3} | instret {:10} | {:8} instructions | {:10} cells | {:10} interactions | {:8} max height ({}){}",
356            self.segments.len(),
357            instret_start,
358            num_insns,
359            total_cells,
360            total_interactions,
361            max_trace_height,
362            air_name,
363            final_marker
364        );
365    }
366}