openvm_circuit/arch/execution_mode/metered/
segment_ctx.rs1use getset::WithSetters;
2use openvm_stark_backend::p3_field::PrimeField32;
3use p3_baby_bear::BabyBear;
4use serde::{Deserialize, Serialize};
5
6pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000;
7
8pub const DEFAULT_MAX_TRACE_HEIGHT: u32 = (1 << 23) - 10000;
9pub const DEFAULT_MAX_CELLS: usize = 2_000_000_000; const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize;
11
12#[derive(derive_new::new, Clone, Debug, Serialize, Deserialize)]
13pub struct Segment {
14 pub instret_start: u64,
15 pub num_insns: u64,
16 pub trace_heights: Vec<u32>,
17}
18
19#[derive(Clone, Copy, Debug, WithSetters)]
20pub struct SegmentationLimits {
21 #[getset(set_with = "pub")]
22 pub max_trace_height: u32,
23 #[getset(set_with = "pub")]
24 pub max_cells: usize,
25 #[getset(set_with = "pub")]
26 pub max_interactions: usize,
27}
28
29impl Default for SegmentationLimits {
30 fn default() -> Self {
31 Self {
32 max_trace_height: DEFAULT_MAX_TRACE_HEIGHT,
33 max_cells: DEFAULT_MAX_CELLS,
34 max_interactions: DEFAULT_MAX_INTERACTIONS,
35 }
36 }
37}
38
39#[derive(Clone, Debug, WithSetters)]
40pub struct SegmentationCtx {
41 pub segments: Vec<Segment>,
42 pub(crate) air_names: Vec<String>,
43 widths: Vec<usize>,
44 interactions: Vec<usize>,
45 pub(crate) segmentation_limits: SegmentationLimits,
46 pub instret_last_segment_check: u64,
47 #[getset(set_with = "pub")]
48 pub segment_check_insns: u64,
49}
50
51impl SegmentationCtx {
52 pub fn new(
53 air_names: Vec<String>,
54 widths: Vec<usize>,
55 interactions: Vec<usize>,
56 segmentation_limits: SegmentationLimits,
57 ) -> Self {
58 assert_eq!(air_names.len(), widths.len());
59 assert_eq!(air_names.len(), interactions.len());
60
61 Self {
62 segments: Vec::new(),
63 air_names,
64 widths,
65 interactions,
66 segmentation_limits,
67 segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
68 instret_last_segment_check: 0,
69 }
70 }
71
72 pub fn new_with_default_segmentation_limits(
73 air_names: Vec<String>,
74 widths: Vec<usize>,
75 interactions: Vec<usize>,
76 ) -> Self {
77 assert_eq!(air_names.len(), widths.len());
78 assert_eq!(air_names.len(), interactions.len());
79
80 Self {
81 segments: Vec::new(),
82 air_names,
83 widths,
84 interactions,
85 segmentation_limits: SegmentationLimits::default(),
86 segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
87 instret_last_segment_check: 0,
88 }
89 }
90
91 pub fn set_max_trace_height(&mut self, max_trace_height: u32) {
92 self.segmentation_limits.max_trace_height = max_trace_height;
93 }
94
95 pub fn set_max_cells(&mut self, max_cells: usize) {
96 self.segmentation_limits.max_cells = max_cells;
97 }
98
99 pub fn set_max_interactions(&mut self, max_interactions: usize) {
100 self.segmentation_limits.max_interactions = max_interactions;
101 }
102
103 #[inline(always)]
105 fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize {
106 debug_assert_eq!(trace_heights.len(), self.widths.len());
107
108 let widths_slice = unsafe { self.widths.get_unchecked(..trace_heights.len()) };
110
111 trace_heights
112 .iter()
113 .zip(widths_slice)
114 .map(|(&height, &width)| height as usize * width)
115 .sum()
116 }
117
118 #[inline(always)]
120 fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize {
121 debug_assert_eq!(trace_heights.len(), self.interactions.len());
122
123 let interactions_slice = unsafe { self.interactions.get_unchecked(..trace_heights.len()) };
125
126 trace_heights
127 .iter()
128 .zip(interactions_slice)
129 .map(|(&height, &interactions)| (height + 1) as usize * interactions)
131 .sum()
132 }
133
134 #[inline(always)]
135 fn should_segment(
136 &self,
137 instret: u64,
138 trace_heights: &[u32],
139 is_trace_height_constant: &[bool],
140 ) -> bool {
141 debug_assert_eq!(trace_heights.len(), is_trace_height_constant.len());
142 debug_assert_eq!(trace_heights.len(), self.air_names.len());
143
144 let instret_start = self
145 .segments
146 .last()
147 .map_or(0, |s| s.instret_start + s.num_insns);
148 let num_insns = instret - instret_start;
149
150 if num_insns == 0 {
152 return false;
153 }
154
155 for (i, (height, is_constant)) in trace_heights
156 .iter()
157 .zip(is_trace_height_constant.iter())
158 .enumerate()
159 {
160 if !is_constant && *height > self.segmentation_limits.max_trace_height {
162 let air_name = &self.air_names[i];
163 tracing::info!(
164 "Segment {:2} | instret {:9} | chip {} ({}) height ({:8}) > max ({:8})",
165 self.segments.len(),
166 instret,
167 i,
168 air_name,
169 height,
170 self.segmentation_limits.max_trace_height
171 );
172 return true;
173 }
174 }
175
176 let total_cells = self.calculate_total_cells(trace_heights);
177 if total_cells > self.segmentation_limits.max_cells {
178 tracing::info!(
179 "Segment {:2} | instret {:9} | total cells ({:10}) > max ({:10})",
180 self.segments.len(),
181 instret,
182 total_cells,
183 self.segmentation_limits.max_cells
184 );
185 return true;
186 }
187
188 let total_interactions = self.calculate_total_interactions(trace_heights);
189 if total_interactions > self.segmentation_limits.max_interactions {
190 tracing::info!(
191 "Segment {:2} | instret {:9} | total interactions ({:11}) > max ({:11})",
192 self.segments.len(),
193 instret,
194 total_interactions,
195 self.segmentation_limits.max_interactions
196 );
197 return true;
198 }
199
200 false
201 }
202
203 #[inline(always)]
204 pub fn check_and_segment(
205 &mut self,
206 instret: u64,
207 trace_heights: &[u32],
208 is_trace_height_constant: &[bool],
209 ) -> bool {
210 let ret = self.should_segment(instret, trace_heights, is_trace_height_constant);
211 if ret {
212 self.segment(instret, trace_heights);
213 }
214 self.instret_last_segment_check = instret;
215
216 ret
217 }
218
219 #[inline(always)]
221 pub fn segment(&mut self, instret: u64, trace_heights: &[u32]) {
222 let instret_start = self
223 .segments
224 .last()
225 .map_or(0, |s| s.instret_start + s.num_insns);
226 let num_insns = instret - instret_start;
227
228 debug_assert!(num_insns > 0, "Segment should contain at least one cycle");
229
230 self.segments.push(Segment {
231 instret_start,
232 num_insns,
233 trace_heights: trace_heights.to_vec(),
234 });
235 }
236}