openvm_circuit/arch/execution_mode/metered/
memory_ctx.rs

1use openvm_instructions::riscv::{RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS};
2
3use crate::{arch::SystemConfig, system::memory::dimensions::MemoryDimensions};
4
5#[derive(Clone, Debug)]
6pub struct BitSet {
7    words: Box<[u64]>,
8}
9
10impl BitSet {
11    pub fn new(num_bits: usize) -> Self {
12        Self {
13            words: vec![0; num_bits.div_ceil(u64::BITS as usize)].into_boxed_slice(),
14        }
15    }
16
17    #[inline(always)]
18    pub fn insert(&mut self, index: usize) -> bool {
19        let word_index = index >> 6;
20        let bit_index = index & 63;
21        let mask = 1u64 << bit_index;
22
23        debug_assert!(word_index < self.words.len(), "BitSet index out of bounds");
24
25        // SAFETY: word_index is derived from a memory address that is bounds-checked
26        //         during memory access. The bitset is sized to accommodate all valid
27        //         memory addresses, so word_index is always within bounds.
28        let word = unsafe { self.words.get_unchecked_mut(word_index) };
29        let was_set = (*word & mask) != 0;
30        *word |= mask;
31        !was_set
32    }
33
34    /// Set all bits within [start, end) to 1, return the number of flipped bits.
35    /// Assumes start < end and end <= self.words.len() * 64.
36    #[inline(always)]
37    pub fn insert_range(&mut self, start: usize, end: usize) -> usize {
38        debug_assert!(start < end);
39        debug_assert!(end <= self.words.len() * 64, "BitSet range out of bounds");
40
41        let mut ret = 0;
42        let start_word_index = start >> 6;
43        let end_word_index = (end - 1) >> 6;
44        let start_bit = (start & 63) as u32;
45
46        if start_word_index == end_word_index {
47            let end_bit = ((end - 1) & 63) as u32 + 1;
48            let mask_bits = end_bit - start_bit;
49            let mask = (u64::MAX >> (64 - mask_bits)) << start_bit;
50            // SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
51            // so start_word_index < self.words.len()
52            let word = unsafe { self.words.get_unchecked_mut(start_word_index) };
53            ret += mask_bits - (*word & mask).count_ones();
54            *word |= mask;
55        } else {
56            let end_bit = (end & 63) as u32;
57            let mask_bits = 64 - start_bit;
58            let mask = u64::MAX << start_bit;
59            // SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
60            // so start_word_index < self.words.len()
61            let start_word = unsafe { self.words.get_unchecked_mut(start_word_index) };
62            ret += mask_bits - (*start_word & mask).count_ones();
63            *start_word |= mask;
64
65            let mask_bits = end_bit;
66            let mask = if end_bit == 0 {
67                0
68            } else {
69                u64::MAX >> (64 - end_bit)
70            };
71            // SAFETY: Caller ensures end <= self.words.len() * 64, so
72            // end_word_index < self.words.len()
73            let end_word = unsafe { self.words.get_unchecked_mut(end_word_index) };
74            ret += mask_bits - (*end_word & mask).count_ones();
75            *end_word |= mask;
76        }
77
78        if start_word_index + 1 < end_word_index {
79            for i in (start_word_index + 1)..end_word_index {
80                // SAFETY: Caller ensures proper start and end, so i is within bounds
81                // of self.words.len()
82                let word = unsafe { self.words.get_unchecked_mut(i) };
83                ret += word.count_zeros();
84                *word = u64::MAX;
85            }
86        }
87        ret as usize
88    }
89
90    #[inline(always)]
91    pub fn clear(&mut self) {
92        // SAFETY: words is valid for self.words.len() elements
93        unsafe {
94            std::ptr::write_bytes(self.words.as_mut_ptr(), 0, self.words.len());
95        }
96    }
97}
98
99#[derive(Clone, Debug)]
100pub struct MemoryCtx<const PAGE_BITS: usize> {
101    pub page_indices: BitSet,
102    memory_dimensions: MemoryDimensions,
103    min_block_size_bits: Vec<u8>,
104    pub boundary_idx: usize,
105    pub merkle_tree_index: Option<usize>,
106    pub adapter_offset: usize,
107    continuations_enabled: bool,
108    chunk: u32,
109    chunk_bits: u32,
110    page_access_count: usize,
111    // Note: 32 is the maximum access adapter size.
112    addr_space_access_count: Vec<usize>,
113}
114
115impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
116    pub fn new(config: &SystemConfig) -> Self {
117        let chunk = config.initial_block_size() as u32;
118        let chunk_bits = chunk.ilog2();
119
120        let memory_dimensions = config.memory_config.memory_dimensions();
121        let merkle_height = memory_dimensions.overall_height();
122
123        Self {
124            // Address height already considers `chunk_bits`.
125            page_indices: BitSet::new(1 << (merkle_height.saturating_sub(PAGE_BITS))),
126            min_block_size_bits: config.memory_config.min_block_size_bits(),
127            boundary_idx: config.memory_boundary_air_id(),
128            merkle_tree_index: config.memory_merkle_air_id(),
129            adapter_offset: config.access_adapter_air_id_offset(),
130            chunk,
131            chunk_bits,
132            memory_dimensions,
133            continuations_enabled: config.continuation_enabled,
134            page_access_count: 0,
135            addr_space_access_count: vec![0; (1 << memory_dimensions.addr_space_height) + 1],
136        }
137    }
138
139    #[inline(always)]
140    pub fn clear(&mut self) {
141        self.page_indices.clear();
142    }
143
144    #[inline(always)]
145    pub(crate) fn add_register_merkle_heights(&mut self) {
146        if self.continuations_enabled {
147            self.update_boundary_merkle_heights(
148                RV32_REGISTER_AS,
149                0,
150                (RV32_NUM_REGISTERS * RV32_REGISTER_NUM_LIMBS) as u32,
151            );
152        }
153    }
154
155    /// For each memory access, record the minimal necessary data to update heights of
156    /// memory-related chips. The actual height updates happen during segment checks. The
157    /// implementation is in `lazy_update_boundary_heights`.
158    #[inline(always)]
159    pub(crate) fn update_boundary_merkle_heights(
160        &mut self,
161        address_space: u32,
162        ptr: u32,
163        size: u32,
164    ) {
165        debug_assert!((address_space as usize) < self.addr_space_access_count.len());
166
167        let num_blocks = (size + self.chunk - 1) >> self.chunk_bits;
168        let start_chunk_id = ptr >> self.chunk_bits;
169        let start_block_id = if self.chunk == 1 {
170            start_chunk_id
171        } else {
172            self.memory_dimensions
173                .label_to_index((address_space, start_chunk_id)) as u32
174        };
175        // Because `self.chunk == 1 << self.chunk_bits`
176        let end_block_id = start_block_id + num_blocks;
177        let start_page_id = start_block_id >> PAGE_BITS;
178        let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1;
179
180        for page_id in start_page_id..end_page_id {
181            if self.page_indices.insert(page_id as usize) {
182                self.page_access_count += 1;
183                // SAFETY: address_space passed is usually a hardcoded constant or derived from an
184                // Instruction where it is bounds checked before passing
185                unsafe {
186                    *self
187                        .addr_space_access_count
188                        .get_unchecked_mut(address_space as usize) += 1;
189                }
190            }
191        }
192    }
193
194    #[inline(always)]
195    pub fn update_adapter_heights(
196        &mut self,
197        trace_heights: &mut [u32],
198        address_space: u32,
199        size_bits: u32,
200    ) {
201        self.update_adapter_heights_batch(trace_heights, address_space, size_bits, 1);
202    }
203
204    #[inline(always)]
205    pub fn update_adapter_heights_batch(
206        &self,
207        trace_heights: &mut [u32],
208        address_space: u32,
209        size_bits: u32,
210        num: u32,
211    ) {
212        debug_assert!((address_space as usize) < self.min_block_size_bits.len());
213
214        // SAFETY: address_space passed is usually a hardcoded constant or derived from an
215        // Instruction where it is bounds checked before passing
216        let align_bits = unsafe {
217            *self
218                .min_block_size_bits
219                .get_unchecked(address_space as usize)
220        };
221        debug_assert!(
222            align_bits as u32 <= size_bits,
223            "align_bits ({}) must be <= size_bits ({})",
224            align_bits,
225            size_bits
226        );
227
228        for adapter_bits in (align_bits as u32 + 1..=size_bits).rev() {
229            let adapter_idx = self.adapter_offset + adapter_bits as usize - 1;
230            debug_assert!(adapter_idx < trace_heights.len());
231            // SAFETY: trace_heights is initialized taking access adapters into account
232            unsafe {
233                *trace_heights.get_unchecked_mut(adapter_idx) +=
234                    num << (size_bits - adapter_bits + 1);
235            }
236        }
237    }
238
239    /// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
240    #[inline(always)]
241    pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) {
242        debug_assert!(self.boundary_idx < trace_heights.len());
243
244        // On page fault, assume we add all leaves in a page
245        let leaves = (self.page_access_count << PAGE_BITS) as u32;
246        // SAFETY: boundary_idx is a compile time constant within bounds
247        unsafe {
248            *trace_heights.get_unchecked_mut(self.boundary_idx) += leaves;
249        }
250
251        if let Some(merkle_tree_idx) = self.merkle_tree_index {
252            debug_assert!(merkle_tree_idx < trace_heights.len());
253            debug_assert!(trace_heights.len() >= 2);
254
255            let poseidon2_idx = trace_heights.len() - 2;
256            // SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds
257            unsafe {
258                *trace_heights.get_unchecked_mut(poseidon2_idx) += leaves * 2;
259            }
260
261            let merkle_height = self.memory_dimensions.overall_height();
262            let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32;
263            // SAFETY: merkle_tree_idx is guaranteed to be in bounds
264            unsafe {
265                *trace_heights.get_unchecked_mut(poseidon2_idx) += nodes * 2;
266                *trace_heights.get_unchecked_mut(merkle_tree_idx) += nodes * 2;
267            }
268        }
269        self.page_access_count = 0;
270
271        for address_space in 0..self.addr_space_access_count.len() {
272            // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
273            let x = unsafe { *self.addr_space_access_count.get_unchecked(address_space) };
274            if x > 0 {
275                // After finalize, we'll need to read it in chunk-sized units for the merkle chip
276                self.update_adapter_heights_batch(
277                    trace_heights,
278                    address_space as u32,
279                    self.chunk_bits,
280                    (x << PAGE_BITS) as u32,
281                );
282                // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
283                unsafe {
284                    *self
285                        .addr_space_access_count
286                        .get_unchecked_mut(address_space) = 0;
287                }
288            }
289        }
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    #[test]
297    fn test_bitset_insert_range() {
298        // 513 bits
299        let mut bit_set = BitSet::new(8 * 64 + 1);
300        let num_flips = bit_set.insert_range(2, 29);
301        assert_eq!(num_flips, 27);
302        let num_flips = bit_set.insert_range(1, 31);
303        assert_eq!(num_flips, 3);
304
305        let num_flips = bit_set.insert_range(32, 65);
306        assert_eq!(num_flips, 33);
307        let num_flips = bit_set.insert_range(0, 66);
308        assert_eq!(num_flips, 3);
309        let num_flips = bit_set.insert_range(0, 66);
310        assert_eq!(num_flips, 0);
311
312        let num_flips = bit_set.insert_range(256, 320);
313        assert_eq!(num_flips, 64);
314        let num_flips = bit_set.insert_range(256, 377);
315        assert_eq!(num_flips, 57);
316        let num_flips = bit_set.insert_range(100, 513);
317        assert_eq!(num_flips, 413 - 121);
318    }
319}