1use std::{borrow::BorrowMut, cmp::max, sync::Arc};
2
3pub use air::*;
4pub use columns::*;
5use enum_dispatch::enum_dispatch;
6use openvm_circuit_primitives::{
7 is_less_than::IsLtSubAir, utils::next_power_of_two_or_zero,
8 var_range::SharedVariableRangeCheckerChip, TraceSubRowGenerator,
9};
10use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
11use openvm_stark_backend::{
12 config::{Domain, StarkGenericConfig, Val},
13 p3_air::BaseAir,
14 p3_commit::PolynomialSpace,
15 p3_field::PrimeField32,
16 p3_matrix::dense::RowMajorMatrix,
17 p3_maybe_rayon::prelude::*,
18 p3_util::log2_strict_usize,
19 prover::types::AirProofInput,
20 AirRef, Chip, ChipUsageGetter,
21};
22
23use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress};
24
25mod air;
26mod columns;
27#[cfg(test)]
28mod tests;
29
30pub struct AccessAdapterInventory<F> {
31 chips: Vec<GenericAccessAdapterChip<F>>,
32 air_names: Vec<String>,
33}
34
35impl<F> AccessAdapterInventory<F> {
36 pub fn new(
37 range_checker: SharedVariableRangeCheckerChip,
38 memory_bus: MemoryBus,
39 clk_max_bits: usize,
40 max_access_adapter_n: usize,
41 ) -> Self {
42 let rc = range_checker;
43 let mb = memory_bus;
44 let cmb = clk_max_bits;
45 let maan = max_access_adapter_n;
46 assert!(matches!(maan, 2 | 4 | 8 | 16 | 32));
47 let chips: Vec<_> = [
48 Self::create_access_adapter_chip::<2>(rc.clone(), mb, cmb, maan),
49 Self::create_access_adapter_chip::<4>(rc.clone(), mb, cmb, maan),
50 Self::create_access_adapter_chip::<8>(rc.clone(), mb, cmb, maan),
51 Self::create_access_adapter_chip::<16>(rc.clone(), mb, cmb, maan),
52 Self::create_access_adapter_chip::<32>(rc.clone(), mb, cmb, maan),
53 ]
54 .into_iter()
55 .flatten()
56 .collect();
57 let air_names = (0..chips.len()).map(|i| air_name(1 << (i + 1))).collect();
58 Self { chips, air_names }
59 }
60 pub fn num_access_adapters(&self) -> usize {
61 self.chips.len()
62 }
63 pub fn set_override_trace_heights(&mut self, overridden_heights: Vec<usize>) {
64 assert_eq!(overridden_heights.len(), self.chips.len());
65 for (chip, oh) in self.chips.iter_mut().zip(overridden_heights) {
66 chip.set_override_trace_heights(oh);
67 }
68 }
69 pub fn add_record(&mut self, record: AccessAdapterRecord<F>) {
70 let n = record.data.len();
71 let idx = log2_strict_usize(n) - 1;
72 let chip = &mut self.chips[idx];
73 debug_assert!(chip.n() == n);
74 chip.add_record(record);
75 }
76
77 pub fn extend_records(&mut self, records: Vec<AccessAdapterRecord<F>>) {
78 for record in records {
79 self.add_record(record);
80 }
81 }
82
83 #[cfg(test)]
84 pub fn records_for_n(&self, n: usize) -> &[AccessAdapterRecord<F>] {
85 let idx = log2_strict_usize(n) - 1;
86 let chip = &self.chips[idx];
87 chip.records()
88 }
89
90 #[cfg(test)]
91 pub fn total_records(&self) -> usize {
92 self.chips.iter().map(|chip| chip.records().len()).sum()
93 }
94
95 pub fn get_heights(&self) -> Vec<usize> {
96 self.chips
97 .iter()
98 .map(|chip| chip.current_trace_height())
99 .collect()
100 }
101 #[allow(dead_code)]
102 pub fn get_widths(&self) -> Vec<usize> {
103 self.chips.iter().map(|chip| chip.trace_width()).collect()
104 }
105 pub fn get_cells(&self) -> Vec<usize> {
106 self.chips
107 .iter()
108 .map(|chip| chip.current_trace_cells())
109 .collect()
110 }
111 pub fn airs<SC: StarkGenericConfig>(&self) -> Vec<AirRef<SC>>
112 where
113 F: PrimeField32,
114 Domain<SC>: PolynomialSpace<Val = F>,
115 {
116 self.chips.iter().map(|chip| chip.air()).collect()
117 }
118 pub fn air_names(&self) -> Vec<String> {
119 self.air_names.clone()
120 }
121 pub fn generate_air_proof_inputs<SC: StarkGenericConfig>(self) -> Vec<AirProofInput<SC>>
122 where
123 F: PrimeField32,
124 Domain<SC>: PolynomialSpace<Val = F>,
125 {
126 self.chips
127 .into_iter()
128 .map(|chip| chip.generate_air_proof_input())
129 .collect()
130 }
131
132 fn create_access_adapter_chip<const N: usize>(
133 range_checker: SharedVariableRangeCheckerChip,
134 memory_bus: MemoryBus,
135 clk_max_bits: usize,
136 max_access_adapter_n: usize,
137 ) -> Option<GenericAccessAdapterChip<F>> {
138 if N <= max_access_adapter_n {
139 Some(GenericAccessAdapterChip::new::<N>(
140 range_checker,
141 memory_bus,
142 clk_max_bits,
143 ))
144 } else {
145 None
146 }
147 }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq)]
151pub enum AccessAdapterRecordKind {
152 Split,
153 Merge {
154 left_timestamp: u32,
155 right_timestamp: u32,
156 },
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub struct AccessAdapterRecord<T> {
161 pub timestamp: u32,
162 pub address_space: T,
163 pub start_index: T,
164 pub data: Vec<T>,
165 pub kind: AccessAdapterRecordKind,
166}
167
168#[enum_dispatch]
169pub trait GenericAccessAdapterChipTrait<F> {
170 fn set_override_trace_heights(&mut self, overridden_height: usize);
171 fn add_record(&mut self, record: AccessAdapterRecord<F>);
172 fn n(&self) -> usize;
173 fn generate_trace(self) -> RowMajorMatrix<F>
174 where
175 F: PrimeField32;
176}
177
178#[derive(Chip, ChipUsageGetter)]
179#[enum_dispatch(GenericAccessAdapterChipTrait<F>)]
180#[chip(where = "F: PrimeField32")]
181enum GenericAccessAdapterChip<F> {
182 N2(AccessAdapterChip<F, 2>),
183 N4(AccessAdapterChip<F, 4>),
184 N8(AccessAdapterChip<F, 8>),
185 N16(AccessAdapterChip<F, 16>),
186 N32(AccessAdapterChip<F, 32>),
187}
188
189impl<F> GenericAccessAdapterChip<F> {
190 fn new<const N: usize>(
191 range_checker: SharedVariableRangeCheckerChip,
192 memory_bus: MemoryBus,
193 clk_max_bits: usize,
194 ) -> Self {
195 let rc = range_checker;
196 let mb = memory_bus;
197 let cmb = clk_max_bits;
198 match N {
199 2 => GenericAccessAdapterChip::N2(AccessAdapterChip::new(rc, mb, cmb)),
200 4 => GenericAccessAdapterChip::N4(AccessAdapterChip::new(rc, mb, cmb)),
201 8 => GenericAccessAdapterChip::N8(AccessAdapterChip::new(rc, mb, cmb)),
202 16 => GenericAccessAdapterChip::N16(AccessAdapterChip::new(rc, mb, cmb)),
203 32 => GenericAccessAdapterChip::N32(AccessAdapterChip::new(rc, mb, cmb)),
204 _ => panic!("Only supports N in (2, 4, 8, 16, 32)"),
205 }
206 }
207
208 #[cfg(test)]
209 fn records(&self) -> &[AccessAdapterRecord<F>] {
210 match &self {
211 GenericAccessAdapterChip::N2(chip) => &chip.records,
212 GenericAccessAdapterChip::N4(chip) => &chip.records,
213 GenericAccessAdapterChip::N8(chip) => &chip.records,
214 GenericAccessAdapterChip::N16(chip) => &chip.records,
215 GenericAccessAdapterChip::N32(chip) => &chip.records,
216 }
217 }
218}
219pub struct AccessAdapterChip<F, const N: usize> {
220 air: AccessAdapterAir<N>,
221 range_checker: SharedVariableRangeCheckerChip,
222 pub records: Vec<AccessAdapterRecord<F>>,
223 overridden_height: Option<usize>,
224}
225impl<F, const N: usize> AccessAdapterChip<F, N> {
226 pub fn new(
227 range_checker: SharedVariableRangeCheckerChip,
228 memory_bus: MemoryBus,
229 clk_max_bits: usize,
230 ) -> Self {
231 let lt_air = IsLtSubAir::new(range_checker.bus(), clk_max_bits);
232 Self {
233 air: AccessAdapterAir::<N> { memory_bus, lt_air },
234 range_checker,
235 records: vec![],
236 overridden_height: None,
237 }
238 }
239}
240impl<F, const N: usize> GenericAccessAdapterChipTrait<F> for AccessAdapterChip<F, N> {
241 fn set_override_trace_heights(&mut self, overridden_height: usize) {
242 self.overridden_height = Some(overridden_height);
243 }
244 fn add_record(&mut self, record: AccessAdapterRecord<F>) {
245 self.records.push(record);
246 }
247 fn n(&self) -> usize {
248 N
249 }
250 fn generate_trace(self) -> RowMajorMatrix<F>
251 where
252 F: PrimeField32,
253 {
254 let width = BaseAir::<F>::width(&self.air);
255 let height = if let Some(oh) = self.overridden_height {
256 assert!(
257 oh >= self.records.len(),
258 "Overridden height is less than the required height"
259 );
260 oh
261 } else {
262 self.records.len()
263 };
264 let height = next_power_of_two_or_zero(height);
265 let mut values = F::zero_vec(height * width);
266
267 values
268 .par_chunks_mut(width)
269 .zip(self.records.into_par_iter())
270 .for_each(|(row, record)| {
271 let row: &mut AccessAdapterCols<F, N> = row.borrow_mut();
272
273 row.is_valid = F::ONE;
274 row.values = record.data.try_into().unwrap();
275 row.address = MemoryAddress::new(record.address_space, record.start_index);
276
277 let (left_timestamp, right_timestamp) = match record.kind {
278 AccessAdapterRecordKind::Split => (record.timestamp, record.timestamp),
279 AccessAdapterRecordKind::Merge {
280 left_timestamp,
281 right_timestamp,
282 } => (left_timestamp, right_timestamp),
283 };
284 debug_assert_eq!(max(left_timestamp, right_timestamp), record.timestamp);
285
286 row.left_timestamp = F::from_canonical_u32(left_timestamp);
287 row.right_timestamp = F::from_canonical_u32(right_timestamp);
288 row.is_split = F::from_bool(record.kind == AccessAdapterRecordKind::Split);
289
290 self.air.lt_air.generate_subrow(
291 (self.range_checker.as_ref(), left_timestamp, right_timestamp),
292 (&mut row.lt_aux, &mut row.is_right_larger),
293 );
294 });
295 RowMajorMatrix::new(values, width)
296 }
297}
298
299impl<SC: StarkGenericConfig, const N: usize> Chip<SC> for AccessAdapterChip<Val<SC>, N>
300where
301 Val<SC>: PrimeField32,
302{
303 fn air(&self) -> AirRef<SC> {
304 Arc::new(self.air.clone())
305 }
306
307 fn generate_air_proof_input(self) -> AirProofInput<SC> {
308 let trace = self.generate_trace();
309 AirProofInput::simple_no_pis(trace)
310 }
311}
312
313impl<F, const N: usize> ChipUsageGetter for AccessAdapterChip<F, N> {
314 fn air_name(&self) -> String {
315 air_name(N)
316 }
317
318 fn current_trace_height(&self) -> usize {
319 self.records.len()
320 }
321
322 fn trace_width(&self) -> usize {
323 BaseAir::<F>::width(&self.air)
324 }
325}
326
327#[inline]
328fn air_name(n: usize) -> String {
329 format!("AccessAdapter<{}>", n)
330}