openvm_circuit/system/memory/adapter/
mod.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4    ptr::copy_nonoverlapping,
5    sync::Arc,
6};
7
8pub use air::*;
9pub use columns::*;
10use enum_dispatch::enum_dispatch;
11use getset::Setters;
12use openvm_circuit_primitives::{
13    is_less_than::IsLtSubAir, utils::next_power_of_two_or_zero,
14    var_range::SharedVariableRangeCheckerChip, TraceSubRowGenerator,
15};
16use openvm_stark_backend::{
17    config::{Domain, StarkGenericConfig},
18    p3_air::BaseAir,
19    p3_commit::PolynomialSpace,
20    p3_field::PrimeField32,
21    p3_matrix::dense::RowMajorMatrix,
22    p3_util::log2_strict_usize,
23    prover::{cpu::CpuBackend, types::AirProvingContext},
24};
25
26use crate::{
27    arch::{
28        AddressSpaceHostConfig, AddressSpaceHostLayout, CustomBorrow, DenseRecordArena,
29        MemoryCellType, MemoryConfig, SizedRecord,
30    },
31    system::memory::{
32        adapter::records::{
33            arena_size_bound, AccessLayout, AccessRecordHeader, AccessRecordMut,
34            MERGE_AND_NOT_SPLIT_FLAG,
35        },
36        offline_checker::MemoryBus,
37        MemoryAddress,
38    },
39};
40
41mod air;
42mod columns;
43pub mod records;
44
45#[derive(Setters)]
46pub struct AccessAdapterInventory<F> {
47    pub(super) memory_config: MemoryConfig,
48    chips: Vec<GenericAccessAdapterChip<F>>,
49    #[getset(set = "pub")]
50    arena: DenseRecordArena,
51    #[cfg(feature = "metrics")]
52    pub(crate) trace_heights: Vec<usize>,
53}
54
55impl<F: Clone + Send + Sync> AccessAdapterInventory<F> {
56    pub fn new(
57        range_checker: SharedVariableRangeCheckerChip,
58        memory_bus: MemoryBus,
59        memory_config: MemoryConfig,
60    ) -> Self {
61        let rc = range_checker;
62        let mb = memory_bus;
63        let tmb = memory_config.timestamp_max_bits;
64        let maan = memory_config.max_access_adapter_n;
65        assert!(matches!(maan, 2 | 4 | 8 | 16 | 32));
66        let chips: Vec<_> = [
67            Self::create_access_adapter_chip::<2>(rc.clone(), mb, tmb, maan),
68            Self::create_access_adapter_chip::<4>(rc.clone(), mb, tmb, maan),
69            Self::create_access_adapter_chip::<8>(rc.clone(), mb, tmb, maan),
70            Self::create_access_adapter_chip::<16>(rc.clone(), mb, tmb, maan),
71            Self::create_access_adapter_chip::<32>(rc.clone(), mb, tmb, maan),
72        ]
73        .into_iter()
74        .flatten()
75        .collect();
76        Self {
77            memory_config,
78            chips,
79            arena: DenseRecordArena::with_byte_capacity(0),
80            #[cfg(feature = "metrics")]
81            trace_heights: Vec::new(),
82        }
83    }
84
85    pub fn num_access_adapters(&self) -> usize {
86        self.chips.len()
87    }
88
89    pub(super) fn set_override_trace_heights(&mut self, overridden_heights: Vec<usize>) {
90        self.set_arena_from_trace_heights(
91            &overridden_heights
92                .iter()
93                .map(|&h| h as u32)
94                .collect::<Vec<_>>(),
95        );
96        for (chip, oh) in self.chips.iter_mut().zip(overridden_heights) {
97            chip.set_override_trace_height(oh);
98        }
99    }
100
101    pub(super) fn set_arena_from_trace_heights(&mut self, trace_heights: &[u32]) {
102        assert_eq!(trace_heights.len(), self.chips.len());
103        let size_bound = arena_size_bound(trace_heights);
104        tracing::debug!(
105            "Allocating {} bytes for memory adapters arena from heights {:?}",
106            size_bound,
107            trace_heights
108        );
109        self.arena.set_byte_capacity(size_bound);
110    }
111
112    pub fn get_widths(&self) -> Vec<usize> {
113        self.chips
114            .iter()
115            .map(|chip: &GenericAccessAdapterChip<F>| chip.trace_width())
116            .collect()
117    }
118
119    /// `heights` should have length equal to the number of access adapter chips.
120    pub(crate) fn compute_heights_from_arena(arena: &DenseRecordArena, heights: &mut [usize]) {
121        let bytes = arena.allocated();
122        tracing::debug!(
123            "Computing heights from memory adapters arena: used {} bytes",
124            bytes.len()
125        );
126        let mut ptr = 0;
127        while ptr < bytes.len() {
128            let bytes_slice = &bytes[ptr..];
129            let header: &AccessRecordHeader = bytes_slice.borrow();
130            // SAFETY:
131            // - bytes[ptr..] is a valid starting pointer to a previously allocated record
132            // - The record contains self-describing layout information
133            let layout: AccessLayout = unsafe { bytes_slice.extract_layout() };
134            ptr += <AccessRecordMut<'_> as SizedRecord<AccessLayout>>::size(&layout);
135
136            let log_max_block_size = log2_strict_usize(header.block_size as usize);
137            for (i, h) in heights
138                .iter_mut()
139                .enumerate()
140                .take(log_max_block_size)
141                .skip(log2_strict_usize(header.lowest_block_size as usize))
142            {
143                *h += 1 << (log_max_block_size - i - 1);
144            }
145        }
146        tracing::debug!("Computed heights from memory adapters arena: {:?}", heights);
147    }
148
149    fn apply_overridden_heights(&mut self, heights: &mut [usize]) {
150        for (i, h) in heights.iter_mut().enumerate() {
151            if let Some(oh) = self.chips[i].overridden_trace_height() {
152                assert!(
153                    oh >= *h,
154                    "Overridden height {oh} is less than the required height {}",
155                    *h
156                );
157                *h = oh;
158            }
159            *h = next_power_of_two_or_zero(*h);
160        }
161    }
162
163    pub fn generate_proving_ctx<SC: StarkGenericConfig>(
164        &mut self,
165    ) -> Vec<AirProvingContext<CpuBackend<SC>>>
166    where
167        F: PrimeField32,
168        Domain<SC>: PolynomialSpace<Val = F>,
169    {
170        let num_adapters = self.chips.len();
171
172        let mut heights = vec![0; num_adapters];
173        Self::compute_heights_from_arena(&self.arena, &mut heights);
174        self.apply_overridden_heights(&mut heights);
175
176        let widths = self
177            .chips
178            .iter()
179            .map(|chip| chip.trace_width())
180            .collect::<Vec<_>>();
181        let mut traces = widths
182            .iter()
183            .zip(heights.iter())
184            .map(|(&width, &height)| RowMajorMatrix::new(vec![F::ZERO; width * height], width))
185            .collect::<Vec<_>>();
186        #[cfg(feature = "metrics")]
187        {
188            self.trace_heights = heights;
189        }
190
191        let mut trace_ptrs = vec![0; num_adapters];
192
193        let bytes = self.arena.allocated_mut();
194        let mut ptr = 0;
195        while ptr < bytes.len() {
196            let bytes_slice = &mut bytes[ptr..];
197            // SAFETY:
198            // - bytes[ptr..] is a valid starting pointer to a previously allocated record
199            // - The record contains self-describing layout information
200            let layout: AccessLayout = unsafe { bytes_slice.extract_layout() };
201            let record: AccessRecordMut<'_> = bytes_slice.custom_borrow(layout.clone());
202            ptr += <AccessRecordMut<'_> as SizedRecord<AccessLayout>>::size(&layout);
203
204            let log_min_block_size = log2_strict_usize(record.header.lowest_block_size as usize);
205            let log_max_block_size = log2_strict_usize(record.header.block_size as usize);
206
207            if record.header.timestamp_and_mask & MERGE_AND_NOT_SPLIT_FLAG != 0 {
208                for i in log_min_block_size..log_max_block_size {
209                    let data_len = layout.type_size << i;
210                    let ts_len = 1 << (i - log_min_block_size);
211                    for j in 0..record.data.len() / (2 * data_len) {
212                        let row_slice =
213                            &mut traces[i].values[trace_ptrs[i]..trace_ptrs[i] + widths[i]];
214                        trace_ptrs[i] += widths[i];
215                        self.chips[i].fill_trace_row(
216                            &self.memory_config.addr_spaces,
217                            row_slice,
218                            false,
219                            MemoryAddress::new(
220                                record.header.address_space,
221                                record.header.pointer + (j << (i + 1)) as u32,
222                            ),
223                            &record.data[j * 2 * data_len..(j + 1) * 2 * data_len],
224                            *record.timestamps[2 * j * ts_len..(2 * j + 1) * ts_len]
225                                .iter()
226                                .max()
227                                .unwrap(),
228                            *record.timestamps[(2 * j + 1) * ts_len..(2 * j + 2) * ts_len]
229                                .iter()
230                                .max()
231                                .unwrap(),
232                        );
233                    }
234                }
235            } else {
236                let timestamp = record.header.timestamp_and_mask;
237                for i in log_min_block_size..log_max_block_size {
238                    let data_len = layout.type_size << i;
239                    for j in 0..record.data.len() / (2 * data_len) {
240                        let row_slice =
241                            &mut traces[i].values[trace_ptrs[i]..trace_ptrs[i] + widths[i]];
242                        trace_ptrs[i] += widths[i];
243                        self.chips[i].fill_trace_row(
244                            &self.memory_config.addr_spaces,
245                            row_slice,
246                            true,
247                            MemoryAddress::new(
248                                record.header.address_space,
249                                record.header.pointer + (j << (i + 1)) as u32,
250                            ),
251                            &record.data[j * 2 * data_len..(j + 1) * 2 * data_len],
252                            timestamp,
253                            timestamp,
254                        );
255                    }
256                }
257            }
258        }
259        traces
260            .into_iter()
261            .map(|trace| AirProvingContext::simple_no_pis(Arc::new(trace)))
262            .collect()
263    }
264
265    fn create_access_adapter_chip<const N: usize>(
266        range_checker: SharedVariableRangeCheckerChip,
267        memory_bus: MemoryBus,
268        timestamp_max_bits: usize,
269        max_access_adapter_n: usize,
270    ) -> Option<GenericAccessAdapterChip<F>>
271    where
272        F: Clone + Send + Sync,
273    {
274        if N <= max_access_adapter_n {
275            Some(GenericAccessAdapterChip::new::<N>(
276                range_checker,
277                memory_bus,
278                timestamp_max_bits,
279            ))
280        } else {
281            None
282        }
283    }
284}
285
286#[enum_dispatch]
287pub(crate) trait GenericAccessAdapterChipTrait<F> {
288    fn trace_width(&self) -> usize;
289    fn set_override_trace_height(&mut self, overridden_height: usize);
290    fn overridden_trace_height(&self) -> Option<usize>;
291
292    #[allow(clippy::too_many_arguments)]
293    fn fill_trace_row(
294        &self,
295        addr_spaces: &[AddressSpaceHostConfig],
296        row: &mut [F],
297        is_split: bool,
298        address: MemoryAddress<u32, u32>,
299        values: &[u8],
300        left_timestamp: u32,
301        right_timestamp: u32,
302    ) where
303        F: PrimeField32;
304}
305
306#[enum_dispatch(GenericAccessAdapterChipTrait<F>)]
307enum GenericAccessAdapterChip<F> {
308    N2(AccessAdapterChip<F, 2>),
309    N4(AccessAdapterChip<F, 4>),
310    N8(AccessAdapterChip<F, 8>),
311    N16(AccessAdapterChip<F, 16>),
312    N32(AccessAdapterChip<F, 32>),
313}
314
315impl<F: Clone + Send + Sync> GenericAccessAdapterChip<F> {
316    fn new<const N: usize>(
317        range_checker: SharedVariableRangeCheckerChip,
318        memory_bus: MemoryBus,
319        timestamp_max_bits: usize,
320    ) -> Self {
321        let rc = range_checker;
322        let mb = memory_bus;
323        let cmb = timestamp_max_bits;
324        match N {
325            2 => GenericAccessAdapterChip::N2(AccessAdapterChip::new(rc, mb, cmb)),
326            4 => GenericAccessAdapterChip::N4(AccessAdapterChip::new(rc, mb, cmb)),
327            8 => GenericAccessAdapterChip::N8(AccessAdapterChip::new(rc, mb, cmb)),
328            16 => GenericAccessAdapterChip::N16(AccessAdapterChip::new(rc, mb, cmb)),
329            32 => GenericAccessAdapterChip::N32(AccessAdapterChip::new(rc, mb, cmb)),
330            _ => panic!("Only supports N in (2, 4, 8, 16, 32)"),
331        }
332    }
333}
334
335pub(crate) struct AccessAdapterChip<F, const N: usize> {
336    air: AccessAdapterAir<N>,
337    range_checker: SharedVariableRangeCheckerChip,
338    overridden_height: Option<usize>,
339    _marker: PhantomData<F>,
340}
341
342impl<F: Clone + Send + Sync, const N: usize> AccessAdapterChip<F, N> {
343    pub fn new(
344        range_checker: SharedVariableRangeCheckerChip,
345        memory_bus: MemoryBus,
346        timestamp_max_bits: usize,
347    ) -> Self {
348        let lt_air = IsLtSubAir::new(range_checker.bus(), timestamp_max_bits);
349        Self {
350            air: AccessAdapterAir::<N> { memory_bus, lt_air },
351            range_checker,
352            overridden_height: None,
353            _marker: PhantomData,
354        }
355    }
356}
357impl<F, const N: usize> GenericAccessAdapterChipTrait<F> for AccessAdapterChip<F, N> {
358    fn trace_width(&self) -> usize {
359        BaseAir::<F>::width(&self.air)
360    }
361
362    fn set_override_trace_height(&mut self, overridden_height: usize) {
363        self.overridden_height = Some(overridden_height);
364    }
365
366    fn overridden_trace_height(&self) -> Option<usize> {
367        self.overridden_height
368    }
369
370    fn fill_trace_row(
371        &self,
372        addr_spaces: &[AddressSpaceHostConfig],
373        row: &mut [F],
374        is_split: bool,
375        address: MemoryAddress<u32, u32>,
376        values: &[u8],
377        left_timestamp: u32,
378        right_timestamp: u32,
379    ) where
380        F: PrimeField32,
381    {
382        let row: &mut AccessAdapterCols<F, N> = row.borrow_mut();
383        row.is_valid = F::ONE;
384        row.is_split = F::from_bool(is_split);
385        row.address = MemoryAddress::new(
386            F::from_canonical_u32(address.address_space),
387            F::from_canonical_u32(address.pointer),
388        );
389        let addr_space_layout = addr_spaces[address.address_space as usize].layout;
390        // SAFETY: values will be a slice of the cell type
391        unsafe {
392            match addr_space_layout {
393                MemoryCellType::Native { .. } => {
394                    copy_nonoverlapping(
395                        values.as_ptr(),
396                        row.values.as_mut_ptr() as *mut u8,
397                        N * size_of::<F>(),
398                    );
399                }
400                _ => {
401                    for (dst, src) in row
402                        .values
403                        .iter_mut()
404                        .zip(values.chunks_exact(addr_space_layout.size()))
405                    {
406                        *dst = addr_space_layout.to_field(src);
407                    }
408                }
409            }
410        }
411        row.left_timestamp = F::from_canonical_u32(left_timestamp);
412        row.right_timestamp = F::from_canonical_u32(right_timestamp);
413        self.air.lt_air.generate_subrow(
414            (self.range_checker.as_ref(), left_timestamp, right_timestamp),
415            (&mut row.lt_aux, &mut row.is_right_larger),
416        );
417    }
418}