1use std::{
2 borrow::{Borrow, BorrowMut},
3 marker::PhantomData,
4 ptr::copy_nonoverlapping,
5 sync::Arc,
6};
7
8pub use air::*;
9pub use columns::*;
10use enum_dispatch::enum_dispatch;
11use getset::Setters;
12use openvm_circuit_primitives::{
13 is_less_than::IsLtSubAir, utils::next_power_of_two_or_zero,
14 var_range::SharedVariableRangeCheckerChip, TraceSubRowGenerator,
15};
16use openvm_stark_backend::{
17 config::{Domain, StarkGenericConfig},
18 p3_air::BaseAir,
19 p3_commit::PolynomialSpace,
20 p3_field::PrimeField32,
21 p3_matrix::dense::RowMajorMatrix,
22 p3_util::log2_strict_usize,
23 prover::{cpu::CpuBackend, types::AirProvingContext},
24};
25
26use crate::{
27 arch::{
28 AddressSpaceHostConfig, AddressSpaceHostLayout, CustomBorrow, DenseRecordArena,
29 MemoryCellType, MemoryConfig, SizedRecord,
30 },
31 system::memory::{
32 adapter::records::{
33 arena_size_bound, AccessLayout, AccessRecordHeader, AccessRecordMut,
34 MERGE_AND_NOT_SPLIT_FLAG,
35 },
36 offline_checker::MemoryBus,
37 MemoryAddress,
38 },
39};
40
41mod air;
42mod columns;
43pub mod records;
44
45#[derive(Setters)]
46pub struct AccessAdapterInventory<F> {
47 pub(super) memory_config: MemoryConfig,
48 chips: Vec<GenericAccessAdapterChip<F>>,
49 #[getset(set = "pub")]
50 arena: DenseRecordArena,
51 #[cfg(feature = "metrics")]
52 pub(crate) trace_heights: Vec<usize>,
53}
54
55impl<F: Clone + Send + Sync> AccessAdapterInventory<F> {
56 pub fn new(
57 range_checker: SharedVariableRangeCheckerChip,
58 memory_bus: MemoryBus,
59 memory_config: MemoryConfig,
60 ) -> Self {
61 let rc = range_checker;
62 let mb = memory_bus;
63 let tmb = memory_config.timestamp_max_bits;
64 let maan = memory_config.max_access_adapter_n;
65 assert!(matches!(maan, 2 | 4 | 8 | 16 | 32));
66 let chips: Vec<_> = [
67 Self::create_access_adapter_chip::<2>(rc.clone(), mb, tmb, maan),
68 Self::create_access_adapter_chip::<4>(rc.clone(), mb, tmb, maan),
69 Self::create_access_adapter_chip::<8>(rc.clone(), mb, tmb, maan),
70 Self::create_access_adapter_chip::<16>(rc.clone(), mb, tmb, maan),
71 Self::create_access_adapter_chip::<32>(rc.clone(), mb, tmb, maan),
72 ]
73 .into_iter()
74 .flatten()
75 .collect();
76 Self {
77 memory_config,
78 chips,
79 arena: DenseRecordArena::with_byte_capacity(0),
80 #[cfg(feature = "metrics")]
81 trace_heights: Vec::new(),
82 }
83 }
84
85 pub fn num_access_adapters(&self) -> usize {
86 self.chips.len()
87 }
88
89 pub(super) fn set_override_trace_heights(&mut self, overridden_heights: Vec<usize>) {
90 self.set_arena_from_trace_heights(
91 &overridden_heights
92 .iter()
93 .map(|&h| h as u32)
94 .collect::<Vec<_>>(),
95 );
96 for (chip, oh) in self.chips.iter_mut().zip(overridden_heights) {
97 chip.set_override_trace_height(oh);
98 }
99 }
100
101 pub(super) fn set_arena_from_trace_heights(&mut self, trace_heights: &[u32]) {
102 assert_eq!(trace_heights.len(), self.chips.len());
103 let size_bound = arena_size_bound(trace_heights);
104 tracing::debug!(
105 "Allocating {} bytes for memory adapters arena from heights {:?}",
106 size_bound,
107 trace_heights
108 );
109 self.arena.set_byte_capacity(size_bound);
110 }
111
112 pub fn get_widths(&self) -> Vec<usize> {
113 self.chips
114 .iter()
115 .map(|chip: &GenericAccessAdapterChip<F>| chip.trace_width())
116 .collect()
117 }
118
119 pub(crate) fn compute_heights_from_arena(arena: &DenseRecordArena, heights: &mut [usize]) {
121 let bytes = arena.allocated();
122 tracing::debug!(
123 "Computing heights from memory adapters arena: used {} bytes",
124 bytes.len()
125 );
126 let mut ptr = 0;
127 while ptr < bytes.len() {
128 let bytes_slice = &bytes[ptr..];
129 let header: &AccessRecordHeader = bytes_slice.borrow();
130 let layout: AccessLayout = unsafe { bytes_slice.extract_layout() };
134 ptr += <AccessRecordMut<'_> as SizedRecord<AccessLayout>>::size(&layout);
135
136 let log_max_block_size = log2_strict_usize(header.block_size as usize);
137 for (i, h) in heights
138 .iter_mut()
139 .enumerate()
140 .take(log_max_block_size)
141 .skip(log2_strict_usize(header.lowest_block_size as usize))
142 {
143 *h += 1 << (log_max_block_size - i - 1);
144 }
145 }
146 tracing::debug!("Computed heights from memory adapters arena: {:?}", heights);
147 }
148
149 fn apply_overridden_heights(&mut self, heights: &mut [usize]) {
150 for (i, h) in heights.iter_mut().enumerate() {
151 if let Some(oh) = self.chips[i].overridden_trace_height() {
152 assert!(
153 oh >= *h,
154 "Overridden height {oh} is less than the required height {}",
155 *h
156 );
157 *h = oh;
158 }
159 *h = next_power_of_two_or_zero(*h);
160 }
161 }
162
163 pub fn generate_proving_ctx<SC: StarkGenericConfig>(
164 &mut self,
165 ) -> Vec<AirProvingContext<CpuBackend<SC>>>
166 where
167 F: PrimeField32,
168 Domain<SC>: PolynomialSpace<Val = F>,
169 {
170 let num_adapters = self.chips.len();
171
172 let mut heights = vec![0; num_adapters];
173 Self::compute_heights_from_arena(&self.arena, &mut heights);
174 self.apply_overridden_heights(&mut heights);
175
176 let widths = self
177 .chips
178 .iter()
179 .map(|chip| chip.trace_width())
180 .collect::<Vec<_>>();
181 let mut traces = widths
182 .iter()
183 .zip(heights.iter())
184 .map(|(&width, &height)| RowMajorMatrix::new(vec![F::ZERO; width * height], width))
185 .collect::<Vec<_>>();
186 #[cfg(feature = "metrics")]
187 {
188 self.trace_heights = heights;
189 }
190
191 let mut trace_ptrs = vec![0; num_adapters];
192
193 let bytes = self.arena.allocated_mut();
194 let mut ptr = 0;
195 while ptr < bytes.len() {
196 let bytes_slice = &mut bytes[ptr..];
197 let layout: AccessLayout = unsafe { bytes_slice.extract_layout() };
201 let record: AccessRecordMut<'_> = bytes_slice.custom_borrow(layout.clone());
202 ptr += <AccessRecordMut<'_> as SizedRecord<AccessLayout>>::size(&layout);
203
204 let log_min_block_size = log2_strict_usize(record.header.lowest_block_size as usize);
205 let log_max_block_size = log2_strict_usize(record.header.block_size as usize);
206
207 if record.header.timestamp_and_mask & MERGE_AND_NOT_SPLIT_FLAG != 0 {
208 for i in log_min_block_size..log_max_block_size {
209 let data_len = layout.type_size << i;
210 let ts_len = 1 << (i - log_min_block_size);
211 for j in 0..record.data.len() / (2 * data_len) {
212 let row_slice =
213 &mut traces[i].values[trace_ptrs[i]..trace_ptrs[i] + widths[i]];
214 trace_ptrs[i] += widths[i];
215 self.chips[i].fill_trace_row(
216 &self.memory_config.addr_spaces,
217 row_slice,
218 false,
219 MemoryAddress::new(
220 record.header.address_space,
221 record.header.pointer + (j << (i + 1)) as u32,
222 ),
223 &record.data[j * 2 * data_len..(j + 1) * 2 * data_len],
224 *record.timestamps[2 * j * ts_len..(2 * j + 1) * ts_len]
225 .iter()
226 .max()
227 .unwrap(),
228 *record.timestamps[(2 * j + 1) * ts_len..(2 * j + 2) * ts_len]
229 .iter()
230 .max()
231 .unwrap(),
232 );
233 }
234 }
235 } else {
236 let timestamp = record.header.timestamp_and_mask;
237 for i in log_min_block_size..log_max_block_size {
238 let data_len = layout.type_size << i;
239 for j in 0..record.data.len() / (2 * data_len) {
240 let row_slice =
241 &mut traces[i].values[trace_ptrs[i]..trace_ptrs[i] + widths[i]];
242 trace_ptrs[i] += widths[i];
243 self.chips[i].fill_trace_row(
244 &self.memory_config.addr_spaces,
245 row_slice,
246 true,
247 MemoryAddress::new(
248 record.header.address_space,
249 record.header.pointer + (j << (i + 1)) as u32,
250 ),
251 &record.data[j * 2 * data_len..(j + 1) * 2 * data_len],
252 timestamp,
253 timestamp,
254 );
255 }
256 }
257 }
258 }
259 traces
260 .into_iter()
261 .map(|trace| AirProvingContext::simple_no_pis(Arc::new(trace)))
262 .collect()
263 }
264
265 fn create_access_adapter_chip<const N: usize>(
266 range_checker: SharedVariableRangeCheckerChip,
267 memory_bus: MemoryBus,
268 timestamp_max_bits: usize,
269 max_access_adapter_n: usize,
270 ) -> Option<GenericAccessAdapterChip<F>>
271 where
272 F: Clone + Send + Sync,
273 {
274 if N <= max_access_adapter_n {
275 Some(GenericAccessAdapterChip::new::<N>(
276 range_checker,
277 memory_bus,
278 timestamp_max_bits,
279 ))
280 } else {
281 None
282 }
283 }
284}
285
286#[enum_dispatch]
287pub(crate) trait GenericAccessAdapterChipTrait<F> {
288 fn trace_width(&self) -> usize;
289 fn set_override_trace_height(&mut self, overridden_height: usize);
290 fn overridden_trace_height(&self) -> Option<usize>;
291
292 #[allow(clippy::too_many_arguments)]
293 fn fill_trace_row(
294 &self,
295 addr_spaces: &[AddressSpaceHostConfig],
296 row: &mut [F],
297 is_split: bool,
298 address: MemoryAddress<u32, u32>,
299 values: &[u8],
300 left_timestamp: u32,
301 right_timestamp: u32,
302 ) where
303 F: PrimeField32;
304}
305
306#[enum_dispatch(GenericAccessAdapterChipTrait<F>)]
307enum GenericAccessAdapterChip<F> {
308 N2(AccessAdapterChip<F, 2>),
309 N4(AccessAdapterChip<F, 4>),
310 N8(AccessAdapterChip<F, 8>),
311 N16(AccessAdapterChip<F, 16>),
312 N32(AccessAdapterChip<F, 32>),
313}
314
315impl<F: Clone + Send + Sync> GenericAccessAdapterChip<F> {
316 fn new<const N: usize>(
317 range_checker: SharedVariableRangeCheckerChip,
318 memory_bus: MemoryBus,
319 timestamp_max_bits: usize,
320 ) -> Self {
321 let rc = range_checker;
322 let mb = memory_bus;
323 let cmb = timestamp_max_bits;
324 match N {
325 2 => GenericAccessAdapterChip::N2(AccessAdapterChip::new(rc, mb, cmb)),
326 4 => GenericAccessAdapterChip::N4(AccessAdapterChip::new(rc, mb, cmb)),
327 8 => GenericAccessAdapterChip::N8(AccessAdapterChip::new(rc, mb, cmb)),
328 16 => GenericAccessAdapterChip::N16(AccessAdapterChip::new(rc, mb, cmb)),
329 32 => GenericAccessAdapterChip::N32(AccessAdapterChip::new(rc, mb, cmb)),
330 _ => panic!("Only supports N in (2, 4, 8, 16, 32)"),
331 }
332 }
333}
334
335pub(crate) struct AccessAdapterChip<F, const N: usize> {
336 air: AccessAdapterAir<N>,
337 range_checker: SharedVariableRangeCheckerChip,
338 overridden_height: Option<usize>,
339 _marker: PhantomData<F>,
340}
341
342impl<F: Clone + Send + Sync, const N: usize> AccessAdapterChip<F, N> {
343 pub fn new(
344 range_checker: SharedVariableRangeCheckerChip,
345 memory_bus: MemoryBus,
346 timestamp_max_bits: usize,
347 ) -> Self {
348 let lt_air = IsLtSubAir::new(range_checker.bus(), timestamp_max_bits);
349 Self {
350 air: AccessAdapterAir::<N> { memory_bus, lt_air },
351 range_checker,
352 overridden_height: None,
353 _marker: PhantomData,
354 }
355 }
356}
357impl<F, const N: usize> GenericAccessAdapterChipTrait<F> for AccessAdapterChip<F, N> {
358 fn trace_width(&self) -> usize {
359 BaseAir::<F>::width(&self.air)
360 }
361
362 fn set_override_trace_height(&mut self, overridden_height: usize) {
363 self.overridden_height = Some(overridden_height);
364 }
365
366 fn overridden_trace_height(&self) -> Option<usize> {
367 self.overridden_height
368 }
369
370 fn fill_trace_row(
371 &self,
372 addr_spaces: &[AddressSpaceHostConfig],
373 row: &mut [F],
374 is_split: bool,
375 address: MemoryAddress<u32, u32>,
376 values: &[u8],
377 left_timestamp: u32,
378 right_timestamp: u32,
379 ) where
380 F: PrimeField32,
381 {
382 let row: &mut AccessAdapterCols<F, N> = row.borrow_mut();
383 row.is_valid = F::ONE;
384 row.is_split = F::from_bool(is_split);
385 row.address = MemoryAddress::new(
386 F::from_canonical_u32(address.address_space),
387 F::from_canonical_u32(address.pointer),
388 );
389 let addr_space_layout = addr_spaces[address.address_space as usize].layout;
390 unsafe {
392 match addr_space_layout {
393 MemoryCellType::Native { .. } => {
394 copy_nonoverlapping(
395 values.as_ptr(),
396 row.values.as_mut_ptr() as *mut u8,
397 N * size_of::<F>(),
398 );
399 }
400 _ => {
401 for (dst, src) in row
402 .values
403 .iter_mut()
404 .zip(values.chunks_exact(addr_space_layout.size()))
405 {
406 *dst = addr_space_layout.to_field(src);
407 }
408 }
409 }
410 }
411 row.left_timestamp = F::from_canonical_u32(left_timestamp);
412 row.right_timestamp = F::from_canonical_u32(right_timestamp);
413 self.air.lt_air.generate_subrow(
414 (self.range_checker.as_ref(), left_timestamp, right_timestamp),
415 (&mut row.lt_aux, &mut row.is_right_larger),
416 );
417 }
418}