openvm_circuit/arch/execution_mode/metered/
memory_ctx.rs1use abi_stable::std_types::RVec;
2use openvm_instructions::riscv::{RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS};
3
4use crate::{arch::SystemConfig, system::memory::dimensions::MemoryDimensions};
5
6pub const MAX_MEM_PAGE_OPS_PER_INSN: usize = 1 << 16;
8
9#[derive(Clone, Debug)]
10pub struct BitSet {
11 words: Box<[u64]>,
12}
13
14impl BitSet {
15 pub fn new(num_bits: usize) -> Self {
16 Self {
17 words: vec![0; num_bits.div_ceil(u64::BITS as usize)].into_boxed_slice(),
18 }
19 }
20
21 #[inline(always)]
22 pub fn insert(&mut self, index: usize) -> bool {
23 let word_index = index >> 6;
24 let bit_index = index & 63;
25 let mask = 1u64 << bit_index;
26
27 debug_assert!(word_index < self.words.len(), "BitSet index out of bounds");
28
29 let word = unsafe { self.words.get_unchecked_mut(word_index) };
33 let was_set = (*word & mask) != 0;
34 *word |= mask;
35 !was_set
36 }
37
38 #[inline(always)]
41 pub fn insert_range(&mut self, start: usize, end: usize) -> usize {
42 debug_assert!(start < end);
43 debug_assert!(end <= self.words.len() * 64, "BitSet range out of bounds");
44
45 let mut ret = 0;
46 let start_word_index = start >> 6;
47 let end_word_index = (end - 1) >> 6;
48 let start_bit = (start & 63) as u32;
49
50 if start_word_index == end_word_index {
51 let end_bit = ((end - 1) & 63) as u32 + 1;
52 let mask_bits = end_bit - start_bit;
53 let mask = (u64::MAX >> (64 - mask_bits)) << start_bit;
54 let word = unsafe { self.words.get_unchecked_mut(start_word_index) };
57 ret += mask_bits - (*word & mask).count_ones();
58 *word |= mask;
59 } else {
60 let end_bit = (end & 63) as u32;
61 let mask_bits = 64 - start_bit;
62 let mask = u64::MAX << start_bit;
63 let start_word = unsafe { self.words.get_unchecked_mut(start_word_index) };
66 ret += mask_bits - (*start_word & mask).count_ones();
67 *start_word |= mask;
68
69 let mask_bits = end_bit;
70 let mask = if end_bit == 0 {
71 0
72 } else {
73 u64::MAX >> (64 - end_bit)
74 };
75 let end_word = unsafe { self.words.get_unchecked_mut(end_word_index) };
78 ret += mask_bits - (*end_word & mask).count_ones();
79 *end_word |= mask;
80 }
81
82 if start_word_index + 1 < end_word_index {
83 for i in (start_word_index + 1)..end_word_index {
84 let word = unsafe { self.words.get_unchecked_mut(i) };
87 ret += word.count_zeros();
88 *word = u64::MAX;
89 }
90 }
91 ret as usize
92 }
93
94 #[inline(always)]
95 pub fn clear(&mut self) {
96 unsafe {
98 std::ptr::write_bytes(self.words.as_mut_ptr(), 0, self.words.len());
99 }
100 }
101}
102
103#[derive(Clone, Debug)]
104pub struct MemoryCtx<const PAGE_BITS: usize> {
105 memory_dimensions: MemoryDimensions,
106 min_block_size_bits: Vec<u8>,
107 pub boundary_idx: usize,
108 pub merkle_tree_index: Option<usize>,
109 pub adapter_offset: usize,
110 continuations_enabled: bool,
111 chunk: u32,
112 chunk_bits: u32,
113 pub page_indices: BitSet,
114 pub addr_space_access_count: RVec<u32>,
115 pub page_indices_since_checkpoint: Box<[u32]>,
116 pub page_indices_since_checkpoint_len: usize,
117}
118
119impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
120 pub fn new(config: &SystemConfig, segment_check_insns: u64) -> Self {
121 let chunk = config.initial_block_size() as u32;
122 let chunk_bits = chunk.ilog2();
123
124 let memory_dimensions = config.memory_config.memory_dimensions();
125 let merkle_height = memory_dimensions.overall_height();
126
127 let bitset_size = 1 << (merkle_height.saturating_sub(PAGE_BITS));
128 let addr_space_size = (1 << memory_dimensions.addr_space_height) + 1;
129 let page_indices_since_checkpoint_cap =
130 Self::calculate_checkpoint_capacity(segment_check_insns);
131
132 Self {
133 min_block_size_bits: config.memory_config.min_block_size_bits(),
134 boundary_idx: config.memory_boundary_air_id(),
135 merkle_tree_index: config.memory_merkle_air_id(),
136 adapter_offset: config.access_adapter_air_id_offset(),
137 chunk,
138 chunk_bits,
139 memory_dimensions,
140 continuations_enabled: config.continuation_enabled,
141 page_indices: BitSet::new(bitset_size),
142 addr_space_access_count: vec![0; addr_space_size].into(),
143 page_indices_since_checkpoint: vec![0; page_indices_since_checkpoint_cap]
144 .into_boxed_slice(),
145 page_indices_since_checkpoint_len: 0,
146 }
147 }
148
149 #[inline(always)]
150 pub(super) fn calculate_checkpoint_capacity(segment_check_insns: u64) -> usize {
151 segment_check_insns as usize * MAX_MEM_PAGE_OPS_PER_INSN
152 }
153
154 #[inline(always)]
155 pub(crate) fn add_register_merkle_heights(&mut self) {
156 if self.continuations_enabled {
157 self.update_boundary_merkle_heights(
158 RV32_REGISTER_AS,
159 0,
160 (RV32_NUM_REGISTERS * RV32_REGISTER_NUM_LIMBS) as u32,
161 );
162 }
163 }
164
165 #[inline(always)]
169 pub(crate) fn update_boundary_merkle_heights(
170 &mut self,
171 address_space: u32,
172 ptr: u32,
173 size: u32,
174 ) {
175 debug_assert!((address_space as usize) < self.addr_space_access_count.len());
176
177 let num_blocks = (size + self.chunk - 1) >> self.chunk_bits;
178 let start_chunk_id = ptr >> self.chunk_bits;
179 let start_block_id = if self.chunk == 1 {
180 start_chunk_id
181 } else {
182 self.memory_dimensions
183 .label_to_index((address_space, start_chunk_id)) as u32
184 };
185 let end_block_id = start_block_id + num_blocks;
187 let start_page_id = start_block_id >> PAGE_BITS;
188 let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1;
189 assert!(
190 self.page_indices_since_checkpoint_len + (end_page_id - start_page_id) as usize
191 <= self.page_indices_since_checkpoint.len(),
192 "more than {MAX_MEM_PAGE_OPS_PER_INSN} memory pages accessed in a single instruction"
193 );
194
195 for page_id in start_page_id..end_page_id {
196 let len = self.page_indices_since_checkpoint_len;
198 debug_assert!(len < self.page_indices_since_checkpoint.len());
199 unsafe {
201 *self.page_indices_since_checkpoint.as_mut_ptr().add(len) = page_id;
202 }
203 self.page_indices_since_checkpoint_len = len + 1;
204
205 if self.page_indices.insert(page_id as usize) {
206 unsafe {
209 *self
210 .addr_space_access_count
211 .get_unchecked_mut(address_space as usize) += 1;
212 }
213 }
214 }
215 }
216
217 #[inline(always)]
218 pub fn update_adapter_heights(
219 &mut self,
220 trace_heights: &mut [u32],
221 address_space: u32,
222 size_bits: u32,
223 ) {
224 self.update_adapter_heights_batch(trace_heights, address_space, size_bits, 1);
225 }
226
227 #[inline(always)]
228 pub fn update_adapter_heights_batch(
229 &self,
230 trace_heights: &mut [u32],
231 address_space: u32,
232 size_bits: u32,
233 num: u32,
234 ) {
235 debug_assert!((address_space as usize) < self.min_block_size_bits.len());
236
237 let align_bits = unsafe {
240 *self
241 .min_block_size_bits
242 .get_unchecked(address_space as usize)
243 };
244 debug_assert!(
245 align_bits as u32 <= size_bits,
246 "align_bits ({align_bits}) must be <= size_bits ({size_bits})"
247 );
248
249 for adapter_bits in (align_bits as u32 + 1..=size_bits).rev() {
250 let adapter_idx = self.adapter_offset + adapter_bits as usize - 1;
251 debug_assert!(adapter_idx < trace_heights.len());
252 unsafe {
254 *trace_heights.get_unchecked_mut(adapter_idx) +=
255 num << (size_bits - adapter_bits + 1);
256 }
257 }
258 }
259
260 #[inline(always)]
262 pub(crate) fn initialize_segment(&mut self, trace_heights: &mut [u32]) {
263 self.page_indices.clear();
265
266 unsafe {
269 *trace_heights.get_unchecked_mut(self.boundary_idx) = 0;
270 }
271 if let Some(merkle_tree_idx) = self.merkle_tree_index {
272 unsafe {
274 *trace_heights.get_unchecked_mut(merkle_tree_idx) = 0;
275 }
276 let poseidon2_idx = trace_heights.len() - 2;
277 unsafe {
279 *trace_heights.get_unchecked_mut(poseidon2_idx) = 0;
280 }
281 }
282
283 let mut addr_space_access_count = vec![0; self.addr_space_access_count.len()];
286 let pages_len = self.page_indices_since_checkpoint_len;
287 for i in 0..pages_len {
288 let page_id = unsafe { *self.page_indices_since_checkpoint.get_unchecked(i) } as usize;
290 if self.page_indices.insert(page_id) {
291 let (addr_space, _) = self
292 .memory_dimensions
293 .index_to_label((page_id as u64) << PAGE_BITS);
294 let addr_space_idx = addr_space as usize;
295 debug_assert!(addr_space_idx < addr_space_access_count.len());
296 unsafe {
299 *addr_space_access_count.get_unchecked_mut(addr_space_idx) += 1;
300 }
301 }
302 }
303 self.apply_height_updates(trace_heights, &addr_space_access_count);
304
305 self.add_register_merkle_heights();
307 self.lazy_update_boundary_heights(trace_heights);
308 }
309
310 #[inline(always)]
312 pub(crate) fn update_checkpoint(&mut self) {
313 self.page_indices_since_checkpoint_len = 0;
314 }
315
316 #[inline(always)]
318 fn apply_height_updates(&self, trace_heights: &mut [u32], addr_space_access_count: &[u32]) {
319 let page_access_count: u32 = addr_space_access_count.iter().sum();
320
321 let leaves = page_access_count << PAGE_BITS;
323 unsafe {
325 *trace_heights.get_unchecked_mut(self.boundary_idx) += leaves;
326 }
327
328 if let Some(merkle_tree_idx) = self.merkle_tree_index {
329 debug_assert!(merkle_tree_idx < trace_heights.len());
330 debug_assert!(trace_heights.len() >= 2);
331
332 let poseidon2_idx = trace_heights.len() - 2;
333 unsafe {
335 *trace_heights.get_unchecked_mut(poseidon2_idx) += leaves * 2;
336 }
337
338 let merkle_height = self.memory_dimensions.overall_height();
339 let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32;
340 unsafe {
342 *trace_heights.get_unchecked_mut(poseidon2_idx) += nodes * page_access_count * 2;
343 *trace_heights.get_unchecked_mut(merkle_tree_idx) += nodes * page_access_count * 2;
344 }
345 }
346
347 for address_space in 0..addr_space_access_count.len() {
348 let x = unsafe { *addr_space_access_count.get_unchecked(address_space) };
350 if x > 0 {
351 self.update_adapter_heights_batch(
356 trace_heights,
357 address_space as u32,
358 self.chunk_bits,
359 x << (PAGE_BITS + 1),
360 );
361 }
362 }
363 }
364
365 #[inline(always)]
367 pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) {
368 self.apply_height_updates(trace_heights, &self.addr_space_access_count);
369 unsafe {
371 std::ptr::write_bytes(
372 self.addr_space_access_count.as_mut_ptr(),
373 0,
374 self.addr_space_access_count.len(),
375 );
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 #[test]
384 fn test_bitset_insert_range() {
385 let mut bit_set = BitSet::new(8 * 64 + 1);
387 let num_flips = bit_set.insert_range(2, 29);
388 assert_eq!(num_flips, 27);
389 let num_flips = bit_set.insert_range(1, 31);
390 assert_eq!(num_flips, 3);
391
392 let num_flips = bit_set.insert_range(32, 65);
393 assert_eq!(num_flips, 33);
394 let num_flips = bit_set.insert_range(0, 66);
395 assert_eq!(num_flips, 3);
396 let num_flips = bit_set.insert_range(0, 66);
397 assert_eq!(num_flips, 0);
398
399 let num_flips = bit_set.insert_range(256, 320);
400 assert_eq!(num_flips, 64);
401 let num_flips = bit_set.insert_range(256, 377);
402 assert_eq!(num_flips, 57);
403 let num_flips = bit_set.insert_range(100, 513);
404 assert_eq!(num_flips, 413 - 121);
405 }
406}