openvm_circuit/arch/execution_mode/metered/
segment_ctx.rs

1use getset::WithSetters;
2use openvm_stark_backend::p3_field::PrimeField32;
3use p3_baby_bear::BabyBear;
4use serde::{Deserialize, Serialize};
5
6pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000;
7
8pub const DEFAULT_MAX_TRACE_HEIGHT: u32 = (1 << 23) - 10000;
9pub const DEFAULT_MAX_CELLS: usize = 2_000_000_000; // 2B
10const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize;
11
12#[derive(derive_new::new, Clone, Debug, Serialize, Deserialize)]
13pub struct Segment {
14    pub instret_start: u64,
15    pub num_insns: u64,
16    pub trace_heights: Vec<u32>,
17}
18
19#[derive(Clone, Copy, Debug, WithSetters)]
20pub struct SegmentationLimits {
21    #[getset(set_with = "pub")]
22    pub max_trace_height: u32,
23    #[getset(set_with = "pub")]
24    pub max_cells: usize,
25    #[getset(set_with = "pub")]
26    pub max_interactions: usize,
27}
28
29impl Default for SegmentationLimits {
30    fn default() -> Self {
31        Self {
32            max_trace_height: DEFAULT_MAX_TRACE_HEIGHT,
33            max_cells: DEFAULT_MAX_CELLS,
34            max_interactions: DEFAULT_MAX_INTERACTIONS,
35        }
36    }
37}
38
39#[derive(Clone, Debug, WithSetters)]
40pub struct SegmentationCtx {
41    pub segments: Vec<Segment>,
42    pub(crate) air_names: Vec<String>,
43    widths: Vec<usize>,
44    interactions: Vec<usize>,
45    pub(crate) segmentation_limits: SegmentationLimits,
46    pub instret_last_segment_check: u64,
47    #[getset(set_with = "pub")]
48    pub segment_check_insns: u64,
49}
50
51impl SegmentationCtx {
52    pub fn new(
53        air_names: Vec<String>,
54        widths: Vec<usize>,
55        interactions: Vec<usize>,
56        segmentation_limits: SegmentationLimits,
57    ) -> Self {
58        assert_eq!(air_names.len(), widths.len());
59        assert_eq!(air_names.len(), interactions.len());
60
61        Self {
62            segments: Vec::new(),
63            air_names,
64            widths,
65            interactions,
66            segmentation_limits,
67            segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
68            instret_last_segment_check: 0,
69        }
70    }
71
72    pub fn new_with_default_segmentation_limits(
73        air_names: Vec<String>,
74        widths: Vec<usize>,
75        interactions: Vec<usize>,
76    ) -> Self {
77        assert_eq!(air_names.len(), widths.len());
78        assert_eq!(air_names.len(), interactions.len());
79
80        Self {
81            segments: Vec::new(),
82            air_names,
83            widths,
84            interactions,
85            segmentation_limits: SegmentationLimits::default(),
86            segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
87            instret_last_segment_check: 0,
88        }
89    }
90
91    pub fn set_max_trace_height(&mut self, max_trace_height: u32) {
92        self.segmentation_limits.max_trace_height = max_trace_height;
93    }
94
95    pub fn set_max_cells(&mut self, max_cells: usize) {
96        self.segmentation_limits.max_cells = max_cells;
97    }
98
99    pub fn set_max_interactions(&mut self, max_interactions: usize) {
100        self.segmentation_limits.max_interactions = max_interactions;
101    }
102
103    /// Calculate the total cells used based on trace heights and widths
104    #[inline(always)]
105    fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize {
106        debug_assert_eq!(trace_heights.len(), self.widths.len());
107
108        // SAFETY: Length equality is asserted during initialization
109        let widths_slice = unsafe { self.widths.get_unchecked(..trace_heights.len()) };
110
111        trace_heights
112            .iter()
113            .zip(widths_slice)
114            .map(|(&height, &width)| height as usize * width)
115            .sum()
116    }
117
118    /// Calculate the total interactions based on trace heights and interaction counts
119    #[inline(always)]
120    fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize {
121        debug_assert_eq!(trace_heights.len(), self.interactions.len());
122
123        // SAFETY: Length equality is asserted during initialization
124        let interactions_slice = unsafe { self.interactions.get_unchecked(..trace_heights.len()) };
125
126        trace_heights
127            .iter()
128            .zip(interactions_slice)
129            // We add 1 for the zero messages from the padding rows
130            .map(|(&height, &interactions)| (height + 1) as usize * interactions)
131            .sum()
132    }
133
134    #[inline(always)]
135    fn should_segment(
136        &self,
137        instret: u64,
138        trace_heights: &[u32],
139        is_trace_height_constant: &[bool],
140    ) -> bool {
141        debug_assert_eq!(trace_heights.len(), is_trace_height_constant.len());
142        debug_assert_eq!(trace_heights.len(), self.air_names.len());
143
144        let instret_start = self
145            .segments
146            .last()
147            .map_or(0, |s| s.instret_start + s.num_insns);
148        let num_insns = instret - instret_start;
149
150        // Segment should contain at least one cycle
151        if num_insns == 0 {
152            return false;
153        }
154
155        for (i, (height, is_constant)) in trace_heights
156            .iter()
157            .zip(is_trace_height_constant.iter())
158            .enumerate()
159        {
160            // Only segment if the height is not constant and exceeds the maximum height
161            if !is_constant && *height > self.segmentation_limits.max_trace_height {
162                let air_name = &self.air_names[i];
163                tracing::info!(
164                    "Segment {:2} | instret {:9} | chip {} ({}) height ({:8}) > max ({:8})",
165                    self.segments.len(),
166                    instret,
167                    i,
168                    air_name,
169                    height,
170                    self.segmentation_limits.max_trace_height
171                );
172                return true;
173            }
174        }
175
176        let total_cells = self.calculate_total_cells(trace_heights);
177        if total_cells > self.segmentation_limits.max_cells {
178            tracing::info!(
179                "Segment {:2} | instret {:9} | total cells ({:10}) > max ({:10})",
180                self.segments.len(),
181                instret,
182                total_cells,
183                self.segmentation_limits.max_cells
184            );
185            return true;
186        }
187
188        let total_interactions = self.calculate_total_interactions(trace_heights);
189        if total_interactions > self.segmentation_limits.max_interactions {
190            tracing::info!(
191                "Segment {:2} | instret {:9} | total interactions ({:11}) > max ({:11})",
192                self.segments.len(),
193                instret,
194                total_interactions,
195                self.segmentation_limits.max_interactions
196            );
197            return true;
198        }
199
200        false
201    }
202
203    #[inline(always)]
204    pub fn check_and_segment(
205        &mut self,
206        instret: u64,
207        trace_heights: &[u32],
208        is_trace_height_constant: &[bool],
209    ) -> bool {
210        let ret = self.should_segment(instret, trace_heights, is_trace_height_constant);
211        if ret {
212            self.segment(instret, trace_heights);
213        }
214        self.instret_last_segment_check = instret;
215
216        ret
217    }
218
219    /// Try segment if there is at least one cycle
220    #[inline(always)]
221    pub fn segment(&mut self, instret: u64, trace_heights: &[u32]) {
222        let instret_start = self
223            .segments
224            .last()
225            .map_or(0, |s| s.instret_start + s.num_insns);
226        let num_insns = instret - instret_start;
227
228        debug_assert!(num_insns > 0, "Segment should contain at least one cycle");
229
230        self.segments.push(Segment {
231            instret_start,
232            num_insns,
233            trace_heights: trace_heights.to_vec(),
234        });
235    }
236}