openvm_circuit/arch/execution_mode/metered/
ctx.rs

1use std::num::NonZero;
2
3use getset::{Getters, Setters, WithSetters};
4use itertools::Itertools;
5use openvm_instructions::riscv::{RV32_IMM_AS, RV32_REGISTER_AS};
6
7use super::{
8    memory_ctx::MemoryCtx,
9    segment_ctx::{Segment, SegmentationCtx},
10};
11use crate::{
12    arch::{
13        execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait},
14        SystemConfig, VmExecState,
15    },
16    system::memory::online::GuestMemory,
17};
18
19pub const DEFAULT_PAGE_BITS: usize = 6;
20
21#[derive(Clone, Debug, Getters, Setters, WithSetters)]
22pub struct MeteredCtx<const PAGE_BITS: usize = DEFAULT_PAGE_BITS> {
23    pub trace_heights: Vec<u32>,
24    pub is_trace_height_constant: Vec<bool>,
25    pub memory_ctx: MemoryCtx<PAGE_BITS>,
26    pub segmentation_ctx: SegmentationCtx,
27    #[getset(get = "pub", set = "pub", set_with = "pub")]
28    suspend_on_segment: bool,
29}
30
31impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
32    // Note[jpw]: prefer to use `build_metered_ctx` in `VmExecutor` or `VirtualMachine`.
33    pub fn new(
34        constant_trace_heights: Vec<Option<usize>>,
35        air_names: Vec<String>,
36        widths: Vec<usize>,
37        interactions: Vec<usize>,
38        config: &SystemConfig,
39    ) -> Self {
40        let (trace_heights, is_trace_height_constant): (Vec<u32>, Vec<bool>) =
41            constant_trace_heights
42                .iter()
43                .map(|&constant_height| {
44                    if let Some(height) = constant_height {
45                        (height as u32, true)
46                    } else {
47                        (0, false)
48                    }
49                })
50                .unzip();
51
52        let memory_ctx = MemoryCtx::new(config);
53
54        // Assert that the indices are correct
55        debug_assert!(
56            air_names[memory_ctx.boundary_idx].contains("Boundary"),
57            "air_name={}",
58            air_names[memory_ctx.boundary_idx]
59        );
60        if let Some(merkle_tree_index) = memory_ctx.merkle_tree_index {
61            debug_assert!(
62                air_names[merkle_tree_index].contains("Merkle"),
63                "air_name={}",
64                air_names[merkle_tree_index]
65            );
66        }
67        debug_assert!(
68            air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"),
69            "air_name={}",
70            air_names[memory_ctx.adapter_offset]
71        );
72
73        let segmentation_ctx =
74            SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits);
75
76        let mut ctx = Self {
77            trace_heights,
78            is_trace_height_constant,
79            memory_ctx,
80            segmentation_ctx,
81            suspend_on_segment: false,
82        };
83        if !config.continuation_enabled {
84            // force single segment
85            ctx.segmentation_ctx.segment_check_insns = u64::MAX;
86        }
87
88        // Add merkle height contributions for all registers
89        ctx.memory_ctx.add_register_merkle_heights();
90
91        ctx
92    }
93
94    pub fn with_max_trace_height(mut self, max_trace_height: u32) -> Self {
95        self.segmentation_ctx.set_max_trace_height(max_trace_height);
96        let max_check_freq = (max_trace_height / 2) as u64;
97        if max_check_freq < self.segmentation_ctx.segment_check_insns {
98            self.segmentation_ctx.segment_check_insns = max_check_freq;
99        }
100        self
101    }
102
103    pub fn with_max_cells(mut self, max_cells: usize) -> Self {
104        self.segmentation_ctx.set_max_cells(max_cells);
105        self
106    }
107
108    pub fn with_max_interactions(mut self, max_interactions: usize) -> Self {
109        self.segmentation_ctx.set_max_interactions(max_interactions);
110        self
111    }
112
113    pub fn segments(&self) -> &[Segment] {
114        &self.segmentation_ctx.segments
115    }
116
117    pub fn into_segments(self) -> Vec<Segment> {
118        self.segmentation_ctx.segments
119    }
120
121    fn reset_segment(&mut self) {
122        self.memory_ctx.clear();
123        // Add merkle height contributions for all registers
124        self.memory_ctx.add_register_merkle_heights();
125    }
126
127    #[inline(always)]
128    pub fn check_and_segment(&mut self, instret: u64, segment_check_insns: u64) -> bool {
129        let threshold = self
130            .segmentation_ctx
131            .instret_last_segment_check
132            .wrapping_add(segment_check_insns);
133        debug_assert!(
134            threshold >= self.segmentation_ctx.instret_last_segment_check,
135            "overflow in segment check threshold calculation"
136        );
137        if instret < threshold {
138            return false;
139        }
140
141        self.memory_ctx
142            .lazy_update_boundary_heights(&mut self.trace_heights);
143        let did_segment = self.segmentation_ctx.check_and_segment(
144            instret,
145            &mut self.trace_heights,
146            &self.is_trace_height_constant,
147        );
148
149        if did_segment {
150            self.reset_segment();
151        }
152        did_segment
153    }
154
155    #[allow(dead_code)]
156    pub fn print_segment(&self) {
157        println!("{}", "-".repeat(80));
158        println!("Segment {}", self.segmentation_ctx.segments.len() - 1);
159        println!("{}", "-".repeat(80));
160        println!("{:>10} {:>10} {:<30}", "Width", "Height", "Air Name");
161        println!("{}", "-".repeat(80));
162        for ((&width, &height), air_name) in self
163            .segmentation_ctx
164            .widths
165            .iter()
166            .zip_eq(self.trace_heights.iter())
167            .zip_eq(self.segmentation_ctx.air_names.iter())
168        {
169            println!("{:>10} {:>10} {:<30}", width, height, air_name.as_str());
170        }
171    }
172}
173
174impl<const PAGE_BITS: usize> ExecutionCtxTrait for MeteredCtx<PAGE_BITS> {
175    #[inline(always)]
176    fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32) {
177        debug_assert!(
178            address_space != RV32_IMM_AS,
179            "address space must not be immediate"
180        );
181        debug_assert!(size > 0, "size must be greater than 0, got {}", size);
182        debug_assert!(
183            size.is_power_of_two(),
184            "size must be a power of 2, got {}",
185            size
186        );
187
188        // Handle access adapter updates
189        // SAFETY: size passed is always a non-zero power of 2
190        let size_bits = unsafe { NonZero::new_unchecked(size).ilog2() };
191        self.memory_ctx
192            .update_adapter_heights(&mut self.trace_heights, address_space, size_bits);
193
194        // Handle merkle tree updates
195        if address_space != RV32_REGISTER_AS {
196            self.memory_ctx
197                .update_boundary_merkle_heights(address_space, ptr, size);
198        }
199    }
200
201    #[inline(always)]
202    fn should_suspend<F>(
203        instret: u64,
204        _pc: u32,
205        segment_check_insns: u64,
206        exec_state: &mut VmExecState<F, GuestMemory, Self>,
207    ) -> bool {
208        // If `segment_suspend` is set, suspend when a segment is determined (but the VM state might
209        // be after the segment boundary because the segment happens in the previous checkpoint).
210        // Otherwise, execute until termination.
211        exec_state
212            .ctx
213            .check_and_segment(instret, segment_check_insns)
214            && exec_state.ctx.suspend_on_segment
215    }
216
217    #[inline(always)]
218    fn on_terminate<F>(instret: u64, _pc: u32, exec_state: &mut VmExecState<F, GuestMemory, Self>) {
219        exec_state
220            .ctx
221            .memory_ctx
222            .lazy_update_boundary_heights(&mut exec_state.ctx.trace_heights);
223        exec_state
224            .ctx
225            .segmentation_ctx
226            .create_final_segment(instret, &exec_state.ctx.trace_heights);
227    }
228}
229
230impl<const PAGE_BITS: usize> MeteredExecutionCtxTrait for MeteredCtx<PAGE_BITS> {
231    #[inline(always)]
232    fn on_height_change(&mut self, chip_idx: usize, height_delta: u32) {
233        debug_assert!(
234            chip_idx < self.trace_heights.len(),
235            "chip_idx out of bounds"
236        );
237        // SAFETY: chip_idx is created in executor_idx_to_air_idx and is always within bounds
238        unsafe {
239            *self.trace_heights.get_unchecked_mut(chip_idx) = self
240                .trace_heights
241                .get_unchecked(chip_idx)
242                .wrapping_add(height_delta);
243        }
244    }
245}