openvm_circuit/arch/execution_mode/metered/
ctx.rs

1use std::num::NonZero;
2
3use openvm_instructions::riscv::{RV32_IMM_AS, RV32_REGISTER_AS};
4
5use super::{
6    memory_ctx::MemoryCtx,
7    segment_ctx::{Segment, SegmentationCtx},
8};
9use crate::{
10    arch::{
11        execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait},
12        SystemConfig, VmExecState,
13    },
14    system::memory::online::GuestMemory,
15};
16
17pub const DEFAULT_PAGE_BITS: usize = 6;
18
19#[derive(Clone, Debug)]
20pub struct MeteredCtx<const PAGE_BITS: usize = DEFAULT_PAGE_BITS> {
21    pub trace_heights: Vec<u32>,
22    pub is_trace_height_constant: Vec<bool>,
23    pub memory_ctx: MemoryCtx<PAGE_BITS>,
24    pub segmentation_ctx: SegmentationCtx,
25}
26
27impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
28    // Note[jpw]: prefer to use `build_metered_ctx` in `VmExecutor` or `VirtualMachine`.
29    pub fn new(
30        constant_trace_heights: Vec<Option<usize>>,
31        air_names: Vec<String>,
32        widths: Vec<usize>,
33        interactions: Vec<usize>,
34        config: &SystemConfig,
35    ) -> Self {
36        let (trace_heights, is_trace_height_constant): (Vec<u32>, Vec<bool>) =
37            constant_trace_heights
38                .iter()
39                .map(|&constant_height| {
40                    if let Some(height) = constant_height {
41                        (height as u32, true)
42                    } else {
43                        (0, false)
44                    }
45                })
46                .unzip();
47
48        let memory_ctx = MemoryCtx::new(config);
49
50        // Assert that the indices are correct
51        debug_assert!(
52            air_names[memory_ctx.boundary_idx].contains("Boundary"),
53            "air_name={}",
54            air_names[memory_ctx.boundary_idx]
55        );
56        if let Some(merkle_tree_index) = memory_ctx.merkle_tree_index {
57            debug_assert!(
58                air_names[merkle_tree_index].contains("Merkle"),
59                "air_name={}",
60                air_names[merkle_tree_index]
61            );
62        }
63        debug_assert!(
64            air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"),
65            "air_name={}",
66            air_names[memory_ctx.adapter_offset]
67        );
68
69        let segmentation_ctx =
70            SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits);
71
72        let mut ctx = Self {
73            trace_heights,
74            is_trace_height_constant,
75            memory_ctx,
76            segmentation_ctx,
77        };
78        if !config.continuation_enabled {
79            // force single segment
80            ctx.segmentation_ctx.segment_check_insns = u64::MAX;
81        }
82
83        // Add merkle height contributions for all registers
84        ctx.memory_ctx.add_register_merkle_heights();
85
86        ctx
87    }
88
89    pub fn with_max_trace_height(mut self, max_trace_height: u32) -> Self {
90        self.segmentation_ctx.set_max_trace_height(max_trace_height);
91        let max_check_freq = (max_trace_height / 2) as u64;
92        if max_check_freq < self.segmentation_ctx.segment_check_insns {
93            self.segmentation_ctx.segment_check_insns = max_check_freq;
94        }
95        self
96    }
97
98    pub fn with_max_cells(mut self, max_cells: usize) -> Self {
99        self.segmentation_ctx.set_max_cells(max_cells);
100        self
101    }
102
103    pub fn with_max_interactions(mut self, max_interactions: usize) -> Self {
104        self.segmentation_ctx.set_max_interactions(max_interactions);
105        self
106    }
107
108    pub fn segments(&self) -> &[Segment] {
109        &self.segmentation_ctx.segments
110    }
111
112    pub fn into_segments(self) -> Vec<Segment> {
113        self.segmentation_ctx.segments
114    }
115
116    fn reset_segment(&mut self) {
117        self.memory_ctx.clear();
118        for (i, &is_constant) in self.is_trace_height_constant.iter().enumerate() {
119            if !is_constant {
120                self.trace_heights[i] = 0;
121            }
122        }
123        // Add merkle height contributions for all registers
124        self.memory_ctx.add_register_merkle_heights();
125    }
126
127    #[inline(always)]
128    pub fn check_and_segment(&mut self, instret: u64) {
129        let threshold = self
130            .segmentation_ctx
131            .instret_last_segment_check
132            .wrapping_add(self.segmentation_ctx.segment_check_insns);
133        debug_assert!(
134            threshold >= self.segmentation_ctx.instret_last_segment_check,
135            "overflow in segment check threshold calculation"
136        );
137        if instret < threshold {
138            return;
139        }
140
141        self.memory_ctx
142            .lazy_update_boundary_heights(&mut self.trace_heights);
143        let did_segment = self.segmentation_ctx.check_and_segment(
144            instret,
145            &self.trace_heights,
146            &self.is_trace_height_constant,
147        );
148
149        if did_segment {
150            self.reset_segment();
151        }
152    }
153
154    #[allow(dead_code)]
155    pub fn print_heights(&self) {
156        println!("{:>10} {:<30}", "Height", "Air Name");
157        println!("{}", "-".repeat(42));
158        for (i, height) in self.trace_heights.iter().enumerate() {
159            let air_name = self
160                .segmentation_ctx
161                .air_names
162                .get(i)
163                .map(|s| s.as_str())
164                .unwrap_or("Unknown");
165            println!("{:>10} {:<30}", height, air_name);
166        }
167    }
168}
169
170impl<const PAGE_BITS: usize> ExecutionCtxTrait for MeteredCtx<PAGE_BITS> {
171    #[inline(always)]
172    fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32) {
173        debug_assert!(
174            address_space != RV32_IMM_AS,
175            "address space must not be immediate"
176        );
177        debug_assert!(size > 0, "size must be greater than 0, got {}", size);
178        debug_assert!(
179            size.is_power_of_two(),
180            "size must be a power of 2, got {}",
181            size
182        );
183
184        // Handle access adapter updates
185        // SAFETY: size passed is always a non-zero power of 2
186        let size_bits = unsafe { NonZero::new_unchecked(size).ilog2() };
187        self.memory_ctx
188            .update_adapter_heights(&mut self.trace_heights, address_space, size_bits);
189
190        // Handle merkle tree updates
191        if address_space != RV32_REGISTER_AS {
192            self.memory_ctx
193                .update_boundary_merkle_heights(address_space, ptr, size);
194        }
195    }
196
197    #[inline(always)]
198    fn should_suspend<F>(vm_state: &mut VmExecState<F, GuestMemory, Self>) -> bool {
199        // E2 always runs until termination. Here we use the function as a hook called every
200        // instruction.
201        vm_state.ctx.check_and_segment(vm_state.instret);
202        false
203    }
204
205    #[inline(always)]
206    fn on_terminate<F>(vm_state: &mut VmExecState<F, GuestMemory, Self>) {
207        vm_state
208            .ctx
209            .memory_ctx
210            .lazy_update_boundary_heights(&mut vm_state.ctx.trace_heights);
211        vm_state
212            .ctx
213            .segmentation_ctx
214            .segment(vm_state.instret, &vm_state.ctx.trace_heights);
215    }
216}
217
218impl<const PAGE_BITS: usize> MeteredExecutionCtxTrait for MeteredCtx<PAGE_BITS> {
219    #[inline(always)]
220    fn on_height_change(&mut self, chip_idx: usize, height_delta: u32) {
221        debug_assert!(
222            chip_idx < self.trace_heights.len(),
223            "chip_idx out of bounds"
224        );
225        // SAFETY: chip_idx is created in executor_idx_to_air_idx and is always within bounds
226        unsafe {
227            *self.trace_heights.get_unchecked_mut(chip_idx) = self
228                .trace_heights
229                .get_unchecked(chip_idx)
230                .wrapping_add(height_delta);
231        }
232    }
233}