openvm_circuit/system/cuda/
access_adapters.rs

1use std::{ptr::null_mut, sync::Arc};
2
3use openvm_circuit::{
4    arch::{CustomBorrow, DenseRecordArena, SizedRecord},
5    system::memory::adapter::{
6        records::{AccessLayout, AccessRecordMut},
7        AccessAdapterCols,
8    },
9    utils::next_power_of_two_or_zero,
10};
11use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU;
12use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
13use openvm_cuda_common::copy::MemCopyH2D;
14use openvm_stark_backend::prover::types::AirProvingContext;
15
16use crate::cuda_abi::access_adapters::tracegen;
17
18pub(crate) const NUM_ADAPTERS: usize = 5;
19
20pub struct AccessAdapterInventoryGPU {
21    max_access_adapter_n: usize,
22    timestamp_max_bits: usize,
23    range_checker: Arc<VariableRangeCheckerChipGPU>,
24    #[cfg(feature = "metrics")]
25    pub(super) unpadded_heights: Vec<usize>,
26}
27
28#[repr(C)]
29pub struct OffsetInfo {
30    pub record_offset: u32,
31    pub adapter_rows: [u32; NUM_ADAPTERS],
32}
33
34impl AccessAdapterInventoryGPU {
35    pub(crate) fn generate_traces_from_records(
36        &mut self,
37        records: &mut [u8],
38    ) -> Vec<Option<DeviceMatrix<F>>> {
39        let max_access_adapter_n = &self.max_access_adapter_n;
40        let timestamp_max_bits = self.timestamp_max_bits;
41        let range_checker = &self.range_checker;
42
43        assert!(max_access_adapter_n.is_power_of_two());
44        let cnt_adapters = max_access_adapter_n.ilog2() as usize;
45        if records.is_empty() {
46            return vec![None; cnt_adapters];
47        }
48
49        let mut offsets = Vec::new();
50        let mut offset = 0;
51        let mut row_ids = [0; NUM_ADAPTERS];
52
53        while offset < records.len() {
54            offsets.push(OffsetInfo {
55                record_offset: offset as u32,
56                adapter_rows: row_ids,
57            });
58            let layout: AccessLayout = unsafe { records[offset..].extract_layout() };
59            let record: AccessRecordMut<'_> = records[offset..].custom_borrow(layout.clone());
60            offset += <AccessRecordMut<'_> as SizedRecord<AccessLayout>>::size(&layout);
61            let bs = record.header.block_size;
62            let lbs = record.header.lowest_block_size;
63            for logn in lbs.ilog2()..bs.ilog2() {
64                row_ids[logn as usize] += bs >> (1 + logn);
65            }
66        }
67
68        let d_records = records.to_device().unwrap();
69        let d_record_offsets = offsets.to_device().unwrap();
70        let widths: [_; NUM_ADAPTERS] = std::array::from_fn(|i| match i {
71            0 => size_of::<AccessAdapterCols<u8, 2>>(),
72            1 => size_of::<AccessAdapterCols<u8, 4>>(),
73            2 => size_of::<AccessAdapterCols<u8, 8>>(),
74            3 => size_of::<AccessAdapterCols<u8, 16>>(),
75            4 => size_of::<AccessAdapterCols<u8, 32>>(),
76            _ => panic!(),
77        });
78        let unpadded_heights = row_ids
79            .iter()
80            .take(cnt_adapters)
81            .map(|&x| x as usize)
82            .collect::<Vec<_>>();
83        let traces = unpadded_heights
84            .iter()
85            .enumerate()
86            .map(|(i, &h)| match h {
87                0 => None,
88                h => Some(DeviceMatrix::<F>::with_capacity(
89                    next_power_of_two_or_zero(h),
90                    widths[i],
91                )),
92            })
93            .collect::<Vec<_>>();
94        let trace_ptrs = traces
95            .iter()
96            .map(|trace| {
97                trace
98                    .as_ref()
99                    .map_or_else(null_mut, |t| t.buffer().as_mut_raw_ptr())
100            })
101            .collect::<Vec<_>>();
102        let d_trace_ptrs = trace_ptrs.to_device().unwrap();
103        let d_unpadded_heights = unpadded_heights.to_device().unwrap();
104        let d_widths = widths.to_device().unwrap();
105
106        unsafe {
107            tracegen(
108                &d_trace_ptrs,
109                &d_unpadded_heights,
110                &d_widths,
111                offsets.len(),
112                &d_records,
113                &d_record_offsets,
114                &range_checker.count,
115                timestamp_max_bits,
116            )
117            .unwrap();
118        }
119        #[cfg(feature = "metrics")]
120        {
121            self.unpadded_heights = unpadded_heights;
122        }
123
124        traces
125    }
126
127    pub fn new(
128        range_checker: Arc<VariableRangeCheckerChipGPU>,
129        max_access_adapter_n: usize,
130        timestamp_max_bits: usize,
131    ) -> Self {
132        Self {
133            range_checker,
134            max_access_adapter_n,
135            timestamp_max_bits,
136            #[cfg(feature = "metrics")]
137            unpadded_heights: Vec::new(),
138        }
139    }
140
141    // @dev: mutable borrow is only to update `self.unpadded_heights` for metrics
142    pub fn generate_air_proving_ctxs(
143        &mut self,
144        mut arena: DenseRecordArena,
145    ) -> Vec<AirProvingContext<GpuBackend>> {
146        let records = arena.allocated_mut();
147        self.generate_traces_from_records(records)
148            .into_iter()
149            .map(|trace| AirProvingContext {
150                cached_mains: vec![],
151                common_main: trace,
152                public_values: vec![],
153            })
154            .collect()
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use std::array;
161
162    use openvm_circuit::{
163        arch::{
164            testing::{MEMORY_BUS, RANGE_CHECKER_BUS},
165            MemoryConfig,
166        },
167        system::memory::{offline_checker::MemoryBus, MemoryController},
168    };
169    use openvm_circuit_primitives::var_range::VariableRangeCheckerBus;
170    use openvm_cuda_backend::{data_transporter::assert_eq_host_and_device_matrix, prelude::SC};
171    use openvm_stark_backend::{p3_field::FieldAlgebra, prover::hal::MatrixDimensions};
172    use rand::{rngs::StdRng, Rng, SeedableRng};
173
174    use super::*;
175    use crate::arch::testing::{GpuChipTestBuilder, TestBuilder};
176
177    #[test]
178    fn test_cuda_access_adapters_cpu_gpu_equivalence() {
179        let mem_config = MemoryConfig::default();
180
181        let mut rng = StdRng::seed_from_u64(42);
182        let decomp = mem_config.decomp;
183        let mut tester = GpuChipTestBuilder::volatile(
184            mem_config.clone(),
185            VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, decomp),
186        );
187
188        let max_ptr = 20;
189        let aligns = [4, 4, 4, 1];
190        let value_bounds = [256, 256, 256, (1 << 30)];
191        let max_log_block_size = 4;
192        let its = 1000;
193        for _ in 0..its {
194            let addr_sp = rng.gen_range(1..=aligns.len());
195            let align: usize = aligns[addr_sp - 1];
196            let value_bound: u32 = value_bounds[addr_sp - 1];
197            let ptr = rng.gen_range(0..max_ptr / align) * align;
198            let log_len = rng.gen_range(align.trailing_zeros()..=max_log_block_size);
199            match log_len {
200                0 => tester.write::<1>(
201                    addr_sp,
202                    ptr,
203                    array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
204                ),
205                1 => tester.write::<2>(
206                    addr_sp,
207                    ptr,
208                    array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
209                ),
210                2 => tester.write::<4>(
211                    addr_sp,
212                    ptr,
213                    array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
214                ),
215                3 => tester.write::<8>(
216                    addr_sp,
217                    ptr,
218                    array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
219                ),
220                4 => tester.write::<16>(
221                    addr_sp,
222                    ptr,
223                    array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))),
224                ),
225                _ => unreachable!(),
226            }
227        }
228
229        let touched = tester.memory.memory.finalize(false);
230        let mut access_adapter_inv = AccessAdapterInventoryGPU::new(
231            tester.range_checker(),
232            mem_config.max_access_adapter_n,
233            mem_config.timestamp_max_bits,
234        );
235        let allocated = tester.memory.memory.access_adapter_records.allocated_mut();
236        let gpu_traces = access_adapter_inv
237            .generate_traces_from_records(allocated)
238            .into_iter()
239            .map(|trace| trace.unwrap_or_else(DeviceMatrix::dummy))
240            .collect::<Vec<_>>();
241
242        let mut controller = MemoryController::with_volatile_memory(
243            MemoryBus::new(MEMORY_BUS),
244            mem_config,
245            tester.cpu_range_checker(),
246        );
247        let all_memory_traces = controller
248            .generate_proving_ctx::<SC>(tester.memory.memory.access_adapter_records, touched)
249            .into_iter()
250            .map(|ctx| ctx.common_main.unwrap())
251            .collect::<Vec<_>>();
252        let num_memory_traces = all_memory_traces.len();
253        let cpu_traces: Vec<_> = all_memory_traces
254            .into_iter()
255            .skip(num_memory_traces - NUM_ADAPTERS)
256            .collect::<Vec<_>>();
257
258        for (cpu_trace, gpu_trace) in cpu_traces.into_iter().zip(gpu_traces.iter()) {
259            assert_eq!(
260                cpu_trace.height() == 0,
261                gpu_trace.height() == 0,
262                "Exactly one of CPU and GPU traces is empty"
263            );
264            if cpu_trace.height() != 0 {
265                assert_eq_host_and_device_matrix(cpu_trace, gpu_trace);
266            }
267        }
268    }
269}