openvm_circuit/system/cuda/
access_adapters.rs1use std::{ptr::null_mut, sync::Arc};
2
3use openvm_circuit::{
4 arch::{CustomBorrow, DenseRecordArena, SizedRecord},
5 system::memory::adapter::{
6 records::{AccessLayout, AccessRecordMut},
7 AccessAdapterCols,
8 },
9 utils::next_power_of_two_or_zero,
10};
11use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU;
12use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
13use openvm_cuda_common::copy::MemCopyH2D;
14use openvm_stark_backend::prover::types::AirProvingContext;
15
16use crate::cuda_abi::access_adapters::tracegen;
17
18pub(crate) const NUM_ADAPTERS: usize = 5;
19
20pub struct AccessAdapterInventoryGPU {
21 max_access_adapter_n: usize,
22 timestamp_max_bits: usize,
23 range_checker: Arc<VariableRangeCheckerChipGPU>,
24 #[cfg(feature = "metrics")]
25 pub(super) unpadded_heights: Vec<usize>,
26}
27
28#[repr(C)]
29pub struct OffsetInfo {
30 pub record_offset: u32,
31 pub adapter_rows: [u32; NUM_ADAPTERS],
32}
33
34impl AccessAdapterInventoryGPU {
35 pub(crate) fn generate_traces_from_records(
36 &mut self,
37 records: &mut [u8],
38 ) -> Vec<Option<DeviceMatrix<F>>> {
39 let max_access_adapter_n = &self.max_access_adapter_n;
40 let timestamp_max_bits = self.timestamp_max_bits;
41 let range_checker = &self.range_checker;
42
43 assert!(max_access_adapter_n.is_power_of_two());
44 let cnt_adapters = max_access_adapter_n.ilog2() as usize;
45 if records.is_empty() {
46 return vec![None; cnt_adapters];
47 }
48
49 let mut offsets = Vec::new();
50 let mut offset = 0;
51 let mut row_ids = [0; NUM_ADAPTERS];
52
53 while offset < records.len() {
54 offsets.push(OffsetInfo {
55 record_offset: offset as u32,
56 adapter_rows: row_ids,
57 });
58 let layout: AccessLayout = unsafe { records[offset..].extract_layout() };
59 let record: AccessRecordMut<'_> = records[offset..].custom_borrow(layout.clone());
60 offset += <AccessRecordMut<'_> as SizedRecord<AccessLayout>>::size(&layout);
61 let bs = record.header.block_size;
62 let lbs = record.header.lowest_block_size;
63 for logn in lbs.ilog2()..bs.ilog2() {
64 row_ids[logn as usize] += bs >> (1 + logn);
65 }
66 }
67
68 let d_records = records.to_device().unwrap();
69 let d_record_offsets = offsets.to_device().unwrap();
70 let widths: [_; NUM_ADAPTERS] = std::array::from_fn(|i| match i {
71 0 => size_of::<AccessAdapterCols<u8, 2>>(),
72 1 => size_of::<AccessAdapterCols<u8, 4>>(),
73 2 => size_of::<AccessAdapterCols<u8, 8>>(),
74 3 => size_of::<AccessAdapterCols<u8, 16>>(),
75 4 => size_of::<AccessAdapterCols<u8, 32>>(),
76 _ => panic!(),
77 });
78 let unpadded_heights = row_ids
79 .iter()
80 .take(cnt_adapters)
81 .map(|&x| x as usize)
82 .collect::<Vec<_>>();
83 let traces = unpadded_heights
84 .iter()
85 .enumerate()
86 .map(|(i, &h)| match h {
87 0 => None,
88 h => Some(DeviceMatrix::<F>::with_capacity(
89 next_power_of_two_or_zero(h),
90 widths[i],
91 )),
92 })
93 .collect::<Vec<_>>();
94 let trace_ptrs = traces
95 .iter()
96 .map(|trace| {
97 trace
98 .as_ref()
99 .map_or_else(null_mut, |t| t.buffer().as_mut_raw_ptr())
100 })
101 .collect::<Vec<_>>();
102 let d_trace_ptrs = trace_ptrs.to_device().unwrap();
103 let d_unpadded_heights = unpadded_heights.to_device().unwrap();
104 let d_widths = widths.to_device().unwrap();
105
106 unsafe {
107 tracegen(
108 &d_trace_ptrs,
109 &d_unpadded_heights,
110 &d_widths,
111 offsets.len(),
112 &d_records,
113 &d_record_offsets,
114 &range_checker.count,
115 timestamp_max_bits,
116 )
117 .unwrap();
118 }
119 #[cfg(feature = "metrics")]
120 {
121 self.unpadded_heights = unpadded_heights;
122 }
123
124 traces
125 }
126
127 pub fn new(
128 range_checker: Arc<VariableRangeCheckerChipGPU>,
129 max_access_adapter_n: usize,
130 timestamp_max_bits: usize,
131 ) -> Self {
132 Self {
133 range_checker,
134 max_access_adapter_n,
135 timestamp_max_bits,
136 #[cfg(feature = "metrics")]
137 unpadded_heights: Vec::new(),
138 }
139 }
140
141 pub fn generate_air_proving_ctxs(
143 &mut self,
144 mut arena: DenseRecordArena,
145 ) -> Vec<AirProvingContext<GpuBackend>> {
146 let records = arena.allocated_mut();
147 self.generate_traces_from_records(records)
148 .into_iter()
149 .map(|trace| AirProvingContext {
150 cached_mains: vec![],
151 common_main: trace,
152 public_values: vec![],
153 })
154 .collect()
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use std::array;
161
162 use openvm_circuit::{
163 arch::{
164 testing::{MEMORY_BUS, RANGE_CHECKER_BUS},
165 MemoryConfig,
166 },
167 system::memory::{offline_checker::MemoryBus, MemoryController},
168 };
169 use openvm_circuit_primitives::var_range::VariableRangeCheckerBus;
170 use openvm_cuda_backend::{data_transporter::assert_eq_host_and_device_matrix, prelude::SC};
171 use openvm_stark_backend::{p3_field::FieldAlgebra, prover::hal::MatrixDimensions};
172 use rand::{rngs::StdRng, Rng, SeedableRng};
173
174 use super::*;
175 use crate::arch::testing::{GpuChipTestBuilder, TestBuilder};
176
177 #[test]
178 fn test_cuda_access_adapters_cpu_gpu_equivalence() {
179 let mem_config = MemoryConfig::default();
180
181 let mut rng = StdRng::seed_from_u64(42);
182 let decomp = mem_config.decomp;
183 let mut tester = GpuChipTestBuilder::volatile(
184 mem_config.clone(),
185 VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, decomp),
186 );
187
188 let max_ptr = 20;
189 let aligns = [4, 4, 4, 1];
190 let value_bounds = [256, 256, 256, (1 << 30)];
191 let max_log_block_size = 4;
192 let its = 1000;
193 for _ in 0..its {
194 let addr_sp = rng.gen_range(1..=aligns.len());
195 let align: usize = aligns[addr_sp - 1];
196 let value_bound: u32 = value_bounds[addr_sp - 1];
197 let ptr = rng.gen_range(0..max_ptr / align) * align;
198 let log_len = rng.gen_range(align.trailing_zeros()..=max_log_block_size);
199 match log_len {
200 0 => tester.write::<1>(
201 addr_sp,
202 ptr,
203 array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
204 ),
205 1 => tester.write::<2>(
206 addr_sp,
207 ptr,
208 array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
209 ),
210 2 => tester.write::<4>(
211 addr_sp,
212 ptr,
213 array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
214 ),
215 3 => tester.write::<8>(
216 addr_sp,
217 ptr,
218 array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
219 ),
220 4 => tester.write::<16>(
221 addr_sp,
222 ptr,
223 array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
224 ),
225 _ => unreachable!(),
226 }
227 }
228
229 let touched = tester.memory.memory.finalize(false);
230 let mut access_adapter_inv = AccessAdapterInventoryGPU::new(
231 tester.range_checker(),
232 mem_config.max_access_adapter_n,
233 mem_config.timestamp_max_bits,
234 );
235 let allocated = tester.memory.memory.access_adapter_records.allocated_mut();
236 let gpu_traces = access_adapter_inv
237 .generate_traces_from_records(allocated)
238 .into_iter()
239 .map(|trace| trace.unwrap_or_else(DeviceMatrix::dummy))
240 .collect::<Vec<_>>();
241
242 let mut controller = MemoryController::with_volatile_memory(
243 MemoryBus::new(MEMORY_BUS),
244 mem_config,
245 tester.cpu_range_checker(),
246 );
247 let all_memory_traces = controller
248 .generate_proving_ctx::<SC>(tester.memory.memory.access_adapter_records, touched)
249 .into_iter()
250 .map(|ctx| ctx.common_main.unwrap())
251 .collect::<Vec<_>>();
252 let num_memory_traces = all_memory_traces.len();
253 let cpu_traces: Vec<_> = all_memory_traces
254 .into_iter()
255 .skip(num_memory_traces - NUM_ADAPTERS)
256 .collect::<Vec<_>>();
257
258 for (cpu_trace, gpu_trace) in cpu_traces.into_iter().zip(gpu_traces.iter()) {
259 assert_eq!(
260 cpu_trace.height() == 0,
261 gpu_trace.height() == 0,
262 "Exactly one of CPU and GPU traces is empty"
263 );
264 if cpu_trace.height() != 0 {
265 assert_eq_host_and_device_matrix(cpu_trace, gpu_trace);
266 }
267 }
268 }
269}