openvm_circuit/arch/execution_mode/metered/
ctx.rs

1use std::num::NonZero;
2
3use getset::{Getters, Setters, WithSetters};
4use itertools::Itertools;
5use openvm_instructions::riscv::{RV32_IMM_AS, RV32_REGISTER_AS};
6
7use super::{
8    memory_ctx::MemoryCtx,
9    segment_ctx::{Segment, SegmentationCtx},
10};
11use crate::{
12    arch::{
13        execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait},
14        SystemConfig, VmExecState,
15    },
16    system::memory::online::GuestMemory,
17};
18
19pub const DEFAULT_PAGE_BITS: usize = 6;
20
21#[derive(Clone, Debug, Getters, Setters, WithSetters)]
22pub struct MeteredCtx<const PAGE_BITS: usize = DEFAULT_PAGE_BITS> {
23    pub trace_heights: Vec<u32>,
24    pub is_trace_height_constant: Vec<bool>,
25    pub memory_ctx: MemoryCtx<PAGE_BITS>,
26    pub segmentation_ctx: SegmentationCtx,
27    #[getset(get = "pub", set = "pub", set_with = "pub")]
28    suspend_on_segment: bool,
29}
30
31impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
32    // Note[jpw]: prefer to use `build_metered_ctx` in `VmExecutor` or `VirtualMachine`.
33    pub fn new(
34        constant_trace_heights: Vec<Option<usize>>,
35        air_names: Vec<String>,
36        widths: Vec<usize>,
37        interactions: Vec<usize>,
38        config: &SystemConfig,
39    ) -> Self {
40        let (trace_heights, is_trace_height_constant): (Vec<u32>, Vec<bool>) =
41            constant_trace_heights
42                .iter()
43                .map(|&constant_height| {
44                    if let Some(height) = constant_height {
45                        (height as u32, true)
46                    } else {
47                        (0, false)
48                    }
49                })
50                .unzip();
51
52        let segmentation_ctx =
53            SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits);
54        let memory_ctx = MemoryCtx::new(config, segmentation_ctx.segment_check_insns);
55
56        // Assert that the indices are correct
57        debug_assert!(
58            segmentation_ctx.air_names[memory_ctx.boundary_idx].contains("Boundary"),
59            "air_name={}",
60            segmentation_ctx.air_names[memory_ctx.boundary_idx]
61        );
62        if let Some(merkle_tree_index) = memory_ctx.merkle_tree_index {
63            debug_assert!(
64                segmentation_ctx.air_names[merkle_tree_index].contains("Merkle"),
65                "air_name={}",
66                segmentation_ctx.air_names[merkle_tree_index]
67            );
68        }
69        debug_assert!(
70            segmentation_ctx.air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"),
71            "air_name={}",
72            segmentation_ctx.air_names[memory_ctx.adapter_offset]
73        );
74
75        let mut ctx = Self {
76            trace_heights,
77            is_trace_height_constant,
78            memory_ctx,
79            segmentation_ctx,
80            suspend_on_segment: false,
81        };
82
83        // Add merkle height contributions for all registers
84        ctx.memory_ctx.add_register_merkle_heights();
85        ctx.memory_ctx
86            .lazy_update_boundary_heights(&mut ctx.trace_heights);
87
88        ctx
89    }
90
91    /// This changes the frequency of segment checks. BE CAREFUL when you change this during
92    /// execution!
93    pub fn with_max_trace_height(mut self, max_trace_height: u32) -> Self {
94        self.segmentation_ctx.set_max_trace_height(max_trace_height);
95        let max_check_freq = (max_trace_height / 2) as u64;
96        if max_check_freq < self.segmentation_ctx.segment_check_insns {
97            self = self.with_segment_check_insns(max_check_freq);
98        }
99        self
100    }
101
102    pub fn with_max_cells(mut self, max_cells: usize) -> Self {
103        self.segmentation_ctx.set_max_cells(max_cells);
104        self
105    }
106
107    pub fn with_max_interactions(mut self, max_interactions: usize) -> Self {
108        self.segmentation_ctx.set_max_interactions(max_interactions);
109        self
110    }
111
112    pub fn with_segment_check_insns(mut self, segment_check_insns: u64) -> Self {
113        self.segmentation_ctx.segment_check_insns = segment_check_insns;
114        self.segmentation_ctx.instrets_until_check = segment_check_insns;
115
116        // Update memory context with new segment check instructions
117        let page_indices_since_checkpoint_cap =
118            MemoryCtx::<PAGE_BITS>::calculate_checkpoint_capacity(segment_check_insns);
119
120        self.memory_ctx.page_indices_since_checkpoint =
121            vec![0; page_indices_since_checkpoint_cap].into_boxed_slice();
122        self.memory_ctx.page_indices_since_checkpoint_len = 0;
123        self
124    }
125
126    pub fn segments(&self) -> &[Segment] {
127        &self.segmentation_ctx.segments
128    }
129
130    pub fn into_segments(self) -> Vec<Segment> {
131        self.segmentation_ctx.segments
132    }
133
134    #[inline(always)]
135    pub fn check_and_segment(&mut self) -> bool {
136        // We track the segmentation check by instrets_until_check instead of instret in order to
137        // save a register in AOT mode.
138        if self.segmentation_ctx.instrets_until_check > 0 {
139            return false;
140        }
141        self.segmentation_ctx.instrets_until_check = self.segmentation_ctx.segment_check_insns;
142        self.segmentation_ctx.instret += self.segmentation_ctx.segment_check_insns;
143
144        self.memory_ctx
145            .lazy_update_boundary_heights(&mut self.trace_heights);
146        let did_segment = self.segmentation_ctx.check_and_segment(
147            self.segmentation_ctx.instret,
148            &mut self.trace_heights,
149            &self.is_trace_height_constant,
150        );
151
152        if did_segment {
153            // Initialize contexts for new segment
154            self.segmentation_ctx
155                .initialize_segment(&mut self.trace_heights, &self.is_trace_height_constant);
156            self.memory_ctx.initialize_segment(&mut self.trace_heights);
157
158            // Check if the new segment is within limits
159            if self.segmentation_ctx.should_segment(
160                self.segmentation_ctx.instret,
161                &self.trace_heights,
162                &self.is_trace_height_constant,
163            ) {
164                let trace_heights_str = self
165                    .trace_heights
166                    .iter()
167                    .zip(self.segmentation_ctx.air_names.iter())
168                    .filter(|(&height, _)| height > 0)
169                    .map(|(&height, name)| format!("  {name} = {height}"))
170                    .collect::<Vec<_>>()
171                    .join("\n");
172                tracing::warn!(
173                    "Segment initialized with heights that exceed limits\n\
174                     instret={}\n\
175                     trace_heights=[\n{}\n]",
176                    self.segmentation_ctx.instret,
177                    trace_heights_str
178                );
179            }
180        }
181
182        // Update checkpoints
183        self.segmentation_ctx
184            .update_checkpoint(self.segmentation_ctx.instret, &self.trace_heights);
185        self.memory_ctx.update_checkpoint();
186
187        did_segment
188    }
189
190    #[allow(dead_code)]
191    pub fn print_segment(&self) {
192        println!("{}", "-".repeat(80));
193        println!("Segment {}", self.segmentation_ctx.segments.len() - 1);
194        println!("{}", "-".repeat(80));
195        println!("{:>10} {:>10} {:<30}", "Width", "Height", "Air Name");
196        println!("{}", "-".repeat(80));
197        for ((&width, &height), air_name) in self
198            .segmentation_ctx
199            .widths
200            .iter()
201            .zip_eq(self.trace_heights.iter())
202            .zip_eq(self.segmentation_ctx.air_names.iter())
203        {
204            println!("{:>10} {:>10} {:<30}", width, height, air_name.as_str());
205        }
206    }
207}
208
209impl<const PAGE_BITS: usize> ExecutionCtxTrait for MeteredCtx<PAGE_BITS> {
210    #[inline(always)]
211    fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32) {
212        debug_assert!(
213            address_space != RV32_IMM_AS,
214            "address space must not be immediate"
215        );
216        debug_assert!(size > 0, "size must be greater than 0, got {size}");
217        debug_assert!(
218            size.is_power_of_two(),
219            "size must be a power of 2, got {size}"
220        );
221
222        // Handle access adapter updates
223        // SAFETY: size passed is always a non-zero power of 2
224        let size_bits = unsafe { NonZero::new_unchecked(size).ilog2() };
225        self.memory_ctx
226            .update_adapter_heights(&mut self.trace_heights, address_space, size_bits);
227
228        // Handle merkle tree updates
229        if address_space != RV32_REGISTER_AS {
230            self.memory_ctx
231                .update_boundary_merkle_heights(address_space, ptr, size);
232        }
233    }
234
235    #[inline(always)]
236    fn should_suspend<F>(exec_state: &mut VmExecState<F, GuestMemory, Self>) -> bool {
237        // ATTENTION: Please make sure to update the corresponding logic in the
238        // `asm_bridge` crate and `aot.rs`` when you change this function.
239        // If `segment_suspend` is set, suspend when a segment is determined (but the VM state might
240        // be after the segment boundary because the segment happens in the previous checkpoint).
241        // Otherwise, execute until termination.
242        if exec_state.ctx.check_and_segment() && exec_state.ctx.suspend_on_segment {
243            true
244        } else {
245            exec_state.ctx.segmentation_ctx.instrets_until_check -= 1;
246            false
247        }
248    }
249
250    #[inline(always)]
251    fn on_terminate<F>(exec_state: &mut VmExecState<F, GuestMemory, Self>) {
252        exec_state
253            .ctx
254            .memory_ctx
255            .lazy_update_boundary_heights(&mut exec_state.ctx.trace_heights);
256        exec_state
257            .ctx
258            .segmentation_ctx
259            .create_final_segment(&exec_state.ctx.trace_heights);
260    }
261}
262
263impl<const PAGE_BITS: usize> MeteredExecutionCtxTrait for MeteredCtx<PAGE_BITS> {
264    #[inline(always)]
265    fn on_height_change(&mut self, chip_idx: usize, height_delta: u32) {
266        debug_assert!(
267            chip_idx < self.trace_heights.len(),
268            "chip_idx out of bounds"
269        );
270        // SAFETY: chip_idx is created in executor_idx_to_air_idx and is always within bounds
271        unsafe {
272            *self.trace_heights.get_unchecked_mut(chip_idx) = self
273                .trace_heights
274                .get_unchecked(chip_idx)
275                .wrapping_add(height_delta);
276        }
277    }
278}