openvm_circuit/system/memory/
paged_vec.rs

1use std::{mem::MaybeUninit, ops::Range, ptr};
2
3use serde::{Deserialize, Serialize};
4
5use crate::arch::MemoryConfig;
6
7/// (address_space, pointer)
8pub type Address = (u32, u32);
9pub const PAGE_SIZE: usize = 1 << 12;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PagedVec<T, const PAGE_SIZE: usize> {
13    pub pages: Vec<Option<Vec<T>>>,
14}
15
16// ------------------------------------------------------------------
17// Common Helper Functions
18// These functions encapsulate the common logic for copying ranges
19// across pages, both for read-only and read-write (set) cases.
20impl<T: Default + Clone, const PAGE_SIZE: usize> PagedVec<T, PAGE_SIZE> {
21    // Copies a range of length `len` starting at index `start`
22    // into the memory pointed to by `dst`. If the relevant page is not
23    // initialized, fills that portion with T::default().
24    fn read_range_generic(&self, start: usize, len: usize, dst: *mut T) {
25        let start_page = start / PAGE_SIZE;
26        let end_page = (start + len - 1) / PAGE_SIZE;
27        unsafe {
28            if start_page == end_page {
29                let offset = start % PAGE_SIZE;
30                if let Some(page) = self.pages[start_page].as_ref() {
31                    ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, len);
32                } else {
33                    std::slice::from_raw_parts_mut(dst, len).fill(T::default());
34                }
35            } else {
36                let offset = start % PAGE_SIZE;
37                let first_part = PAGE_SIZE - offset;
38                if let Some(page) = self.pages[start_page].as_ref() {
39                    ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, first_part);
40                } else {
41                    std::slice::from_raw_parts_mut(dst, first_part).fill(T::default());
42                }
43                let second_part = len - first_part;
44                if let Some(page) = self.pages[end_page].as_ref() {
45                    ptr::copy_nonoverlapping(page.as_ptr(), dst.add(first_part), second_part);
46                } else {
47                    std::slice::from_raw_parts_mut(dst.add(first_part), second_part)
48                        .fill(T::default());
49                }
50            }
51        }
52    }
53
54    // Updates a range of length `len` starting at index `start` with new values.
55    // It copies the current values into the memory pointed to by `dst`
56    // and then writes the new values into the underlying pages,
57    // allocating pages (with defaults) if necessary.
58    fn set_range_generic(&mut self, start: usize, len: usize, new: *const T, dst: *mut T) {
59        let start_page = start / PAGE_SIZE;
60        let end_page = (start + len - 1) / PAGE_SIZE;
61        unsafe {
62            if start_page == end_page {
63                let offset = start % PAGE_SIZE;
64                let page =
65                    self.pages[start_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]);
66                ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, len);
67                ptr::copy_nonoverlapping(new, page.as_mut_ptr().add(offset), len);
68            } else {
69                let offset = start % PAGE_SIZE;
70                let first_part = PAGE_SIZE - offset;
71                {
72                    let page =
73                        self.pages[start_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]);
74                    ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, first_part);
75                    ptr::copy_nonoverlapping(new, page.as_mut_ptr().add(offset), first_part);
76                }
77                let second_part = len - first_part;
78                {
79                    let page =
80                        self.pages[end_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]);
81                    ptr::copy_nonoverlapping(page.as_ptr(), dst.add(first_part), second_part);
82                    ptr::copy_nonoverlapping(new.add(first_part), page.as_mut_ptr(), second_part);
83                }
84            }
85        }
86    }
87}
88
89// ------------------------------------------------------------------
90// Implementation for types requiring Default + Clone
91impl<T: Default + Clone, const PAGE_SIZE: usize> PagedVec<T, PAGE_SIZE> {
92    pub fn new(num_pages: usize) -> Self {
93        Self {
94            pages: vec![None; num_pages],
95        }
96    }
97
98    pub fn get(&self, index: usize) -> Option<&T> {
99        let page_idx = index / PAGE_SIZE;
100        self.pages[page_idx]
101            .as_ref()
102            .map(|page| &page[index % PAGE_SIZE])
103    }
104
105    pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
106        let page_idx = index / PAGE_SIZE;
107        self.pages[page_idx]
108            .as_mut()
109            .map(|page| &mut page[index % PAGE_SIZE])
110    }
111
112    pub fn set(&mut self, index: usize, value: T) -> Option<T> {
113        let page_idx = index / PAGE_SIZE;
114        if let Some(page) = self.pages[page_idx].as_mut() {
115            Some(std::mem::replace(&mut page[index % PAGE_SIZE], value))
116        } else {
117            let page = self.pages[page_idx].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]);
118            page[index % PAGE_SIZE] = value;
119            None
120        }
121    }
122
123    #[inline(always)]
124    pub fn range_vec(&self, range: Range<usize>) -> Vec<T> {
125        let len = range.end - range.start;
126        // Create a vector for uninitialized values.
127        let mut result: Vec<MaybeUninit<T>> = Vec::with_capacity(len);
128        // SAFETY: We set the length and then initialize every element via read_range_generic.
129        unsafe {
130            result.set_len(len);
131            self.read_range_generic(range.start, len, result.as_mut_ptr() as *mut T);
132            std::mem::transmute::<Vec<MaybeUninit<T>>, Vec<T>>(result)
133        }
134    }
135
136    pub fn set_range(&mut self, range: Range<usize>, values: &[T]) -> Vec<T> {
137        let len = range.end - range.start;
138        assert_eq!(values.len(), len);
139        let mut result: Vec<MaybeUninit<T>> = Vec::with_capacity(len);
140        // SAFETY: We will write to every element in result via set_range_generic.
141        unsafe {
142            result.set_len(len);
143            self.set_range_generic(
144                range.start,
145                len,
146                values.as_ptr(),
147                result.as_mut_ptr() as *mut T,
148            );
149            std::mem::transmute::<Vec<MaybeUninit<T>>, Vec<T>>(result)
150        }
151    }
152
153    pub fn memory_size(&self) -> usize {
154        self.pages.len() * PAGE_SIZE
155    }
156
157    pub fn is_empty(&self) -> bool {
158        self.pages.iter().all(|page| page.is_none())
159    }
160}
161
162// ------------------------------------------------------------------
163// Implementation for types requiring Default + Copy
164impl<T: Default + Copy, const PAGE_SIZE: usize> PagedVec<T, PAGE_SIZE> {
165    #[inline(always)]
166    pub fn range_array<const N: usize>(&self, from: usize) -> [T; N] {
167        // Create an uninitialized array of MaybeUninit<T>
168        let mut result: [MaybeUninit<T>; N] = unsafe {
169            // SAFETY: An uninitialized `[MaybeUninit<T>; N]` is valid.
170            MaybeUninit::uninit().assume_init()
171        };
172        self.read_range_generic(from, N, result.as_mut_ptr() as *mut T);
173        // SAFETY: All elements have been initialized.
174        unsafe { ptr::read(&result as *const _ as *const [T; N]) }
175    }
176
177    #[inline(always)]
178    pub fn set_range_array<const N: usize>(&mut self, from: usize, values: &[T; N]) -> [T; N] {
179        // Create an uninitialized array for old values.
180        let mut result: [MaybeUninit<T>; N] = unsafe { MaybeUninit::uninit().assume_init() };
181        self.set_range_generic(from, N, values.as_ptr(), result.as_mut_ptr() as *mut T);
182        unsafe { ptr::read(&result as *const _ as *const [T; N]) }
183    }
184}
185
186impl<T, const PAGE_SIZE: usize> PagedVec<T, PAGE_SIZE> {
187    pub fn iter(&self) -> PagedVecIter<'_, T, PAGE_SIZE> {
188        PagedVecIter {
189            vec: self,
190            current_page: 0,
191            current_index_in_page: 0,
192        }
193    }
194}
195
196pub struct PagedVecIter<'a, T, const PAGE_SIZE: usize> {
197    vec: &'a PagedVec<T, PAGE_SIZE>,
198    current_page: usize,
199    current_index_in_page: usize,
200}
201
202impl<T: Clone, const PAGE_SIZE: usize> Iterator for PagedVecIter<'_, T, PAGE_SIZE> {
203    type Item = (usize, T);
204
205    fn next(&mut self) -> Option<Self::Item> {
206        while self.current_page < self.vec.pages.len()
207            && self.vec.pages[self.current_page].is_none()
208        {
209            self.current_page += 1;
210            debug_assert_eq!(self.current_index_in_page, 0);
211            self.current_index_in_page = 0;
212        }
213        if self.current_page >= self.vec.pages.len() {
214            return None;
215        }
216        let global_index = self.current_page * PAGE_SIZE + self.current_index_in_page;
217
218        let page = self.vec.pages[self.current_page].as_ref()?;
219        let value = page[self.current_index_in_page].clone();
220
221        self.current_index_in_page += 1;
222        if self.current_index_in_page == PAGE_SIZE {
223            self.current_page += 1;
224            self.current_index_in_page = 0;
225        }
226        Some((global_index, value))
227    }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct AddressMap<T, const PAGE_SIZE: usize> {
232    pub paged_vecs: Vec<PagedVec<T, PAGE_SIZE>>,
233    pub as_offset: u32,
234}
235
236impl<T: Clone + Default, const PAGE_SIZE: usize> Default for AddressMap<T, PAGE_SIZE> {
237    fn default() -> Self {
238        Self::from_mem_config(&MemoryConfig::default())
239    }
240}
241
242impl<T: Clone + Default, const PAGE_SIZE: usize> AddressMap<T, PAGE_SIZE> {
243    pub fn new(as_offset: u32, as_cnt: usize, mem_size: usize) -> Self {
244        Self {
245            paged_vecs: vec![PagedVec::new(mem_size.div_ceil(PAGE_SIZE)); as_cnt],
246            as_offset,
247        }
248    }
249    pub fn from_mem_config(mem_config: &MemoryConfig) -> Self {
250        Self::new(
251            mem_config.as_offset,
252            1 << mem_config.as_height,
253            1 << mem_config.pointer_max_bits,
254        )
255    }
256    pub fn items(&self) -> impl Iterator<Item = (Address, T)> + '_ {
257        self.paged_vecs
258            .iter()
259            .enumerate()
260            .flat_map(move |(as_idx, page)| {
261                page.iter()
262                    .map(move |(ptr_idx, x)| ((as_idx as u32 + self.as_offset, ptr_idx as u32), x))
263            })
264    }
265    pub fn get(&self, address: &Address) -> Option<&T> {
266        self.paged_vecs[(address.0 - self.as_offset) as usize].get(address.1 as usize)
267    }
268    pub fn get_mut(&mut self, address: &Address) -> Option<&mut T> {
269        self.paged_vecs[(address.0 - self.as_offset) as usize].get_mut(address.1 as usize)
270    }
271    pub fn insert(&mut self, address: &Address, data: T) -> Option<T> {
272        self.paged_vecs[(address.0 - self.as_offset) as usize].set(address.1 as usize, data)
273    }
274    pub fn is_empty(&self) -> bool {
275        self.paged_vecs.iter().all(|page| page.is_empty())
276    }
277
278    pub fn from_iter(
279        as_offset: u32,
280        as_cnt: usize,
281        mem_size: usize,
282        iter: impl IntoIterator<Item = (Address, T)>,
283    ) -> Self {
284        let mut vec = Self::new(as_offset, as_cnt, mem_size);
285        for (address, data) in iter {
286            vec.insert(&address, data);
287        }
288        vec
289    }
290}
291
292impl<T: Copy + Default, const PAGE_SIZE: usize> AddressMap<T, PAGE_SIZE> {
293    pub fn get_range<const N: usize>(&self, address: &Address) -> [T; N] {
294        self.paged_vecs[(address.0 - self.as_offset) as usize].range_array(address.1 as usize)
295    }
296    pub fn set_range<const N: usize>(&mut self, address: &Address, values: &[T; N]) -> [T; N] {
297        self.paged_vecs[(address.0 - self.as_offset) as usize]
298            .set_range_array(address.1 as usize, values)
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_basic_get_set() {
308        let mut v = PagedVec::<_, 4>::new(3);
309        assert_eq!(v.get(0), None);
310        v.set(0, 42);
311        assert_eq!(v.get(0), Some(&42));
312    }
313
314    #[test]
315    fn test_cross_page_operations() {
316        let mut v = PagedVec::<_, 4>::new(3);
317        v.set(3, 10); // Last element of first page
318        v.set(4, 20); // First element of second page
319        assert_eq!(v.get(3), Some(&10));
320        assert_eq!(v.get(4), Some(&20));
321    }
322
323    #[test]
324    fn test_page_boundaries() {
325        let mut v = PagedVec::<_, 4>::new(2);
326        // Fill first page
327        v.set(0, 1);
328        v.set(1, 2);
329        v.set(2, 3);
330        v.set(3, 4);
331        // Fill second page
332        v.set(4, 5);
333        v.set(5, 6);
334        v.set(6, 7);
335        v.set(7, 8);
336
337        // Verify all values
338        assert_eq!(v.range_vec(0..8), [1, 2, 3, 4, 5, 6, 7, 8]);
339    }
340
341    #[test]
342    fn test_range_cross_page_boundary() {
343        let mut v = PagedVec::<_, 4>::new(2);
344        v.set_range(2..8, &[10, 11, 12, 13, 14, 15]);
345        assert_eq!(v.range_vec(2..8), [10, 11, 12, 13, 14, 15]);
346    }
347
348    #[test]
349    fn test_large_indices() {
350        let mut v = PagedVec::<_, 4>::new(100);
351        let large_index = 399;
352        v.set(large_index, 42);
353        assert_eq!(v.get(large_index), Some(&42));
354    }
355
356    #[test]
357    fn test_range_operations_with_defaults() {
358        let mut v = PagedVec::<_, 4>::new(3);
359        v.set(2, 5);
360        v.set(5, 10);
361
362        // Should include both set values and defaults
363        assert_eq!(v.range_vec(1..7), [0, 5, 0, 0, 10, 0]);
364    }
365
366    #[test]
367    fn test_non_zero_default_type() {
368        let mut v: PagedVec<bool, 4> = PagedVec::new(2);
369        assert_eq!(v.get(0), None); // bool's default
370        v.set(0, true);
371        assert_eq!(v.get(0), Some(&true));
372        assert_eq!(v.get(1), Some(&false)); // because we created the page
373    }
374
375    #[test]
376    fn test_set_range_overlapping_pages() {
377        let mut v = PagedVec::<_, 4>::new(3);
378        let test_data = [1, 2, 3, 4, 5, 6];
379        v.set_range(2..8, &test_data);
380
381        // Verify first page
382        assert_eq!(v.get(2), Some(&1));
383        assert_eq!(v.get(3), Some(&2));
384
385        // Verify second page
386        assert_eq!(v.get(4), Some(&3));
387        assert_eq!(v.get(5), Some(&4));
388        assert_eq!(v.get(6), Some(&5));
389        assert_eq!(v.get(7), Some(&6));
390    }
391
392    #[test]
393    fn test_overlapping_set_ranges() {
394        let mut v = PagedVec::<_, 4>::new(3);
395
396        // Initial set_range
397        v.set_range(0..5, &[1, 2, 3, 4, 5]);
398        assert_eq!(v.range_vec(0..5), [1, 2, 3, 4, 5]);
399
400        // Overlap from beginning
401        v.set_range(0..3, &[10, 20, 30]);
402        assert_eq!(v.range_vec(0..5), [10, 20, 30, 4, 5]);
403
404        // Overlap in middle
405        v.set_range(2..4, &[42, 43]);
406        assert_eq!(v.range_vec(0..5), [10, 20, 42, 43, 5]);
407
408        // Overlap at end
409        v.set_range(4..6, &[91, 92]);
410        assert_eq!(v.range_vec(0..6), [10, 20, 42, 43, 91, 92]);
411    }
412
413    #[test]
414    fn test_overlapping_set_ranges_cross_pages() {
415        let mut v = PagedVec::<_, 4>::new(3);
416
417        // Fill across first two pages
418        v.set_range(0..8, &[1, 2, 3, 4, 5, 6, 7, 8]);
419
420        // Overlap end of first page and start of second
421        v.set_range(2..6, &[21, 22, 23, 24]);
422        assert_eq!(v.range_vec(0..8), [1, 2, 21, 22, 23, 24, 7, 8]);
423
424        // Overlap multiple pages
425        v.set_range(1..7, &[31, 32, 33, 34, 35, 36]);
426        assert_eq!(v.range_vec(0..8), [1, 31, 32, 33, 34, 35, 36, 8]);
427    }
428
429    #[test]
430    fn test_iterator() {
431        let mut v = PagedVec::<_, 4>::new(3);
432
433        v.set_range(4..10, &[1, 2, 3, 4, 5, 6]);
434        let contents: Vec<_> = v.iter().collect();
435        assert_eq!(contents.len(), 8); // two pages
436
437        contents
438            .iter()
439            .take(6)
440            .enumerate()
441            .for_each(|(i, &(idx, val))| {
442                assert_eq!((idx, val), (4 + i, 1 + i));
443            });
444        assert_eq!(contents[6], (10, 0));
445        assert_eq!(contents[7], (11, 0));
446    }
447}