openvm_circuit/arch/execution_mode/metered/
memory_ctx.rs1use openvm_instructions::riscv::{RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS};
2
3use crate::{arch::SystemConfig, system::memory::dimensions::MemoryDimensions};
4
5#[derive(Clone, Debug)]
6pub struct BitSet {
7 words: Box<[u64]>,
8}
9
10impl BitSet {
11 pub fn new(num_bits: usize) -> Self {
12 Self {
13 words: vec![0; num_bits.div_ceil(u64::BITS as usize)].into_boxed_slice(),
14 }
15 }
16
17 #[inline(always)]
18 pub fn insert(&mut self, index: usize) -> bool {
19 let word_index = index >> 6;
20 let bit_index = index & 63;
21 let mask = 1u64 << bit_index;
22
23 debug_assert!(word_index < self.words.len(), "BitSet index out of bounds");
24
25 let word = unsafe { self.words.get_unchecked_mut(word_index) };
29 let was_set = (*word & mask) != 0;
30 *word |= mask;
31 !was_set
32 }
33
34 #[inline(always)]
37 pub fn insert_range(&mut self, start: usize, end: usize) -> usize {
38 debug_assert!(start < end);
39 debug_assert!(end <= self.words.len() * 64, "BitSet range out of bounds");
40
41 let mut ret = 0;
42 let start_word_index = start >> 6;
43 let end_word_index = (end - 1) >> 6;
44 let start_bit = (start & 63) as u32;
45
46 if start_word_index == end_word_index {
47 let end_bit = ((end - 1) & 63) as u32 + 1;
48 let mask_bits = end_bit - start_bit;
49 let mask = (u64::MAX >> (64 - mask_bits)) << start_bit;
50 let word = unsafe { self.words.get_unchecked_mut(start_word_index) };
53 ret += mask_bits - (*word & mask).count_ones();
54 *word |= mask;
55 } else {
56 let end_bit = (end & 63) as u32;
57 let mask_bits = 64 - start_bit;
58 let mask = u64::MAX << start_bit;
59 let start_word = unsafe { self.words.get_unchecked_mut(start_word_index) };
62 ret += mask_bits - (*start_word & mask).count_ones();
63 *start_word |= mask;
64
65 let mask_bits = end_bit;
66 let mask = if end_bit == 0 {
67 0
68 } else {
69 u64::MAX >> (64 - end_bit)
70 };
71 let end_word = unsafe { self.words.get_unchecked_mut(end_word_index) };
74 ret += mask_bits - (*end_word & mask).count_ones();
75 *end_word |= mask;
76 }
77
78 if start_word_index + 1 < end_word_index {
79 for i in (start_word_index + 1)..end_word_index {
80 let word = unsafe { self.words.get_unchecked_mut(i) };
83 ret += word.count_zeros();
84 *word = u64::MAX;
85 }
86 }
87 ret as usize
88 }
89
90 #[inline(always)]
91 pub fn clear(&mut self) {
92 unsafe {
94 std::ptr::write_bytes(self.words.as_mut_ptr(), 0, self.words.len());
95 }
96 }
97}
98
99#[derive(Clone, Debug)]
100pub struct MemoryCtx<const PAGE_BITS: usize> {
101 pub page_indices: BitSet,
102 memory_dimensions: MemoryDimensions,
103 min_block_size_bits: Vec<u8>,
104 pub boundary_idx: usize,
105 pub merkle_tree_index: Option<usize>,
106 pub adapter_offset: usize,
107 continuations_enabled: bool,
108 chunk: u32,
109 chunk_bits: u32,
110 page_access_count: usize,
111 addr_space_access_count: Vec<usize>,
113}
114
115impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
116 pub fn new(config: &SystemConfig) -> Self {
117 let chunk = config.initial_block_size() as u32;
118 let chunk_bits = chunk.ilog2();
119
120 let memory_dimensions = config.memory_config.memory_dimensions();
121 let merkle_height = memory_dimensions.overall_height();
122
123 Self {
124 page_indices: BitSet::new(1 << (merkle_height.saturating_sub(PAGE_BITS))),
126 min_block_size_bits: config.memory_config.min_block_size_bits(),
127 boundary_idx: config.memory_boundary_air_id(),
128 merkle_tree_index: config.memory_merkle_air_id(),
129 adapter_offset: config.access_adapter_air_id_offset(),
130 chunk,
131 chunk_bits,
132 memory_dimensions,
133 continuations_enabled: config.continuation_enabled,
134 page_access_count: 0,
135 addr_space_access_count: vec![0; (1 << memory_dimensions.addr_space_height) + 1],
136 }
137 }
138
139 #[inline(always)]
140 pub fn clear(&mut self) {
141 self.page_indices.clear();
142 }
143
144 #[inline(always)]
145 pub(crate) fn add_register_merkle_heights(&mut self) {
146 if self.continuations_enabled {
147 self.update_boundary_merkle_heights(
148 RV32_REGISTER_AS,
149 0,
150 (RV32_NUM_REGISTERS * RV32_REGISTER_NUM_LIMBS) as u32,
151 );
152 }
153 }
154
155 #[inline(always)]
159 pub(crate) fn update_boundary_merkle_heights(
160 &mut self,
161 address_space: u32,
162 ptr: u32,
163 size: u32,
164 ) {
165 debug_assert!((address_space as usize) < self.addr_space_access_count.len());
166
167 let num_blocks = (size + self.chunk - 1) >> self.chunk_bits;
168 let start_chunk_id = ptr >> self.chunk_bits;
169 let start_block_id = if self.chunk == 1 {
170 start_chunk_id
171 } else {
172 self.memory_dimensions
173 .label_to_index((address_space, start_chunk_id)) as u32
174 };
175 let end_block_id = start_block_id + num_blocks;
177 let start_page_id = start_block_id >> PAGE_BITS;
178 let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1;
179
180 for page_id in start_page_id..end_page_id {
181 if self.page_indices.insert(page_id as usize) {
182 self.page_access_count += 1;
183 unsafe {
186 *self
187 .addr_space_access_count
188 .get_unchecked_mut(address_space as usize) += 1;
189 }
190 }
191 }
192 }
193
194 #[inline(always)]
195 pub fn update_adapter_heights(
196 &mut self,
197 trace_heights: &mut [u32],
198 address_space: u32,
199 size_bits: u32,
200 ) {
201 self.update_adapter_heights_batch(trace_heights, address_space, size_bits, 1);
202 }
203
204 #[inline(always)]
205 pub fn update_adapter_heights_batch(
206 &self,
207 trace_heights: &mut [u32],
208 address_space: u32,
209 size_bits: u32,
210 num: u32,
211 ) {
212 debug_assert!((address_space as usize) < self.min_block_size_bits.len());
213
214 let align_bits = unsafe {
217 *self
218 .min_block_size_bits
219 .get_unchecked(address_space as usize)
220 };
221 debug_assert!(
222 align_bits as u32 <= size_bits,
223 "align_bits ({}) must be <= size_bits ({})",
224 align_bits,
225 size_bits
226 );
227
228 for adapter_bits in (align_bits as u32 + 1..=size_bits).rev() {
229 let adapter_idx = self.adapter_offset + adapter_bits as usize - 1;
230 debug_assert!(adapter_idx < trace_heights.len());
231 unsafe {
233 *trace_heights.get_unchecked_mut(adapter_idx) +=
234 num << (size_bits - adapter_bits + 1);
235 }
236 }
237 }
238
239 #[inline(always)]
241 pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) {
242 debug_assert!(self.boundary_idx < trace_heights.len());
243
244 let leaves = (self.page_access_count << PAGE_BITS) as u32;
246 unsafe {
248 *trace_heights.get_unchecked_mut(self.boundary_idx) += leaves;
249 }
250
251 if let Some(merkle_tree_idx) = self.merkle_tree_index {
252 debug_assert!(merkle_tree_idx < trace_heights.len());
253 debug_assert!(trace_heights.len() >= 2);
254
255 let poseidon2_idx = trace_heights.len() - 2;
256 unsafe {
258 *trace_heights.get_unchecked_mut(poseidon2_idx) += leaves * 2;
259 }
260
261 let merkle_height = self.memory_dimensions.overall_height();
262 let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32;
263 unsafe {
265 *trace_heights.get_unchecked_mut(poseidon2_idx) += nodes * 2;
266 *trace_heights.get_unchecked_mut(merkle_tree_idx) += nodes * 2;
267 }
268 }
269 self.page_access_count = 0;
270
271 for address_space in 0..self.addr_space_access_count.len() {
272 let x = unsafe { *self.addr_space_access_count.get_unchecked(address_space) };
274 if x > 0 {
275 self.update_adapter_heights_batch(
277 trace_heights,
278 address_space as u32,
279 self.chunk_bits,
280 (x << PAGE_BITS) as u32,
281 );
282 unsafe {
284 *self
285 .addr_space_access_count
286 .get_unchecked_mut(address_space) = 0;
287 }
288 }
289 }
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 #[test]
297 fn test_bitset_insert_range() {
298 let mut bit_set = BitSet::new(8 * 64 + 1);
300 let num_flips = bit_set.insert_range(2, 29);
301 assert_eq!(num_flips, 27);
302 let num_flips = bit_set.insert_range(1, 31);
303 assert_eq!(num_flips, 3);
304
305 let num_flips = bit_set.insert_range(32, 65);
306 assert_eq!(num_flips, 33);
307 let num_flips = bit_set.insert_range(0, 66);
308 assert_eq!(num_flips, 3);
309 let num_flips = bit_set.insert_range(0, 66);
310 assert_eq!(num_flips, 0);
311
312 let num_flips = bit_set.insert_range(256, 320);
313 assert_eq!(num_flips, 64);
314 let num_flips = bit_set.insert_range(256, 377);
315 assert_eq!(num_flips, 57);
316 let num_flips = bit_set.insert_range(100, 513);
317 assert_eq!(num_flips, 413 - 121);
318 }
319}