openvm_circuit/arch/testing/memory/
cuda.rs1use std::{collections::HashMap, sync::Arc};
2
3use openvm_circuit::{
4 arch::{
5 testing::memory::air::{MemoryDummyAir, MemoryDummyChip},
6 MemoryConfig,
7 },
8 system::memory::{
9 offline_checker::{MemoryBridge, MemoryBus},
10 online::TracingMemory,
11 },
12};
13use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChipGPU};
14use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F};
15use openvm_cuda_common::copy::MemCopyH2D;
16use openvm_stark_backend::{
17 p3_field::{FieldAlgebra, PrimeField32},
18 prover::types::AirProvingContext,
19 Chip, ChipUsageGetter,
20};
21
22use crate::{
23 cuda_abi::memory_testing,
24 system::cuda::{memory::MemoryInventoryGPU, poseidon2::Poseidon2PeripheryChipGPU},
25};
26
27pub struct DeviceMemoryTester {
28 pub chip_for_block: HashMap<usize, FixedSizeMemoryTester>,
29 pub memory: TracingMemory,
30 pub inventory: MemoryInventoryGPU,
31 pub hasher_chip: Option<Arc<Poseidon2PeripheryChipGPU>>,
32
33 pub config: MemoryConfig,
35 pub mem_bus: MemoryBus,
36 pub range_bus: VariableRangeCheckerBus,
37}
38
39impl DeviceMemoryTester {
40 pub fn volatile(
41 memory: TracingMemory,
42 mem_bus: MemoryBus,
43 mem_config: MemoryConfig,
44 range_checker: Arc<VariableRangeCheckerChipGPU>,
45 ) -> Self {
46 let mut chip_for_block = HashMap::new();
47 for log_block_size in 0..6 {
48 let block_size = 1 << log_block_size;
49 chip_for_block.insert(block_size, FixedSizeMemoryTester::new(mem_bus, block_size));
50 }
51 let range_bus = range_checker.cpu_chip.as_ref().unwrap().bus();
52 Self {
53 chip_for_block,
54 memory,
55 inventory: MemoryInventoryGPU::volatile(mem_config.clone(), range_checker),
56 hasher_chip: None,
57 config: mem_config,
58 mem_bus,
59 range_bus,
60 }
61 }
62
63 pub fn persistent(
64 memory: TracingMemory,
65 mem_bus: MemoryBus,
66 mem_config: MemoryConfig,
67 range_checker: Arc<VariableRangeCheckerChipGPU>,
68 ) -> Self {
69 let mut chip_for_block = HashMap::new();
70 for log_block_size in 0..6 {
71 let block_size = 1 << log_block_size;
72 chip_for_block.insert(block_size, FixedSizeMemoryTester::new(mem_bus, block_size));
73 }
74 let range_bus = range_checker.cpu_chip.as_ref().unwrap().bus();
75 let sbox_regs = 1;
76 let poseidon2_periphery = Arc::new(Poseidon2PeripheryChipGPU::new(
77 1 << 20, sbox_regs,
79 ));
80 let mut inventory = MemoryInventoryGPU::persistent(
81 mem_config.clone(),
82 range_checker,
83 poseidon2_periphery.clone(),
84 );
85 inventory.set_initial_memory(&memory.data.memory);
86 Self {
87 chip_for_block,
88 memory,
89 inventory,
90 hasher_chip: Some(poseidon2_periphery),
91 config: mem_config,
92 mem_bus,
93 range_bus,
94 }
95 }
96
97 pub fn memory_bridge(&self) -> MemoryBridge {
98 MemoryBridge::new(self.mem_bus, self.config.timestamp_max_bits, self.range_bus)
99 }
100
101 pub fn read<const N: usize>(&mut self, addr_space: usize, ptr: usize) -> [F; N] {
102 let t = self.memory.timestamp();
103 let (t_prev, data) = if addr_space <= 3 {
104 let (t_prev, data) =
105 unsafe { self.memory.read::<u8, N, 4>(addr_space as u32, ptr as u32) };
106 (t_prev, data.map(F::from_canonical_u8))
107 } else {
108 unsafe { self.memory.read::<F, N, 1>(addr_space as u32, ptr as u32) }
109 };
110 self.chip_for_block.get_mut(&N).unwrap().receive(
111 addr_space as u32,
112 ptr as u32,
113 &data,
114 t_prev,
115 );
116 self.chip_for_block
117 .get_mut(&N)
118 .unwrap()
119 .send(addr_space as u32, ptr as u32, &data, t);
120 data
121 }
122
123 pub fn write<const N: usize>(&mut self, addr_space: usize, ptr: usize, data: [F; N]) {
124 let t = self.memory.timestamp();
125 let (t_prev, data_prev) = if addr_space <= 3 {
126 let (t_prev, data_prev) = unsafe {
127 self.memory.write::<u8, N, 4>(
128 addr_space as u32,
129 ptr as u32,
130 data.map(|x| x.as_canonical_u32() as u8),
131 )
132 };
133 (t_prev, data_prev.map(F::from_canonical_u8))
134 } else {
135 unsafe {
136 self.memory
137 .write::<F, N, 1>(addr_space as u32, ptr as u32, data)
138 }
139 };
140 self.chip_for_block.get_mut(&N).unwrap().receive(
141 addr_space as u32,
142 ptr as u32,
143 &data_prev,
144 t_prev,
145 );
146 self.chip_for_block
147 .get_mut(&N)
148 .unwrap()
149 .send(addr_space as u32, ptr as u32, &data, t);
150 }
151}
152
153pub struct FixedSizeMemoryTester(pub(crate) MemoryDummyChip<F>);
154
155impl FixedSizeMemoryTester {
156 pub fn new(bus: MemoryBus, block_size: usize) -> Self {
157 Self(MemoryDummyChip::new(MemoryDummyAir::new(bus, block_size)))
158 }
159
160 pub fn send(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) {
161 self.0.send(addr_space, ptr, data, timestamp);
162 }
163
164 pub fn receive(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) {
165 self.0.receive(addr_space, ptr, data, timestamp);
166 }
167
168 pub fn push(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32, count: F) {
169 self.0.push(addr_space, ptr, data, timestamp, count);
170 }
171}
172
173impl ChipUsageGetter for FixedSizeMemoryTester {
174 fn air_name(&self) -> String {
175 self.0.air_name()
176 }
177
178 fn current_trace_height(&self) -> usize {
179 self.0.current_trace_height()
180 }
181
182 fn trace_width(&self) -> usize {
183 self.0.trace_width()
184 }
185}
186
187impl<RA> Chip<RA, GpuBackend> for FixedSizeMemoryTester {
188 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
189 let height = self.0.current_trace_height().next_power_of_two();
190 let width = self.0.trace_width();
191
192 let mut records = self.0.trace.clone();
193 records.resize(height * width, F::ZERO);
194 let num_records = height;
195
196 let trace = DeviceMatrix::<F>::with_capacity(height, width);
197 unsafe {
198 memory_testing::tracegen(
199 trace.buffer(),
200 height,
201 width,
202 &records.to_device().unwrap(),
203 num_records,
204 self.0.air.block_size,
205 )
206 .unwrap();
207 }
208 AirProvingContext::simple_no_pis(trace)
209 }
210}