openvm_circuit/system/memory/adapter/
records.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::{align_of, size_of},
4};
5
6use openvm_circuit_primitives::AlignedBytesBorrow;
7
8use crate::arch::{CustomBorrow, DenseRecordArena, RecordArena, SizedRecord};
9
10#[repr(C)]
11#[derive(Debug, Clone, Copy, AlignedBytesBorrow, PartialEq, Eq, PartialOrd, Ord)]
12pub struct AccessRecordHeader {
13    /// Iff we need to merge before, this has the `MERGE_AND_NOT_SPLIT_FLAG` bit set
14    pub timestamp_and_mask: u32,
15    pub address_space: u32,
16    pub pointer: u32,
17    // PERF: these three are easily mergeable into a single u32
18    pub block_size: u32,
19    pub lowest_block_size: u32,
20    pub type_size: u32,
21}
22
23#[repr(C)]
24#[derive(Debug)]
25pub struct AccessRecordMut<'a> {
26    pub header: &'a mut AccessRecordHeader,
27    // PERF(AG): optimize with some `Option` serialization stuff
28    pub timestamps: &'a mut [u32], // len is block_size / lowest_block_size
29    pub data: &'a mut [u8],        // len is block_size * type_size
30}
31
32#[derive(Debug, Clone)]
33pub struct AccessLayout {
34    /// The size of the block in elements.
35    pub block_size: usize,
36    /// The size of the minimal block we may split into/merge from (usually 1 or 4)
37    pub lowest_block_size: usize,
38    /// The size of the type in bytes (1 for u8, 4 for F).
39    pub type_size: usize,
40}
41
42impl AccessLayout {
43    pub(crate) fn from_record_header(header: &AccessRecordHeader) -> Self {
44        Self {
45            block_size: header.block_size as usize,
46            lowest_block_size: header.lowest_block_size as usize,
47            type_size: header.type_size as usize,
48        }
49    }
50}
51
52pub(crate) const MERGE_AND_NOT_SPLIT_FLAG: u32 = 1 << 31;
53
54pub(crate) fn size_by_layout(layout: &AccessLayout) -> usize {
55    size_of::<AccessRecordHeader>() // header struct
56    + (layout.block_size / layout.lowest_block_size) * size_of::<u32>() // timestamps
57    + (layout.block_size * layout.type_size).next_multiple_of(4) // data
58}
59
60impl SizedRecord<AccessLayout> for AccessRecordMut<'_> {
61    fn size(layout: &AccessLayout) -> usize {
62        size_by_layout(layout)
63    }
64
65    fn alignment(_: &AccessLayout) -> usize {
66        align_of::<AccessRecordHeader>()
67    }
68}
69
70impl<'a> CustomBorrow<'a, AccessRecordMut<'a>, AccessLayout> for [u8] {
71    fn custom_borrow(&'a mut self, layout: AccessLayout) -> AccessRecordMut<'a> {
72        // header: AccessRecordHeader
73        // SAFETY: self.len() >= size_of::<AccessRecordHeader>() from size_by_layout()
74        let (header_buf, rest) =
75            unsafe { self.split_at_mut_unchecked(size_of::<AccessRecordHeader>()) };
76        let header = header_buf.borrow_mut();
77
78        let mut offset = 0;
79
80        // timestamps: [u32] (block_size / cell_size * 4 bytes)
81        // SAFETY:
82        // - size: (layout.block_size / layout.lowest_block_size) * size_of::<u32>() from
83        //   size_by_layout()
84        // - alignment: u32 aligned due to AccessRecordHeader alignment
85        let timestamps = unsafe {
86            std::slice::from_raw_parts_mut(
87                rest.as_mut_ptr().add(offset) as *mut u32,
88                layout.block_size / layout.lowest_block_size,
89            )
90        };
91        offset += layout.block_size / layout.lowest_block_size * size_of::<u32>();
92
93        // data: [u8] (block_size * type_size bytes)
94        // SAFETY:
95        // - size: layout.block_size * layout.type_size from size_by_layout()
96        // - offset points past timestamps section
97        let data = unsafe {
98            std::slice::from_raw_parts_mut(
99                rest.as_mut_ptr().add(offset),
100                layout.block_size * layout.type_size,
101            )
102        };
103
104        AccessRecordMut {
105            header,
106            data,
107            timestamps,
108        }
109    }
110
111    unsafe fn extract_layout(&self) -> AccessLayout {
112        let header: &AccessRecordHeader = self.borrow();
113        AccessLayout {
114            block_size: header.block_size as usize,
115            lowest_block_size: header.lowest_block_size as usize,
116            type_size: header.type_size as usize,
117        }
118    }
119}
120
121impl<'a> RecordArena<'a, AccessLayout, AccessRecordMut<'a>> for DenseRecordArena {
122    fn alloc(&'a mut self, layout: AccessLayout) -> AccessRecordMut<'a> {
123        let bytes = self.alloc_bytes(<AccessRecordMut<'a> as SizedRecord<AccessLayout>>::size(
124            &layout,
125        ));
126        <[u8] as CustomBorrow<AccessRecordMut<'a>, AccessLayout>>::custom_borrow(bytes, layout)
127    }
128}
129
130/// `trace_heights[i]` is assumed to correspond to `Adapter< 2^(i+1) >`.
131pub fn arena_size_bound(trace_heights: &[u32]) -> usize {
132    // At the very worst, each row in `Adapter<N>`
133    // corresponds to a unique record of `block_size` being `2 * N`,
134    // and its `lowest_block_size` is at least 1 and `type_size` is at most 4.
135    let size_bound = trace_heights
136        .iter()
137        .enumerate()
138        .map(|(i, &h)| {
139            size_by_layout(&AccessLayout {
140                block_size: 1 << (i + 1),
141                lowest_block_size: 1,
142                type_size: 4,
143            }) * h as usize
144        })
145        .sum::<usize>();
146    tracing::debug!(
147        "Allocating {} bytes for memory adapters arena from heights {:?}",
148        size_bound,
149        trace_heights
150    );
151    size_bound
152}