openvm_circuit/arch/execution_mode/metered/
ctx.rs1use std::num::NonZero;
2
3use openvm_instructions::riscv::{RV32_IMM_AS, RV32_REGISTER_AS};
4
5use super::{
6 memory_ctx::MemoryCtx,
7 segment_ctx::{Segment, SegmentationCtx},
8};
9use crate::{
10 arch::{
11 execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait},
12 SystemConfig, VmExecState,
13 },
14 system::memory::online::GuestMemory,
15};
16
17pub const DEFAULT_PAGE_BITS: usize = 6;
18
19#[derive(Clone, Debug)]
20pub struct MeteredCtx<const PAGE_BITS: usize = DEFAULT_PAGE_BITS> {
21 pub trace_heights: Vec<u32>,
22 pub is_trace_height_constant: Vec<bool>,
23 pub memory_ctx: MemoryCtx<PAGE_BITS>,
24 pub segmentation_ctx: SegmentationCtx,
25}
26
27impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
28 pub fn new(
30 constant_trace_heights: Vec<Option<usize>>,
31 air_names: Vec<String>,
32 widths: Vec<usize>,
33 interactions: Vec<usize>,
34 config: &SystemConfig,
35 ) -> Self {
36 let (trace_heights, is_trace_height_constant): (Vec<u32>, Vec<bool>) =
37 constant_trace_heights
38 .iter()
39 .map(|&constant_height| {
40 if let Some(height) = constant_height {
41 (height as u32, true)
42 } else {
43 (0, false)
44 }
45 })
46 .unzip();
47
48 let memory_ctx = MemoryCtx::new(config);
49
50 debug_assert!(
52 air_names[memory_ctx.boundary_idx].contains("Boundary"),
53 "air_name={}",
54 air_names[memory_ctx.boundary_idx]
55 );
56 if let Some(merkle_tree_index) = memory_ctx.merkle_tree_index {
57 debug_assert!(
58 air_names[merkle_tree_index].contains("Merkle"),
59 "air_name={}",
60 air_names[merkle_tree_index]
61 );
62 }
63 debug_assert!(
64 air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"),
65 "air_name={}",
66 air_names[memory_ctx.adapter_offset]
67 );
68
69 let segmentation_ctx =
70 SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits);
71
72 let mut ctx = Self {
73 trace_heights,
74 is_trace_height_constant,
75 memory_ctx,
76 segmentation_ctx,
77 };
78 if !config.continuation_enabled {
79 ctx.segmentation_ctx.segment_check_insns = u64::MAX;
81 }
82
83 ctx.memory_ctx.add_register_merkle_heights();
85
86 ctx
87 }
88
89 pub fn with_max_trace_height(mut self, max_trace_height: u32) -> Self {
90 self.segmentation_ctx.set_max_trace_height(max_trace_height);
91 let max_check_freq = (max_trace_height / 2) as u64;
92 if max_check_freq < self.segmentation_ctx.segment_check_insns {
93 self.segmentation_ctx.segment_check_insns = max_check_freq;
94 }
95 self
96 }
97
98 pub fn with_max_cells(mut self, max_cells: usize) -> Self {
99 self.segmentation_ctx.set_max_cells(max_cells);
100 self
101 }
102
103 pub fn with_max_interactions(mut self, max_interactions: usize) -> Self {
104 self.segmentation_ctx.set_max_interactions(max_interactions);
105 self
106 }
107
108 pub fn segments(&self) -> &[Segment] {
109 &self.segmentation_ctx.segments
110 }
111
112 pub fn into_segments(self) -> Vec<Segment> {
113 self.segmentation_ctx.segments
114 }
115
116 fn reset_segment(&mut self) {
117 self.memory_ctx.clear();
118 for (i, &is_constant) in self.is_trace_height_constant.iter().enumerate() {
119 if !is_constant {
120 self.trace_heights[i] = 0;
121 }
122 }
123 self.memory_ctx.add_register_merkle_heights();
125 }
126
127 #[inline(always)]
128 pub fn check_and_segment(&mut self, instret: u64) {
129 let threshold = self
130 .segmentation_ctx
131 .instret_last_segment_check
132 .wrapping_add(self.segmentation_ctx.segment_check_insns);
133 debug_assert!(
134 threshold >= self.segmentation_ctx.instret_last_segment_check,
135 "overflow in segment check threshold calculation"
136 );
137 if instret < threshold {
138 return;
139 }
140
141 self.memory_ctx
142 .lazy_update_boundary_heights(&mut self.trace_heights);
143 let did_segment = self.segmentation_ctx.check_and_segment(
144 instret,
145 &self.trace_heights,
146 &self.is_trace_height_constant,
147 );
148
149 if did_segment {
150 self.reset_segment();
151 }
152 }
153
154 #[allow(dead_code)]
155 pub fn print_heights(&self) {
156 println!("{:>10} {:<30}", "Height", "Air Name");
157 println!("{}", "-".repeat(42));
158 for (i, height) in self.trace_heights.iter().enumerate() {
159 let air_name = self
160 .segmentation_ctx
161 .air_names
162 .get(i)
163 .map(|s| s.as_str())
164 .unwrap_or("Unknown");
165 println!("{:>10} {:<30}", height, air_name);
166 }
167 }
168}
169
170impl<const PAGE_BITS: usize> ExecutionCtxTrait for MeteredCtx<PAGE_BITS> {
171 #[inline(always)]
172 fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32) {
173 debug_assert!(
174 address_space != RV32_IMM_AS,
175 "address space must not be immediate"
176 );
177 debug_assert!(size > 0, "size must be greater than 0, got {}", size);
178 debug_assert!(
179 size.is_power_of_two(),
180 "size must be a power of 2, got {}",
181 size
182 );
183
184 let size_bits = unsafe { NonZero::new_unchecked(size).ilog2() };
187 self.memory_ctx
188 .update_adapter_heights(&mut self.trace_heights, address_space, size_bits);
189
190 if address_space != RV32_REGISTER_AS {
192 self.memory_ctx
193 .update_boundary_merkle_heights(address_space, ptr, size);
194 }
195 }
196
197 #[inline(always)]
198 fn should_suspend<F>(vm_state: &mut VmExecState<F, GuestMemory, Self>) -> bool {
199 vm_state.ctx.check_and_segment(vm_state.instret);
202 false
203 }
204
205 #[inline(always)]
206 fn on_terminate<F>(vm_state: &mut VmExecState<F, GuestMemory, Self>) {
207 vm_state
208 .ctx
209 .memory_ctx
210 .lazy_update_boundary_heights(&mut vm_state.ctx.trace_heights);
211 vm_state
212 .ctx
213 .segmentation_ctx
214 .segment(vm_state.instret, &vm_state.ctx.trace_heights);
215 }
216}
217
218impl<const PAGE_BITS: usize> MeteredExecutionCtxTrait for MeteredCtx<PAGE_BITS> {
219 #[inline(always)]
220 fn on_height_change(&mut self, chip_idx: usize, height_delta: u32) {
221 debug_assert!(
222 chip_idx < self.trace_heights.len(),
223 "chip_idx out of bounds"
224 );
225 unsafe {
227 *self.trace_heights.get_unchecked_mut(chip_idx) = self
228 .trace_heights
229 .get_unchecked(chip_idx)
230 .wrapping_add(height_delta);
231 }
232 }
233}