openvm_circuit/system/memory/adapter/
mod.rs

1use std::{borrow::BorrowMut, cmp::max, sync::Arc};
2
3pub use air::*;
4pub use columns::*;
5use enum_dispatch::enum_dispatch;
6use openvm_circuit_primitives::{
7    is_less_than::IsLtSubAir, utils::next_power_of_two_or_zero,
8    var_range::SharedVariableRangeCheckerChip, TraceSubRowGenerator,
9};
10use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
11use openvm_stark_backend::{
12    config::{Domain, StarkGenericConfig, Val},
13    p3_air::BaseAir,
14    p3_commit::PolynomialSpace,
15    p3_field::PrimeField32,
16    p3_matrix::dense::RowMajorMatrix,
17    p3_maybe_rayon::prelude::*,
18    p3_util::log2_strict_usize,
19    prover::types::AirProofInput,
20    AirRef, Chip, ChipUsageGetter,
21};
22
23use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress};
24
25mod air;
26mod columns;
27#[cfg(test)]
28mod tests;
29
30pub struct AccessAdapterInventory<F> {
31    chips: Vec<GenericAccessAdapterChip<F>>,
32    air_names: Vec<String>,
33}
34
35impl<F> AccessAdapterInventory<F> {
36    pub fn new(
37        range_checker: SharedVariableRangeCheckerChip,
38        memory_bus: MemoryBus,
39        clk_max_bits: usize,
40        max_access_adapter_n: usize,
41    ) -> Self {
42        let rc = range_checker;
43        let mb = memory_bus;
44        let cmb = clk_max_bits;
45        let maan = max_access_adapter_n;
46        assert!(matches!(maan, 2 | 4 | 8 | 16 | 32));
47        let chips: Vec<_> = [
48            Self::create_access_adapter_chip::<2>(rc.clone(), mb, cmb, maan),
49            Self::create_access_adapter_chip::<4>(rc.clone(), mb, cmb, maan),
50            Self::create_access_adapter_chip::<8>(rc.clone(), mb, cmb, maan),
51            Self::create_access_adapter_chip::<16>(rc.clone(), mb, cmb, maan),
52            Self::create_access_adapter_chip::<32>(rc.clone(), mb, cmb, maan),
53        ]
54        .into_iter()
55        .flatten()
56        .collect();
57        let air_names = (0..chips.len()).map(|i| air_name(1 << (i + 1))).collect();
58        Self { chips, air_names }
59    }
60    pub fn num_access_adapters(&self) -> usize {
61        self.chips.len()
62    }
63    pub fn set_override_trace_heights(&mut self, overridden_heights: Vec<usize>) {
64        assert_eq!(overridden_heights.len(), self.chips.len());
65        for (chip, oh) in self.chips.iter_mut().zip(overridden_heights) {
66            chip.set_override_trace_heights(oh);
67        }
68    }
69    pub fn add_record(&mut self, record: AccessAdapterRecord<F>) {
70        let n = record.data.len();
71        let idx = log2_strict_usize(n) - 1;
72        let chip = &mut self.chips[idx];
73        debug_assert!(chip.n() == n);
74        chip.add_record(record);
75    }
76
77    pub fn extend_records(&mut self, records: Vec<AccessAdapterRecord<F>>) {
78        for record in records {
79            self.add_record(record);
80        }
81    }
82
83    #[cfg(test)]
84    pub fn records_for_n(&self, n: usize) -> &[AccessAdapterRecord<F>] {
85        let idx = log2_strict_usize(n) - 1;
86        let chip = &self.chips[idx];
87        chip.records()
88    }
89
90    #[cfg(test)]
91    pub fn total_records(&self) -> usize {
92        self.chips.iter().map(|chip| chip.records().len()).sum()
93    }
94
95    pub fn get_heights(&self) -> Vec<usize> {
96        self.chips
97            .iter()
98            .map(|chip| chip.current_trace_height())
99            .collect()
100    }
101    #[allow(dead_code)]
102    pub fn get_widths(&self) -> Vec<usize> {
103        self.chips.iter().map(|chip| chip.trace_width()).collect()
104    }
105    pub fn get_cells(&self) -> Vec<usize> {
106        self.chips
107            .iter()
108            .map(|chip| chip.current_trace_cells())
109            .collect()
110    }
111    pub fn airs<SC: StarkGenericConfig>(&self) -> Vec<AirRef<SC>>
112    where
113        F: PrimeField32,
114        Domain<SC>: PolynomialSpace<Val = F>,
115    {
116        self.chips.iter().map(|chip| chip.air()).collect()
117    }
118    pub fn air_names(&self) -> Vec<String> {
119        self.air_names.clone()
120    }
121    pub fn generate_air_proof_inputs<SC: StarkGenericConfig>(self) -> Vec<AirProofInput<SC>>
122    where
123        F: PrimeField32,
124        Domain<SC>: PolynomialSpace<Val = F>,
125    {
126        self.chips
127            .into_iter()
128            .map(|chip| chip.generate_air_proof_input())
129            .collect()
130    }
131
132    fn create_access_adapter_chip<const N: usize>(
133        range_checker: SharedVariableRangeCheckerChip,
134        memory_bus: MemoryBus,
135        clk_max_bits: usize,
136        max_access_adapter_n: usize,
137    ) -> Option<GenericAccessAdapterChip<F>> {
138        if N <= max_access_adapter_n {
139            Some(GenericAccessAdapterChip::new::<N>(
140                range_checker,
141                memory_bus,
142                clk_max_bits,
143            ))
144        } else {
145            None
146        }
147    }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq)]
151pub enum AccessAdapterRecordKind {
152    Split,
153    Merge {
154        left_timestamp: u32,
155        right_timestamp: u32,
156    },
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub struct AccessAdapterRecord<T> {
161    pub timestamp: u32,
162    pub address_space: T,
163    pub start_index: T,
164    pub data: Vec<T>,
165    pub kind: AccessAdapterRecordKind,
166}
167
168#[enum_dispatch]
169pub trait GenericAccessAdapterChipTrait<F> {
170    fn set_override_trace_heights(&mut self, overridden_height: usize);
171    fn add_record(&mut self, record: AccessAdapterRecord<F>);
172    fn n(&self) -> usize;
173    fn generate_trace(self) -> RowMajorMatrix<F>
174    where
175        F: PrimeField32;
176}
177
178#[derive(Chip, ChipUsageGetter)]
179#[enum_dispatch(GenericAccessAdapterChipTrait<F>)]
180#[chip(where = "F: PrimeField32")]
181enum GenericAccessAdapterChip<F> {
182    N2(AccessAdapterChip<F, 2>),
183    N4(AccessAdapterChip<F, 4>),
184    N8(AccessAdapterChip<F, 8>),
185    N16(AccessAdapterChip<F, 16>),
186    N32(AccessAdapterChip<F, 32>),
187}
188
189impl<F> GenericAccessAdapterChip<F> {
190    fn new<const N: usize>(
191        range_checker: SharedVariableRangeCheckerChip,
192        memory_bus: MemoryBus,
193        clk_max_bits: usize,
194    ) -> Self {
195        let rc = range_checker;
196        let mb = memory_bus;
197        let cmb = clk_max_bits;
198        match N {
199            2 => GenericAccessAdapterChip::N2(AccessAdapterChip::new(rc, mb, cmb)),
200            4 => GenericAccessAdapterChip::N4(AccessAdapterChip::new(rc, mb, cmb)),
201            8 => GenericAccessAdapterChip::N8(AccessAdapterChip::new(rc, mb, cmb)),
202            16 => GenericAccessAdapterChip::N16(AccessAdapterChip::new(rc, mb, cmb)),
203            32 => GenericAccessAdapterChip::N32(AccessAdapterChip::new(rc, mb, cmb)),
204            _ => panic!("Only supports N in (2, 4, 8, 16, 32)"),
205        }
206    }
207
208    #[cfg(test)]
209    fn records(&self) -> &[AccessAdapterRecord<F>] {
210        match &self {
211            GenericAccessAdapterChip::N2(chip) => &chip.records,
212            GenericAccessAdapterChip::N4(chip) => &chip.records,
213            GenericAccessAdapterChip::N8(chip) => &chip.records,
214            GenericAccessAdapterChip::N16(chip) => &chip.records,
215            GenericAccessAdapterChip::N32(chip) => &chip.records,
216        }
217    }
218}
219pub struct AccessAdapterChip<F, const N: usize> {
220    air: AccessAdapterAir<N>,
221    range_checker: SharedVariableRangeCheckerChip,
222    pub records: Vec<AccessAdapterRecord<F>>,
223    overridden_height: Option<usize>,
224}
225impl<F, const N: usize> AccessAdapterChip<F, N> {
226    pub fn new(
227        range_checker: SharedVariableRangeCheckerChip,
228        memory_bus: MemoryBus,
229        clk_max_bits: usize,
230    ) -> Self {
231        let lt_air = IsLtSubAir::new(range_checker.bus(), clk_max_bits);
232        Self {
233            air: AccessAdapterAir::<N> { memory_bus, lt_air },
234            range_checker,
235            records: vec![],
236            overridden_height: None,
237        }
238    }
239}
240impl<F, const N: usize> GenericAccessAdapterChipTrait<F> for AccessAdapterChip<F, N> {
241    fn set_override_trace_heights(&mut self, overridden_height: usize) {
242        self.overridden_height = Some(overridden_height);
243    }
244    fn add_record(&mut self, record: AccessAdapterRecord<F>) {
245        self.records.push(record);
246    }
247    fn n(&self) -> usize {
248        N
249    }
250    fn generate_trace(self) -> RowMajorMatrix<F>
251    where
252        F: PrimeField32,
253    {
254        let width = BaseAir::<F>::width(&self.air);
255        let height = if let Some(oh) = self.overridden_height {
256            assert!(
257                oh >= self.records.len(),
258                "Overridden height is less than the required height"
259            );
260            oh
261        } else {
262            self.records.len()
263        };
264        let height = next_power_of_two_or_zero(height);
265        let mut values = F::zero_vec(height * width);
266
267        values
268            .par_chunks_mut(width)
269            .zip(self.records.into_par_iter())
270            .for_each(|(row, record)| {
271                let row: &mut AccessAdapterCols<F, N> = row.borrow_mut();
272
273                row.is_valid = F::ONE;
274                row.values = record.data.try_into().unwrap();
275                row.address = MemoryAddress::new(record.address_space, record.start_index);
276
277                let (left_timestamp, right_timestamp) = match record.kind {
278                    AccessAdapterRecordKind::Split => (record.timestamp, record.timestamp),
279                    AccessAdapterRecordKind::Merge {
280                        left_timestamp,
281                        right_timestamp,
282                    } => (left_timestamp, right_timestamp),
283                };
284                debug_assert_eq!(max(left_timestamp, right_timestamp), record.timestamp);
285
286                row.left_timestamp = F::from_canonical_u32(left_timestamp);
287                row.right_timestamp = F::from_canonical_u32(right_timestamp);
288                row.is_split = F::from_bool(record.kind == AccessAdapterRecordKind::Split);
289
290                self.air.lt_air.generate_subrow(
291                    (self.range_checker.as_ref(), left_timestamp, right_timestamp),
292                    (&mut row.lt_aux, &mut row.is_right_larger),
293                );
294            });
295        RowMajorMatrix::new(values, width)
296    }
297}
298
299impl<SC: StarkGenericConfig, const N: usize> Chip<SC> for AccessAdapterChip<Val<SC>, N>
300where
301    Val<SC>: PrimeField32,
302{
303    fn air(&self) -> AirRef<SC> {
304        Arc::new(self.air.clone())
305    }
306
307    fn generate_air_proof_input(self) -> AirProofInput<SC> {
308        let trace = self.generate_trace();
309        AirProofInput::simple_no_pis(trace)
310    }
311}
312
313impl<F, const N: usize> ChipUsageGetter for AccessAdapterChip<F, N> {
314    fn air_name(&self) -> String {
315        air_name(N)
316    }
317
318    fn current_trace_height(&self) -> usize {
319        self.records.len()
320    }
321
322    fn trace_width(&self) -> usize {
323        BaseAir::<F>::width(&self.air)
324    }
325}
326
327#[inline]
328fn air_name(n: usize) -> String {
329    format!("AccessAdapter<{}>", n)
330}