openvm_circuit/arch/execution_mode/metered/
memory_ctx.rs

1use abi_stable::std_types::RVec;
2use openvm_instructions::riscv::{RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS};
3
4use crate::{arch::SystemConfig, system::memory::dimensions::MemoryDimensions};
5
6/// Upper bound on number of memory pages accessed per instruction. Used for buffer allocation.
7pub const MAX_MEM_PAGE_OPS_PER_INSN: usize = 1 << 16;
8
9#[derive(Clone, Debug)]
10pub struct BitSet {
11    words: Box<[u64]>,
12}
13
14impl BitSet {
15    pub fn new(num_bits: usize) -> Self {
16        Self {
17            words: vec![0; num_bits.div_ceil(u64::BITS as usize)].into_boxed_slice(),
18        }
19    }
20
21    #[inline(always)]
22    pub fn insert(&mut self, index: usize) -> bool {
23        let word_index = index >> 6;
24        let bit_index = index & 63;
25        let mask = 1u64 << bit_index;
26
27        debug_assert!(word_index < self.words.len(), "BitSet index out of bounds");
28
29        // SAFETY: word_index is derived from a memory address that is bounds-checked
30        //         during memory access. The bitset is sized to accommodate all valid
31        //         memory addresses, so word_index is always within bounds.
32        let word = unsafe { self.words.get_unchecked_mut(word_index) };
33        let was_set = (*word & mask) != 0;
34        *word |= mask;
35        !was_set
36    }
37
38    /// Set all bits within [start, end) to 1, return the number of flipped bits.
39    /// Assumes start < end and end <= self.words.len() * 64.
40    #[inline(always)]
41    pub fn insert_range(&mut self, start: usize, end: usize) -> usize {
42        debug_assert!(start < end);
43        debug_assert!(end <= self.words.len() * 64, "BitSet range out of bounds");
44
45        let mut ret = 0;
46        let start_word_index = start >> 6;
47        let end_word_index = (end - 1) >> 6;
48        let start_bit = (start & 63) as u32;
49
50        if start_word_index == end_word_index {
51            let end_bit = ((end - 1) & 63) as u32 + 1;
52            let mask_bits = end_bit - start_bit;
53            let mask = (u64::MAX >> (64 - mask_bits)) << start_bit;
54            // SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
55            // so start_word_index < self.words.len()
56            let word = unsafe { self.words.get_unchecked_mut(start_word_index) };
57            ret += mask_bits - (*word & mask).count_ones();
58            *word |= mask;
59        } else {
60            let end_bit = (end & 63) as u32;
61            let mask_bits = 64 - start_bit;
62            let mask = u64::MAX << start_bit;
63            // SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
64            // so start_word_index < self.words.len()
65            let start_word = unsafe { self.words.get_unchecked_mut(start_word_index) };
66            ret += mask_bits - (*start_word & mask).count_ones();
67            *start_word |= mask;
68
69            let mask_bits = end_bit;
70            let mask = if end_bit == 0 {
71                0
72            } else {
73                u64::MAX >> (64 - end_bit)
74            };
75            // SAFETY: Caller ensures end <= self.words.len() * 64, so
76            // end_word_index < self.words.len()
77            let end_word = unsafe { self.words.get_unchecked_mut(end_word_index) };
78            ret += mask_bits - (*end_word & mask).count_ones();
79            *end_word |= mask;
80        }
81
82        if start_word_index + 1 < end_word_index {
83            for i in (start_word_index + 1)..end_word_index {
84                // SAFETY: Caller ensures proper start and end, so i is within bounds
85                // of self.words.len()
86                let word = unsafe { self.words.get_unchecked_mut(i) };
87                ret += word.count_zeros();
88                *word = u64::MAX;
89            }
90        }
91        ret as usize
92    }
93
94    #[inline(always)]
95    pub fn clear(&mut self) {
96        // SAFETY: words is valid for self.words.len() elements
97        unsafe {
98            std::ptr::write_bytes(self.words.as_mut_ptr(), 0, self.words.len());
99        }
100    }
101}
102
103#[derive(Clone, Debug)]
104pub struct MemoryCtx<const PAGE_BITS: usize> {
105    memory_dimensions: MemoryDimensions,
106    min_block_size_bits: Vec<u8>,
107    pub boundary_idx: usize,
108    pub merkle_tree_index: Option<usize>,
109    pub adapter_offset: usize,
110    continuations_enabled: bool,
111    chunk: u32,
112    chunk_bits: u32,
113    pub page_indices: BitSet,
114    pub addr_space_access_count: RVec<u32>,
115    pub page_indices_since_checkpoint: Box<[u32]>,
116    pub page_indices_since_checkpoint_len: usize,
117}
118
119impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
120    pub fn new(config: &SystemConfig, segment_check_insns: u64) -> Self {
121        let chunk = config.initial_block_size() as u32;
122        let chunk_bits = chunk.ilog2();
123
124        let memory_dimensions = config.memory_config.memory_dimensions();
125        let merkle_height = memory_dimensions.overall_height();
126
127        let bitset_size = 1 << (merkle_height.saturating_sub(PAGE_BITS));
128        let addr_space_size = (1 << memory_dimensions.addr_space_height) + 1;
129        let page_indices_since_checkpoint_cap =
130            Self::calculate_checkpoint_capacity(segment_check_insns);
131
132        Self {
133            min_block_size_bits: config.memory_config.min_block_size_bits(),
134            boundary_idx: config.memory_boundary_air_id(),
135            merkle_tree_index: config.memory_merkle_air_id(),
136            adapter_offset: config.access_adapter_air_id_offset(),
137            chunk,
138            chunk_bits,
139            memory_dimensions,
140            continuations_enabled: config.continuation_enabled,
141            page_indices: BitSet::new(bitset_size),
142            addr_space_access_count: vec![0; addr_space_size].into(),
143            page_indices_since_checkpoint: vec![0; page_indices_since_checkpoint_cap]
144                .into_boxed_slice(),
145            page_indices_since_checkpoint_len: 0,
146        }
147    }
148
149    #[inline(always)]
150    pub(super) fn calculate_checkpoint_capacity(segment_check_insns: u64) -> usize {
151        segment_check_insns as usize * MAX_MEM_PAGE_OPS_PER_INSN
152    }
153
154    #[inline(always)]
155    pub(crate) fn add_register_merkle_heights(&mut self) {
156        if self.continuations_enabled {
157            self.update_boundary_merkle_heights(
158                RV32_REGISTER_AS,
159                0,
160                (RV32_NUM_REGISTERS * RV32_REGISTER_NUM_LIMBS) as u32,
161            );
162        }
163    }
164
165    /// For each memory access, record the minimal necessary data to update heights of
166    /// memory-related chips. The actual height updates happen during segment checks. The
167    /// implementation is in `lazy_update_boundary_heights`.
168    #[inline(always)]
169    pub(crate) fn update_boundary_merkle_heights(
170        &mut self,
171        address_space: u32,
172        ptr: u32,
173        size: u32,
174    ) {
175        debug_assert!((address_space as usize) < self.addr_space_access_count.len());
176
177        let num_blocks = (size + self.chunk - 1) >> self.chunk_bits;
178        let start_chunk_id = ptr >> self.chunk_bits;
179        let start_block_id = if self.chunk == 1 {
180            start_chunk_id
181        } else {
182            self.memory_dimensions
183                .label_to_index((address_space, start_chunk_id)) as u32
184        };
185        // Because `self.chunk == 1 << self.chunk_bits`
186        let end_block_id = start_block_id + num_blocks;
187        let start_page_id = start_block_id >> PAGE_BITS;
188        let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1;
189        assert!(
190            self.page_indices_since_checkpoint_len + (end_page_id - start_page_id) as usize
191                <= self.page_indices_since_checkpoint.len(),
192            "more than {MAX_MEM_PAGE_OPS_PER_INSN} memory pages accessed in a single instruction"
193        );
194
195        for page_id in start_page_id..end_page_id {
196            // Append page_id to page_indices_since_checkpoint
197            let len = self.page_indices_since_checkpoint_len;
198            debug_assert!(len < self.page_indices_since_checkpoint.len());
199            // SAFETY: len is within bounds, and we extend length by 1 after writing.
200            unsafe {
201                *self.page_indices_since_checkpoint.as_mut_ptr().add(len) = page_id;
202            }
203            self.page_indices_since_checkpoint_len = len + 1;
204
205            if self.page_indices.insert(page_id as usize) {
206                // SAFETY: address_space passed is usually a hardcoded constant or derived from an
207                // Instruction where it is bounds checked before passing
208                unsafe {
209                    *self
210                        .addr_space_access_count
211                        .get_unchecked_mut(address_space as usize) += 1;
212                }
213            }
214        }
215    }
216
217    #[inline(always)]
218    pub fn update_adapter_heights(
219        &mut self,
220        trace_heights: &mut [u32],
221        address_space: u32,
222        size_bits: u32,
223    ) {
224        self.update_adapter_heights_batch(trace_heights, address_space, size_bits, 1);
225    }
226
227    #[inline(always)]
228    pub fn update_adapter_heights_batch(
229        &self,
230        trace_heights: &mut [u32],
231        address_space: u32,
232        size_bits: u32,
233        num: u32,
234    ) {
235        debug_assert!((address_space as usize) < self.min_block_size_bits.len());
236
237        // SAFETY: address_space passed is usually a hardcoded constant or derived from an
238        // Instruction where it is bounds checked before passing
239        let align_bits = unsafe {
240            *self
241                .min_block_size_bits
242                .get_unchecked(address_space as usize)
243        };
244        debug_assert!(
245            align_bits as u32 <= size_bits,
246            "align_bits ({align_bits}) must be <= size_bits ({size_bits})"
247        );
248
249        for adapter_bits in (align_bits as u32 + 1..=size_bits).rev() {
250            let adapter_idx = self.adapter_offset + adapter_bits as usize - 1;
251            debug_assert!(adapter_idx < trace_heights.len());
252            // SAFETY: trace_heights is initialized taking access adapters into account
253            unsafe {
254                *trace_heights.get_unchecked_mut(adapter_idx) +=
255                    num << (size_bits - adapter_bits + 1);
256            }
257        }
258    }
259
260    /// Initialize state for a new segment
261    #[inline(always)]
262    pub(crate) fn initialize_segment(&mut self, trace_heights: &mut [u32]) {
263        // Clear page indices for the new segment
264        self.page_indices.clear();
265
266        // Reset trace heights for memory chips as 0
267        // SAFETY: boundary_idx is a compile time constant within bounds
268        unsafe {
269            *trace_heights.get_unchecked_mut(self.boundary_idx) = 0;
270        }
271        if let Some(merkle_tree_idx) = self.merkle_tree_index {
272            // SAFETY: merkle_tree_idx is guaranteed to be in bounds
273            unsafe {
274                *trace_heights.get_unchecked_mut(merkle_tree_idx) = 0;
275            }
276            let poseidon2_idx = trace_heights.len() - 2;
277            // SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds
278            unsafe {
279                *trace_heights.get_unchecked_mut(poseidon2_idx) = 0;
280            }
281        }
282
283        // Apply height updates for all pages accessed since last checkpoint, and
284        // initialize page_indices for the new segment.
285        let mut addr_space_access_count = vec![0; self.addr_space_access_count.len()];
286        let pages_len = self.page_indices_since_checkpoint_len;
287        for i in 0..pages_len {
288            // SAFETY: i is within 0..pages_len and pages_len is the slice length.
289            let page_id = unsafe { *self.page_indices_since_checkpoint.get_unchecked(i) } as usize;
290            if self.page_indices.insert(page_id) {
291                let (addr_space, _) = self
292                    .memory_dimensions
293                    .index_to_label((page_id as u64) << PAGE_BITS);
294                let addr_space_idx = addr_space as usize;
295                debug_assert!(addr_space_idx < addr_space_access_count.len());
296                // SAFETY: addr_space_idx is bounds checked in debug and derived from a valid page
297                // id.
298                unsafe {
299                    *addr_space_access_count.get_unchecked_mut(addr_space_idx) += 1;
300                }
301            }
302        }
303        self.apply_height_updates(trace_heights, &addr_space_access_count);
304
305        // Add merkle height contributions for all registers
306        self.add_register_merkle_heights();
307        self.lazy_update_boundary_heights(trace_heights);
308    }
309
310    /// Updates the checkpoint with current safe state
311    #[inline(always)]
312    pub(crate) fn update_checkpoint(&mut self) {
313        self.page_indices_since_checkpoint_len = 0;
314    }
315
316    /// Apply height updates given page counts
317    #[inline(always)]
318    fn apply_height_updates(&self, trace_heights: &mut [u32], addr_space_access_count: &[u32]) {
319        let page_access_count: u32 = addr_space_access_count.iter().sum();
320
321        // On page fault, assume we add all leaves in a page
322        let leaves = page_access_count << PAGE_BITS;
323        // SAFETY: boundary_idx is a compile time constant within bounds
324        unsafe {
325            *trace_heights.get_unchecked_mut(self.boundary_idx) += leaves;
326        }
327
328        if let Some(merkle_tree_idx) = self.merkle_tree_index {
329            debug_assert!(merkle_tree_idx < trace_heights.len());
330            debug_assert!(trace_heights.len() >= 2);
331
332            let poseidon2_idx = trace_heights.len() - 2;
333            // SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds
334            unsafe {
335                *trace_heights.get_unchecked_mut(poseidon2_idx) += leaves * 2;
336            }
337
338            let merkle_height = self.memory_dimensions.overall_height();
339            let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32;
340            // SAFETY: merkle_tree_idx is guaranteed to be in bounds
341            unsafe {
342                *trace_heights.get_unchecked_mut(poseidon2_idx) += nodes * page_access_count * 2;
343                *trace_heights.get_unchecked_mut(merkle_tree_idx) += nodes * page_access_count * 2;
344            }
345        }
346
347        for address_space in 0..addr_space_access_count.len() {
348            // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
349            let x = unsafe { *addr_space_access_count.get_unchecked(address_space) };
350            if x > 0 {
351                // Initial **and** final handling of touched pages requires send (resp. receive) in
352                // chunk-sized units for the merkle chip
353                // Corresponds to `handle_uninitialized_memory` and `handle_touched_blocks` in
354                // online.rs
355                self.update_adapter_heights_batch(
356                    trace_heights,
357                    address_space as u32,
358                    self.chunk_bits,
359                    x << (PAGE_BITS + 1),
360                );
361            }
362        }
363    }
364
365    /// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
366    #[inline(always)]
367    pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) {
368        self.apply_height_updates(trace_heights, &self.addr_space_access_count);
369        // SAFETY: Resetting array elements to 0 is always safe
370        unsafe {
371            std::ptr::write_bytes(
372                self.addr_space_access_count.as_mut_ptr(),
373                0,
374                self.addr_space_access_count.len(),
375            );
376        }
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    #[test]
384    fn test_bitset_insert_range() {
385        // 513 bits
386        let mut bit_set = BitSet::new(8 * 64 + 1);
387        let num_flips = bit_set.insert_range(2, 29);
388        assert_eq!(num_flips, 27);
389        let num_flips = bit_set.insert_range(1, 31);
390        assert_eq!(num_flips, 3);
391
392        let num_flips = bit_set.insert_range(32, 65);
393        assert_eq!(num_flips, 33);
394        let num_flips = bit_set.insert_range(0, 66);
395        assert_eq!(num_flips, 3);
396        let num_flips = bit_set.insert_range(0, 66);
397        assert_eq!(num_flips, 0);
398
399        let num_flips = bit_set.insert_range(256, 320);
400        assert_eq!(num_flips, 64);
401        let num_flips = bit_set.insert_range(256, 377);
402        assert_eq!(num_flips, 57);
403        let num_flips = bit_set.insert_range(100, 513);
404        assert_eq!(num_flips, 413 - 121);
405    }
406}