openvm_circuit/arch/testing/memory/
cuda.rs

1use std::{collections::HashMap, sync::Arc};
2
3use openvm_circuit::{
4    arch::{
5        testing::memory::air::{MemoryDummyAir, MemoryDummyChip},
6        MemoryConfig,
7    },
8    system::memory::{
9        offline_checker::{MemoryBridge, MemoryBus},
10        online::TracingMemory,
11    },
12};
13use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChipGPU};
14use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F};
15use openvm_cuda_common::copy::MemCopyH2D;
16use openvm_stark_backend::{
17    p3_field::{FieldAlgebra, PrimeField32},
18    prover::types::AirProvingContext,
19    Chip, ChipUsageGetter,
20};
21
22use crate::{
23    cuda_abi::memory_testing,
24    system::cuda::{memory::MemoryInventoryGPU, poseidon2::Poseidon2PeripheryChipGPU},
25};
26
27pub struct DeviceMemoryTester {
28    pub chip_for_block: HashMap<usize, FixedSizeMemoryTester>,
29    pub memory: TracingMemory,
30    pub inventory: MemoryInventoryGPU,
31    pub hasher_chip: Option<Arc<Poseidon2PeripheryChipGPU>>,
32
33    // Convenience fields, so we don't have to keep unwrapping
34    pub config: MemoryConfig,
35    pub mem_bus: MemoryBus,
36    pub range_bus: VariableRangeCheckerBus,
37}
38
39impl DeviceMemoryTester {
40    pub fn volatile(
41        memory: TracingMemory,
42        mem_bus: MemoryBus,
43        mem_config: MemoryConfig,
44        range_checker: Arc<VariableRangeCheckerChipGPU>,
45    ) -> Self {
46        let mut chip_for_block = HashMap::new();
47        for log_block_size in 0..6 {
48            let block_size = 1 << log_block_size;
49            chip_for_block.insert(block_size, FixedSizeMemoryTester::new(mem_bus, block_size));
50        }
51        let range_bus = range_checker.cpu_chip.as_ref().unwrap().bus();
52        Self {
53            chip_for_block,
54            memory,
55            inventory: MemoryInventoryGPU::volatile(mem_config.clone(), range_checker),
56            hasher_chip: None,
57            config: mem_config,
58            mem_bus,
59            range_bus,
60        }
61    }
62
63    pub fn persistent(
64        memory: TracingMemory,
65        mem_bus: MemoryBus,
66        mem_config: MemoryConfig,
67        range_checker: Arc<VariableRangeCheckerChipGPU>,
68    ) -> Self {
69        let mut chip_for_block = HashMap::new();
70        for log_block_size in 0..6 {
71            let block_size = 1 << log_block_size;
72            chip_for_block.insert(block_size, FixedSizeMemoryTester::new(mem_bus, block_size));
73        }
74        let range_bus = range_checker.cpu_chip.as_ref().unwrap().bus();
75        let sbox_regs = 1;
76        let poseidon2_periphery = Arc::new(Poseidon2PeripheryChipGPU::new(
77            1 << 20, // probably enough for our tests
78            sbox_regs,
79        ));
80        let mut inventory = MemoryInventoryGPU::persistent(
81            mem_config.clone(),
82            range_checker,
83            poseidon2_periphery.clone(),
84        );
85        inventory.set_initial_memory(&memory.data.memory);
86        Self {
87            chip_for_block,
88            memory,
89            inventory,
90            hasher_chip: Some(poseidon2_periphery),
91            config: mem_config,
92            mem_bus,
93            range_bus,
94        }
95    }
96
97    pub fn memory_bridge(&self) -> MemoryBridge {
98        MemoryBridge::new(self.mem_bus, self.config.timestamp_max_bits, self.range_bus)
99    }
100
101    pub fn read<const N: usize>(&mut self, addr_space: usize, ptr: usize) -> [F; N] {
102        let t = self.memory.timestamp();
103        let (t_prev, data) = if addr_space <= 3 {
104            let (t_prev, data) =
105                unsafe { self.memory.read::<u8, N, 4>(addr_space as u32, ptr as u32) };
106            (t_prev, data.map(F::from_canonical_u8))
107        } else {
108            unsafe { self.memory.read::<F, N, 1>(addr_space as u32, ptr as u32) }
109        };
110        self.chip_for_block.get_mut(&N).unwrap().receive(
111            addr_space as u32,
112            ptr as u32,
113            &data,
114            t_prev,
115        );
116        self.chip_for_block
117            .get_mut(&N)
118            .unwrap()
119            .send(addr_space as u32, ptr as u32, &data, t);
120        data
121    }
122
123    pub fn write<const N: usize>(&mut self, addr_space: usize, ptr: usize, data: [F; N]) {
124        let t = self.memory.timestamp();
125        let (t_prev, data_prev) = if addr_space <= 3 {
126            let (t_prev, data_prev) = unsafe {
127                self.memory.write::<u8, N, 4>(
128                    addr_space as u32,
129                    ptr as u32,
130                    data.map(|x| x.as_canonical_u32() as u8),
131                )
132            };
133            (t_prev, data_prev.map(F::from_canonical_u8))
134        } else {
135            unsafe {
136                self.memory
137                    .write::<F, N, 1>(addr_space as u32, ptr as u32, data)
138            }
139        };
140        self.chip_for_block.get_mut(&N).unwrap().receive(
141            addr_space as u32,
142            ptr as u32,
143            &data_prev,
144            t_prev,
145        );
146        self.chip_for_block
147            .get_mut(&N)
148            .unwrap()
149            .send(addr_space as u32, ptr as u32, &data, t);
150    }
151}
152
153pub struct FixedSizeMemoryTester(pub(crate) MemoryDummyChip<F>);
154
155impl FixedSizeMemoryTester {
156    pub fn new(bus: MemoryBus, block_size: usize) -> Self {
157        Self(MemoryDummyChip::new(MemoryDummyAir::new(bus, block_size)))
158    }
159
160    pub fn send(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) {
161        self.0.send(addr_space, ptr, data, timestamp);
162    }
163
164    pub fn receive(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) {
165        self.0.receive(addr_space, ptr, data, timestamp);
166    }
167
168    pub fn push(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32, count: F) {
169        self.0.push(addr_space, ptr, data, timestamp, count);
170    }
171}
172
173impl ChipUsageGetter for FixedSizeMemoryTester {
174    fn air_name(&self) -> String {
175        self.0.air_name()
176    }
177
178    fn current_trace_height(&self) -> usize {
179        self.0.current_trace_height()
180    }
181
182    fn trace_width(&self) -> usize {
183        self.0.trace_width()
184    }
185}
186
187impl<RA> Chip<RA, GpuBackend> for FixedSizeMemoryTester {
188    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
189        let height = self.0.current_trace_height().next_power_of_two();
190        let width = self.0.trace_width();
191
192        let mut records = self.0.trace.clone();
193        records.resize(height * width, F::ZERO);
194        let num_records = height;
195
196        let trace = DeviceMatrix::<F>::with_capacity(height, width);
197        unsafe {
198            memory_testing::tracegen(
199                trace.buffer(),
200                height,
201                width,
202                &records.to_device().unwrap(),
203                num_records,
204                self.0.air.block_size,
205            )
206            .unwrap();
207        }
208        AirProvingContext::simple_no_pis(trace)
209    }
210}