openvm_circuit/arch/testing/memory/
mod.rs
1use std::{array::from_fn, borrow::BorrowMut as _, cell::RefCell, mem::size_of, rc::Rc, sync::Arc};
2
3use air::{DummyMemoryInteractionCols, MemoryDummyAir};
4use openvm_circuit::system::memory::MemoryController;
5use openvm_stark_backend::{
6 config::{StarkGenericConfig, Val},
7 p3_field::{FieldAlgebra, PrimeField32},
8 p3_matrix::dense::RowMajorMatrix,
9 prover::types::AirProofInput,
10 AirRef, Chip, ChipUsageGetter,
11};
12use rand::{seq::SliceRandom, Rng};
13
14use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress, RecordId};
15
16pub mod air;
17
18const WORD_SIZE: usize = 1;
19
20pub struct MemoryTester<F> {
25 pub bus: MemoryBus,
26 pub controller: Rc<RefCell<MemoryController<F>>>,
27 pub records: Vec<RecordId>,
29}
30
31impl<F: PrimeField32> MemoryTester<F> {
32 pub fn new(controller: Rc<RefCell<MemoryController<F>>>) -> Self {
33 let bus = controller.borrow().memory_bus;
34 Self {
35 bus,
36 controller,
37 records: Vec::new(),
38 }
39 }
40
41 pub fn read_cell(&mut self, address_space: usize, pointer: usize) -> F {
43 let [addr_space, pointer] = [address_space, pointer].map(F::from_canonical_usize);
44 let (record_id, value) =
46 RefCell::borrow_mut(&self.controller).read_cell(addr_space, pointer);
47 self.records.push(record_id);
48 value
49 }
50
51 pub fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) {
52 let [addr_space, pointer] = [address_space, pointer].map(F::from_canonical_usize);
53 let (record_id, _) =
54 RefCell::borrow_mut(&self.controller).write_cell(addr_space, pointer, value);
55 self.records.push(record_id);
56 }
57
58 pub fn read<const N: usize>(&mut self, address_space: usize, pointer: usize) -> [F; N] {
59 from_fn(|i| self.read_cell(address_space, pointer + i))
60 }
61
62 pub fn write<const N: usize>(
63 &mut self,
64 address_space: usize,
65 mut pointer: usize,
66 cells: [F; N],
67 ) {
68 for cell in cells {
69 self.write_cell(address_space, pointer, cell);
70 pointer += 1;
71 }
72 }
73}
74
75impl<SC: StarkGenericConfig> Chip<SC> for MemoryTester<Val<SC>>
76where
77 Val<SC>: PrimeField32,
78{
79 fn air(&self) -> AirRef<SC> {
80 Arc::new(MemoryDummyAir::<WORD_SIZE>::new(self.bus))
81 }
82
83 fn generate_air_proof_input(self) -> AirProofInput<SC> {
84 let offline_memory = self.controller.borrow().offline_memory();
85 let offline_memory = offline_memory.lock().unwrap();
86
87 let height = self.records.len().next_power_of_two();
88 let width = self.trace_width();
89 let mut values = Val::<SC>::zero_vec(2 * height * width);
90 for (row, id) in values.chunks_mut(2 * width).zip(self.records) {
93 let (first, second) = row.split_at_mut(width);
94 let row: &mut DummyMemoryInteractionCols<Val<SC>, WORD_SIZE> = first.borrow_mut();
95 let record = offline_memory.record_by_id(id);
96 row.address = MemoryAddress {
97 address_space: record.address_space,
98 pointer: record.pointer,
99 };
100 row.data
101 .copy_from_slice(record.prev_data_slice().unwrap_or(record.data_slice()));
102 row.timestamp = Val::<SC>::from_canonical_u32(record.prev_timestamp);
103 row.count = -Val::<SC>::ONE;
104
105 let row: &mut DummyMemoryInteractionCols<Val<SC>, WORD_SIZE> = second.borrow_mut();
106 row.address = MemoryAddress {
107 address_space: record.address_space,
108 pointer: record.pointer,
109 };
110 row.data.copy_from_slice(record.data_slice());
111 row.timestamp = Val::<SC>::from_canonical_u32(record.timestamp);
112 row.count = Val::<SC>::ONE;
113 }
114 AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width))
115 }
116}
117
118impl<F: PrimeField32> ChipUsageGetter for MemoryTester<F> {
119 fn air_name(&self) -> String {
120 "MemoryDummyAir".to_string()
121 }
122 fn current_trace_height(&self) -> usize {
123 self.records.len()
124 }
125
126 fn trace_width(&self) -> usize {
127 size_of::<DummyMemoryInteractionCols<u8, WORD_SIZE>>()
128 }
129}
130
131pub fn gen_address_space<R>(rng: &mut R) -> usize
132where
133 R: Rng + ?Sized,
134{
135 *[1, 2].choose(rng).unwrap()
136}
137
138pub fn gen_pointer<R>(rng: &mut R, len: usize) -> usize
139where
140 R: Rng + ?Sized,
141{
142 const MAX_MEMORY: usize = 1 << 29;
143 rng.gen_range(0..MAX_MEMORY - len) / len * len
144}