openvm_circuit/arch/execution_mode/metered/
segment_ctx.rs1use getset::WithSetters;
2use openvm_stark_backend::p3_field::PrimeField32;
3use p3_baby_bear::BabyBear;
4use serde::{Deserialize, Serialize};
5
6pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000;
7
8pub const DEFAULT_MAX_TRACE_HEIGHT_BITS: u8 = 22;
9pub const DEFAULT_MAX_TRACE_HEIGHT: u32 = 1 << DEFAULT_MAX_TRACE_HEIGHT_BITS;
10pub const DEFAULT_MAX_CELLS: usize = 1_200_000_000; const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize;
12
13#[derive(derive_new::new, Clone, Debug, Serialize, Deserialize)]
14pub struct Segment {
15 pub instret_start: u64,
16 pub num_insns: u64,
17 pub trace_heights: Vec<u32>,
18}
19
20#[derive(Clone, Copy, Debug, WithSetters)]
21pub struct SegmentationLimits {
22 #[getset(set_with = "pub")]
23 pub max_trace_height: u32,
24 #[getset(set_with = "pub")]
25 pub max_cells: usize,
26 #[getset(set_with = "pub")]
27 pub max_interactions: usize,
28}
29
30impl Default for SegmentationLimits {
31 fn default() -> Self {
32 Self {
33 max_trace_height: DEFAULT_MAX_TRACE_HEIGHT,
34 max_cells: DEFAULT_MAX_CELLS,
35 max_interactions: DEFAULT_MAX_INTERACTIONS,
36 }
37 }
38}
39
40#[derive(Clone, Debug, WithSetters)]
41pub struct SegmentationCtx {
42 pub segments: Vec<Segment>,
43 pub(crate) air_names: Vec<String>,
44 pub(crate) widths: Vec<usize>,
45 interactions: Vec<usize>,
46 pub(crate) segmentation_limits: SegmentationLimits,
47 pub instret_last_segment_check: u64,
48 #[getset(set_with = "pub")]
49 pub segment_check_insns: u64,
50 pub(crate) checkpoint_trace_heights: Vec<u32>,
52 checkpoint_instret: u64,
54}
55
56impl SegmentationCtx {
57 pub fn new(
58 air_names: Vec<String>,
59 widths: Vec<usize>,
60 interactions: Vec<usize>,
61 segmentation_limits: SegmentationLimits,
62 ) -> Self {
63 assert_eq!(air_names.len(), widths.len());
64 assert_eq!(air_names.len(), interactions.len());
65
66 let num_airs = air_names.len();
67 Self {
68 segments: Vec::new(),
69 air_names,
70 widths,
71 interactions,
72 segmentation_limits,
73 segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
74 instret_last_segment_check: 0,
75 checkpoint_trace_heights: vec![0; num_airs],
76 checkpoint_instret: 0,
77 }
78 }
79
80 pub fn new_with_default_segmentation_limits(
81 air_names: Vec<String>,
82 widths: Vec<usize>,
83 interactions: Vec<usize>,
84 ) -> Self {
85 assert_eq!(air_names.len(), widths.len());
86 assert_eq!(air_names.len(), interactions.len());
87
88 let num_airs = air_names.len();
89 Self {
90 segments: Vec::new(),
91 air_names,
92 widths,
93 interactions,
94 segmentation_limits: SegmentationLimits::default(),
95 segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
96 instret_last_segment_check: 0,
97 checkpoint_trace_heights: vec![0; num_airs],
98 checkpoint_instret: 0,
99 }
100 }
101
102 pub fn set_max_trace_height(&mut self, max_trace_height: u32) {
103 debug_assert!(
104 max_trace_height.is_power_of_two(),
105 "max_trace_height should be a power of two"
106 );
107 self.segmentation_limits.max_trace_height = max_trace_height;
108 }
109
110 pub fn set_max_cells(&mut self, max_cells: usize) {
111 self.segmentation_limits.max_cells = max_cells;
112 }
113
114 pub fn set_max_interactions(&mut self, max_interactions: usize) {
115 self.segmentation_limits.max_interactions = max_interactions;
116 }
117
118 #[inline(always)]
120 fn calculate_max_trace_height_with_name(&self, trace_heights: &[u32]) -> (u32, &str) {
121 trace_heights
122 .iter()
123 .enumerate()
124 .map(|(i, &height)| (height.next_power_of_two(), i))
125 .max_by_key(|(height, _)| *height)
126 .map(|(height, idx)| (height, self.air_names[idx].as_str()))
127 .unwrap_or((0, "unknown"))
128 }
129
130 #[inline(always)]
132 fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize {
133 debug_assert_eq!(trace_heights.len(), self.widths.len());
134
135 trace_heights
136 .iter()
137 .zip(self.widths.iter())
138 .map(|(&height, &width)| height.next_power_of_two() as usize * width)
139 .sum()
140 }
141
142 #[inline(always)]
146 fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize {
147 debug_assert_eq!(trace_heights.len(), self.interactions.len());
148
149 trace_heights
150 .iter()
151 .zip(self.interactions.iter())
152 .map(|(&height, &interactions)| (height + 1) as usize * interactions)
153 .sum()
154 }
155
156 #[inline(always)]
157 fn should_segment(
158 &self,
159 instret: u64,
160 trace_heights: &[u32],
161 is_trace_height_constant: &[bool],
162 ) -> bool {
163 debug_assert_eq!(trace_heights.len(), is_trace_height_constant.len());
164 debug_assert_eq!(trace_heights.len(), self.air_names.len());
165 debug_assert_eq!(trace_heights.len(), self.widths.len());
166 debug_assert_eq!(trace_heights.len(), self.interactions.len());
167
168 let instret_start = self
169 .segments
170 .last()
171 .map_or(0, |s| s.instret_start + s.num_insns);
172 let num_insns = instret - instret_start;
173
174 if num_insns == 0 {
176 return false;
177 }
178
179 let mut total_cells = 0;
180 for (i, ((padded_height, width), is_constant)) in trace_heights
181 .iter()
182 .map(|&height| height.next_power_of_two())
183 .zip(self.widths.iter())
184 .zip(is_trace_height_constant.iter())
185 .enumerate()
186 {
187 if !is_constant && padded_height > self.segmentation_limits.max_trace_height {
190 let air_name = unsafe { self.air_names.get_unchecked(i) };
191 tracing::info!(
192 "instret {:10} | height ({:8}) > max ({:8}) | chip {:3} ({}) ",
193 instret,
194 i,
195 air_name,
196 padded_height,
197 self.segmentation_limits.max_trace_height
198 );
199 return true;
200 }
201 total_cells += padded_height as usize * width;
202 }
203
204 if total_cells > self.segmentation_limits.max_cells {
205 tracing::info!(
206 "instret {:10} | total cells ({:10}) > max ({:10})",
207 instret,
208 total_cells,
209 self.segmentation_limits.max_cells
210 );
211 return true;
212 }
213
214 let total_interactions = self.calculate_total_interactions(trace_heights);
215 if total_interactions > self.segmentation_limits.max_interactions {
216 tracing::info!(
217 "instret {:10} | total interactions ({:10}) > max ({:10})",
218 instret,
219 total_interactions,
220 self.segmentation_limits.max_interactions
221 );
222 return true;
223 }
224
225 false
226 }
227
228 #[inline(always)]
229 pub fn check_and_segment(
230 &mut self,
231 instret: u64,
232 trace_heights: &mut [u32],
233 is_trace_height_constant: &[bool],
234 ) -> bool {
235 let should_seg = self.should_segment(instret, trace_heights, is_trace_height_constant);
236
237 if should_seg {
238 self.create_segment_from_checkpoint(instret, trace_heights, is_trace_height_constant);
239 } else {
240 self.update_checkpoint(instret, trace_heights);
241 }
242
243 self.instret_last_segment_check = instret;
244 should_seg
245 }
246
247 #[inline(always)]
248 fn create_segment_from_checkpoint(
249 &mut self,
250 instret: u64,
251 trace_heights: &mut [u32],
252 is_trace_height_constant: &[bool],
253 ) {
254 let instret_start = self
255 .segments
256 .last()
257 .map_or(0, |s| s.instret_start + s.num_insns);
258
259 let (segment_instret, segment_heights) = if self.checkpoint_instret > instret_start {
260 (
261 self.checkpoint_instret,
262 self.checkpoint_trace_heights.clone(),
263 )
264 } else {
265 tracing::warn!(
266 "No valid checkpoint, creating segment using instret={instret}, trace_heights={:?}",
267 trace_heights
268 );
269 (instret, trace_heights.to_vec())
271 };
272
273 self.reset_trace_heights(trace_heights, &segment_heights, is_trace_height_constant);
275 self.checkpoint_instret = 0;
276
277 let num_insns = segment_instret - instret_start;
278 self.create_segment::<false>(instret_start, num_insns, segment_heights);
279 }
280
281 #[inline(always)]
283 fn reset_trace_heights(
284 &self,
285 trace_heights: &mut [u32],
286 segment_heights: &[u32],
287 is_trace_height_constant: &[bool],
288 ) {
289 for ((trace_height, &segment_height), &is_trace_height_constant) in trace_heights
290 .iter_mut()
291 .zip(segment_heights.iter())
292 .zip(is_trace_height_constant.iter())
293 {
294 if !is_trace_height_constant {
295 *trace_height = trace_height.checked_sub(segment_height).unwrap();
296 }
297 }
298 }
299
300 #[inline(always)]
302 fn update_checkpoint(&mut self, instret: u64, trace_heights: &[u32]) {
303 self.checkpoint_trace_heights.copy_from_slice(trace_heights);
304 self.checkpoint_instret = instret;
305 }
306
307 #[inline(always)]
309 pub fn create_final_segment(&mut self, instret: u64, trace_heights: &[u32]) {
310 let instret_start = self
311 .segments
312 .last()
313 .map_or(0, |s| s.instret_start + s.num_insns);
314
315 let num_insns = instret - instret_start;
316 self.create_segment::<true>(instret_start, num_insns, trace_heights.to_vec());
317 }
318
319 #[inline(always)]
321 fn create_segment<const IS_FINAL: bool>(
322 &mut self,
323 instret_start: u64,
324 num_insns: u64,
325 trace_heights: Vec<u32>,
326 ) {
327 debug_assert!(
328 num_insns > 0,
329 "Segment should contain at least one instruction"
330 );
331
332 self.log_segment_info::<IS_FINAL>(instret_start, num_insns, &trace_heights);
333 self.segments.push(Segment {
334 instret_start,
335 num_insns,
336 trace_heights,
337 });
338 }
339
340 #[inline(always)]
342 fn log_segment_info<const IS_FINAL: bool>(
343 &self,
344 instret_start: u64,
345 num_insns: u64,
346 trace_heights: &[u32],
347 ) {
348 let (max_trace_height, air_name) = self.calculate_max_trace_height_with_name(trace_heights);
349 let total_cells = self.calculate_total_cells(trace_heights);
350 let total_interactions = self.calculate_total_interactions(trace_heights);
351
352 let final_marker = if IS_FINAL { " [TERMINATED]" } else { "" };
353
354 tracing::info!(
355 "Segment {:3} | instret {:10} | {:8} instructions | {:10} cells | {:10} interactions | {:8} max height ({}){}",
356 self.segments.len(),
357 instret_start,
358 num_insns,
359 total_cells,
360 total_interactions,
361 max_trace_height,
362 air_name,
363 final_marker
364 );
365 }
366}