1use getset::WithSetters;
2use openvm_stark_backend::p3_field::PrimeField32;
3use p3_baby_bear::BabyBear;
4use serde::{Deserialize, Serialize};
5
6pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000;
7
8pub const DEFAULT_MAX_TRACE_HEIGHT_BITS: u8 = 22;
9pub const DEFAULT_MAX_TRACE_HEIGHT: u32 = 1 << DEFAULT_MAX_TRACE_HEIGHT_BITS;
10pub const DEFAULT_MAX_CELLS: usize = 1_200_000_000; const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize;
12
13#[derive(derive_new::new, Clone, Debug, Serialize, Deserialize)]
14pub struct Segment {
15 pub instret_start: u64,
16 pub num_insns: u64,
17 pub trace_heights: Vec<u32>,
18}
19
20#[derive(Clone, Copy, Debug, WithSetters)]
21pub struct SegmentationLimits {
22 #[getset(set_with = "pub")]
23 pub max_trace_height: u32,
24 #[getset(set_with = "pub")]
25 pub max_cells: usize,
26 #[getset(set_with = "pub")]
27 pub max_interactions: usize,
28}
29
30impl Default for SegmentationLimits {
31 fn default() -> Self {
32 Self {
33 max_trace_height: DEFAULT_MAX_TRACE_HEIGHT,
34 max_cells: DEFAULT_MAX_CELLS,
35 max_interactions: DEFAULT_MAX_INTERACTIONS,
36 }
37 }
38}
39
40impl SegmentationLimits {
41 pub fn new(max_trace_height: u32, max_cells: usize, max_interactions: usize) -> Self {
42 debug_assert!(
43 max_trace_height.is_power_of_two(),
44 "max_trace_height should be a power of two"
45 );
46 Self {
47 max_trace_height,
48 max_cells,
49 max_interactions,
50 }
51 }
52
53 pub fn set_max_trace_height(&mut self, max_trace_height: u32) {
54 debug_assert!(
55 max_trace_height.is_power_of_two(),
56 "max_trace_height should be a power of two"
57 );
58 self.max_trace_height = max_trace_height;
59 }
60}
61
62#[derive(Clone, Debug, WithSetters)]
63pub struct SegmentationCtx {
64 pub segments: Vec<Segment>,
65 pub(crate) air_names: Vec<String>,
66 pub(crate) widths: Vec<usize>,
67 interactions: Vec<usize>,
68 pub(crate) segmentation_limits: SegmentationLimits,
69 pub instret: u64,
70 pub instrets_until_check: u64,
71 pub(super) segment_check_insns: u64,
72 pub(crate) checkpoint_trace_heights: Vec<u32>,
74 checkpoint_instret: u64,
76}
77
78impl SegmentationCtx {
79 pub fn new(
80 air_names: Vec<String>,
81 widths: Vec<usize>,
82 interactions: Vec<usize>,
83 segmentation_limits: SegmentationLimits,
84 ) -> Self {
85 assert_eq!(air_names.len(), widths.len());
86 assert_eq!(air_names.len(), interactions.len());
87
88 let num_airs = air_names.len();
89 Self {
90 segments: Vec::new(),
91 air_names,
92 widths,
93 interactions,
94 segmentation_limits,
95 instret: 0,
96 instrets_until_check: DEFAULT_SEGMENT_CHECK_INSNS,
97 segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
98 checkpoint_trace_heights: vec![0; num_airs],
99 checkpoint_instret: 0,
100 }
101 }
102
103 pub fn new_with_default_segmentation_limits(
104 air_names: Vec<String>,
105 widths: Vec<usize>,
106 interactions: Vec<usize>,
107 ) -> Self {
108 assert_eq!(air_names.len(), widths.len());
109 assert_eq!(air_names.len(), interactions.len());
110
111 let num_airs = air_names.len();
112 Self {
113 segments: Vec::new(),
114 air_names,
115 widths,
116 interactions,
117 segmentation_limits: SegmentationLimits::default(),
118 instret: 0,
119 instrets_until_check: DEFAULT_SEGMENT_CHECK_INSNS,
120 segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
121 checkpoint_trace_heights: vec![0; num_airs],
122 checkpoint_instret: 0,
123 }
124 }
125
126 pub fn set_max_trace_height(&mut self, max_trace_height: u32) {
127 self.segmentation_limits
128 .set_max_trace_height(max_trace_height);
129 }
130
131 pub fn set_max_cells(&mut self, max_cells: usize) {
132 self.segmentation_limits.max_cells = max_cells;
133 }
134
135 pub fn set_max_interactions(&mut self, max_interactions: usize) {
136 self.segmentation_limits.max_interactions = max_interactions;
137 }
138
139 #[inline(always)]
141 fn calculate_max_trace_height_with_name(&self, trace_heights: &[u32]) -> (u32, &str) {
142 trace_heights
143 .iter()
144 .enumerate()
145 .map(|(i, &height)| (height.next_power_of_two(), i))
146 .max_by_key(|(height, _)| *height)
147 .map(|(height, idx)| (height, self.air_names[idx].as_str()))
148 .unwrap_or((0, "unknown"))
149 }
150
151 #[inline(always)]
153 fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize {
154 debug_assert_eq!(trace_heights.len(), self.widths.len());
155
156 trace_heights
157 .iter()
158 .zip(self.widths.iter())
159 .map(|(&height, &width)| height.next_power_of_two() as usize * width)
160 .sum()
161 }
162
163 #[inline(always)]
167 fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize {
168 debug_assert_eq!(trace_heights.len(), self.interactions.len());
169
170 trace_heights
171 .iter()
172 .zip(self.interactions.iter())
173 .map(|(&height, &interactions)| (height + 1) as usize * interactions)
174 .sum()
175 }
176
177 #[inline(always)]
178 pub(crate) fn should_segment(
179 &self,
180 instret: u64,
181 trace_heights: &[u32],
182 is_trace_height_constant: &[bool],
183 ) -> bool {
184 debug_assert_eq!(trace_heights.len(), is_trace_height_constant.len());
185 debug_assert_eq!(trace_heights.len(), self.air_names.len());
186 debug_assert_eq!(trace_heights.len(), self.widths.len());
187 debug_assert_eq!(trace_heights.len(), self.interactions.len());
188
189 let instret_start = self
190 .segments
191 .last()
192 .map_or(0, |s| s.instret_start + s.num_insns);
193 let num_insns = instret - instret_start;
194
195 if num_insns == 0 {
197 return false;
198 }
199
200 let mut total_cells = 0;
201 for (i, ((padded_height, width), is_constant)) in trace_heights
202 .iter()
203 .map(|&height| height.next_power_of_two())
204 .zip(self.widths.iter())
205 .zip(is_trace_height_constant.iter())
206 .enumerate()
207 {
208 if !is_constant && padded_height > self.segmentation_limits.max_trace_height {
211 let air_name = unsafe { self.air_names.get_unchecked(i) };
212 tracing::info!(
213 "instret {:10} | height ({:8}) > max ({:8}) | chip {:3} ({}) ",
214 instret,
215 padded_height,
216 self.segmentation_limits.max_trace_height,
217 i,
218 air_name,
219 );
220 return true;
221 }
222 total_cells += padded_height as usize * width;
223 }
224
225 if total_cells > self.segmentation_limits.max_cells {
226 tracing::info!(
227 "instret {:10} | total cells ({:10}) > max ({:10})",
228 instret,
229 total_cells,
230 self.segmentation_limits.max_cells
231 );
232 return true;
233 }
234
235 let total_interactions = self.calculate_total_interactions(trace_heights);
236 if total_interactions > self.segmentation_limits.max_interactions {
237 tracing::info!(
238 "instret {:10} | total interactions ({:10}) > max ({:10})",
239 instret,
240 total_interactions,
241 self.segmentation_limits.max_interactions
242 );
243 return true;
244 }
245
246 false
247 }
248
249 #[inline(always)]
250 pub fn check_and_segment(
251 &mut self,
252 instret: u64,
253 trace_heights: &mut [u32],
254 is_trace_height_constant: &[bool],
255 ) -> bool {
256 let should_seg = self.should_segment(instret, trace_heights, is_trace_height_constant);
257
258 if should_seg {
259 self.create_segment_from_checkpoint(instret, trace_heights);
260 }
261 should_seg
262 }
263
264 #[inline(always)]
265 fn create_segment_from_checkpoint(&mut self, instret: u64, trace_heights: &mut [u32]) {
266 let instret_start = self
267 .segments
268 .last()
269 .map_or(0, |s| s.instret_start + s.num_insns);
270
271 let (segment_instret, segment_heights) = if self.checkpoint_instret > instret_start {
272 (
273 self.checkpoint_instret,
274 self.checkpoint_trace_heights.clone(),
275 )
276 } else {
277 let trace_heights_str = trace_heights
278 .iter()
279 .zip(self.air_names.iter())
280 .filter(|(&height, _)| height > 0)
281 .map(|(&height, name)| format!(" {name} = {height}"))
282 .collect::<Vec<_>>()
283 .join("\n");
284 tracing::warn!(
285 "No valid checkpoint, creating segment using instret={instret}\ntrace_heights=[\n{trace_heights_str}\n]"
286 );
287 (instret, trace_heights.to_vec())
289 };
290
291 let num_insns = segment_instret - instret_start;
292 self.create_segment::<false>(instret_start, num_insns, segment_heights);
293 }
294
295 #[inline(always)]
297 pub(crate) fn initialize_segment(
298 &mut self,
299 trace_heights: &mut [u32],
300 is_trace_height_constant: &[bool],
301 ) {
302 let last_segment = self.segments.last().unwrap();
304 self.reset_trace_heights(
305 trace_heights,
306 &last_segment.trace_heights,
307 is_trace_height_constant,
308 );
309 }
310
311 #[inline(always)]
313 fn reset_trace_heights(
314 &self,
315 trace_heights: &mut [u32],
316 segment_heights: &[u32],
317 is_trace_height_constant: &[bool],
318 ) {
319 for ((trace_height, &segment_height), &is_trace_height_constant) in trace_heights
320 .iter_mut()
321 .zip(segment_heights.iter())
322 .zip(is_trace_height_constant.iter())
323 {
324 if !is_trace_height_constant {
325 *trace_height = trace_height.checked_sub(segment_height).unwrap();
326 }
327 }
328 }
329
330 #[inline(always)]
332 pub(crate) fn update_checkpoint(&mut self, instret: u64, trace_heights: &[u32]) {
333 self.checkpoint_trace_heights.copy_from_slice(trace_heights);
334 self.checkpoint_instret = instret;
335 }
336
337 #[inline(always)]
339 pub fn create_final_segment(&mut self, trace_heights: &[u32]) {
340 self.instret += self.segment_check_insns - self.instrets_until_check;
341 self.instrets_until_check = self.segment_check_insns;
342 let instret_start = self
343 .segments
344 .last()
345 .map_or(0, |s| s.instret_start + s.num_insns);
346
347 let num_insns = self.instret - instret_start;
348 self.create_segment::<true>(instret_start, num_insns, trace_heights.to_vec());
349 }
350
351 #[inline(always)]
353 fn create_segment<const IS_FINAL: bool>(
354 &mut self,
355 instret_start: u64,
356 num_insns: u64,
357 trace_heights: Vec<u32>,
358 ) {
359 debug_assert!(
360 num_insns > 0,
361 "Segment should contain at least one instruction"
362 );
363
364 self.log_segment_info::<IS_FINAL>(instret_start, num_insns, &trace_heights);
365 self.segments.push(Segment {
366 instret_start,
367 num_insns,
368 trace_heights,
369 });
370 }
371
372 #[inline(always)]
374 fn log_segment_info<const IS_FINAL: bool>(
375 &self,
376 instret_start: u64,
377 num_insns: u64,
378 trace_heights: &[u32],
379 ) {
380 let (max_trace_height, air_name) = self.calculate_max_trace_height_with_name(trace_heights);
381 let total_cells = self.calculate_total_cells(trace_heights);
382 let total_interactions = self.calculate_total_interactions(trace_heights);
383
384 let final_marker = if IS_FINAL { " [TERMINATED]" } else { "" };
385
386 tracing::info!(
387 "Segment {:3} | instret {:10} | {:8} instructions | {:10} cells | {:10} interactions | {:8} max height ({}){}",
388 self.segments.len(),
389 instret_start,
390 num_insns,
391 total_cells,
392 total_interactions,
393 max_trace_height,
394 air_name,
395 final_marker
396 );
397 }
398}