openvm_circuit/system/cuda/
boundary.rs1use std::sync::Arc;
2
3use openvm_circuit::{
4 system::memory::{
5 persistent::PersistentBoundaryCols, volatile::VolatileBoundaryCols,
6 TimestampedEquipartition, TimestampedValues,
7 },
8 utils::next_power_of_two_or_zero,
9};
10use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU;
11use openvm_cuda_backend::{
12 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend,
13};
14use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer};
15use openvm_stark_backend::{
16 p3_field::PrimeField32,
17 p3_maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator},
18 prover::{hal::MatrixDimensions, types::AirProvingContext},
19 Chip,
20};
21
22use super::{merkle_tree::TIMESTAMPED_BLOCK_WIDTH, poseidon2::SharedBuffer};
23use crate::cuda_abi::boundary::{persistent_boundary_tracegen, volatile_boundary_tracegen};
24
25pub struct PersistentBoundary {
26 pub poseidon2_buffer: SharedBuffer<F>,
27 pub initial_leaves: Vec<*const std::ffi::c_void>,
31 pub touched_blocks: Option<DeviceBuffer<u32>>,
32}
33
34pub struct VolatileBoundary {
35 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
36 pub as_max_bits: usize,
37 pub ptr_max_bits: usize,
38 pub records: Option<Vec<u32>>,
39}
40
41pub enum BoundaryFields {
42 Persistent(PersistentBoundary),
43 Volatile(VolatileBoundary),
44}
45
46pub struct BoundaryChipGPU {
47 pub fields: BoundaryFields,
48 pub num_records: Option<usize>,
49 pub trace_width: Option<usize>,
50}
51
52impl BoundaryChipGPU {
53 pub fn persistent(poseidon2_buffer: SharedBuffer<F>) -> Self {
54 Self {
55 fields: BoundaryFields::Persistent(PersistentBoundary {
56 poseidon2_buffer,
57 initial_leaves: Vec::new(),
58 touched_blocks: None,
59 }),
60 num_records: None,
61 trace_width: None,
62 }
63 }
64
65 pub fn volatile(
66 range_checker: Arc<VariableRangeCheckerChipGPU>,
67 as_max_bits: usize,
68 ptr_max_bits: usize,
69 ) -> Self {
70 Self {
71 fields: BoundaryFields::Volatile(VolatileBoundary {
72 range_checker,
73 as_max_bits,
74 ptr_max_bits,
75 records: None,
76 }),
77 num_records: None,
78 trace_width: None,
79 }
80 }
81
82 pub fn finalize_records_volatile<const CHUNK: usize>(
85 &mut self,
86 final_memory: TimestampedEquipartition<F, CHUNK>,
87 ) {
88 match &mut self.fields {
89 BoundaryFields::Persistent(_) => panic!("call `finalize_records_persistent`"),
90 BoundaryFields::Volatile(fields) => {
91 self.num_records = Some(final_memory.len());
92 self.trace_width = Some(VolatileBoundaryCols::<F>::width());
93 let records: Vec<_> = final_memory
94 .par_iter()
95 .flat_map(|&((addr_space, ptr), ts_values)| {
96 let TimestampedValues { timestamp, values } = ts_values;
97 let mut record = vec![addr_space, ptr, timestamp];
98 record.extend_from_slice(&values.map(|x| x.as_canonical_u32()));
99 record
100 })
101 .collect();
102 fields.records = Some(records);
103 }
104 }
105 }
106
107 pub fn finalize_records_persistent<const CHUNK: usize>(
108 &mut self,
109 touched_blocks: DeviceBuffer<u32>,
110 ) {
111 match &mut self.fields {
112 BoundaryFields::Volatile(_) => panic!("call `finalize_records_volatile`"),
113 BoundaryFields::Persistent(fields) => {
114 self.num_records = Some(touched_blocks.len() / TIMESTAMPED_BLOCK_WIDTH);
115 self.trace_width = Some(PersistentBoundaryCols::<F, CHUNK>::width());
116 fields.touched_blocks = Some(touched_blocks);
117 }
118 }
119 }
120
121 pub fn trace_width(&self) -> usize {
122 self.trace_width.expect("Finalize records to get width")
123 }
124}
125
126impl<RA> Chip<RA, GpuBackend> for BoundaryChipGPU {
127 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
128 let num_records = self.num_records.unwrap();
129 if num_records == 0 {
130 return get_empty_air_proving_ctx();
131 }
132 let unpadded_height = match &self.fields {
133 BoundaryFields::Persistent(_) => 2 * num_records,
134 BoundaryFields::Volatile(_) => num_records,
135 };
136 let trace_height = next_power_of_two_or_zero(unpadded_height);
137 let trace = DeviceMatrix::<F>::with_capacity(trace_height, self.trace_width());
138 match &self.fields {
139 BoundaryFields::Persistent(boundary) => {
140 let mem_ptrs = boundary.initial_leaves.to_device().unwrap();
141 unsafe {
142 persistent_boundary_tracegen(
143 trace.buffer(),
144 trace.height(),
145 trace.width(),
146 &mem_ptrs,
147 boundary.touched_blocks.as_ref().unwrap(),
148 num_records,
149 &boundary.poseidon2_buffer.buffer,
150 &boundary.poseidon2_buffer.idx,
151 )
152 .expect("Failed to generate persistent boundary trace");
153 }
154 }
155 BoundaryFields::Volatile(boundary) => unsafe {
156 let records = boundary
157 .records
158 .as_ref()
159 .expect("Records must be finalized before generating trace");
160 let records = records.to_device().unwrap();
161 volatile_boundary_tracegen(
162 trace.buffer(),
163 trace.height(),
164 trace.width(),
165 &records,
166 num_records,
167 &boundary.range_checker.count,
168 boundary.as_max_bits,
169 boundary.ptr_max_bits,
170 )
171 .expect("Failed to generate volatile boundary trace");
172 },
173 }
174 AirProvingContext::simple_no_pis(trace)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use std::{collections::HashSet, sync::Arc};
181
182 use openvm_circuit::{
183 arch::{testing::MEMORY_BUS, MemoryConfig, ADDR_SPACE_OFFSET},
184 system::memory::{
185 offline_checker::MemoryBus, volatile::VolatileBoundaryChip, TimestampedEquipartition,
186 TimestampedValues,
187 },
188 };
189 use openvm_circuit_primitives::var_range::VariableRangeCheckerChip;
190 use openvm_cuda_backend::{
191 data_transporter::assert_eq_host_and_device_matrix,
192 prelude::{F, SC},
193 prover_backend::GpuBackend,
194 };
195 use openvm_stark_backend::{
196 p3_util::log2_ceil_usize,
197 prover::{cpu::CpuBackend, types::AirProvingContext},
198 Chip,
199 };
200 use openvm_stark_sdk::utils::create_seeded_rng;
201 use p3_field::FieldAlgebra;
202 use rand::Rng;
203
204 use super::{BoundaryChipGPU, VariableRangeCheckerChipGPU};
205 use crate::arch::testing::default_var_range_checker_bus;
206
207 const MAX_ADDRESS_SPACE: u32 = 4;
208 const LIMB_BITS: usize = 15;
209
210 #[test]
211 fn test_cuda_volatile_boundary_tracegen() {
212 const NUM_ADDRESSES: usize = 10;
213 let mut rng = create_seeded_rng();
214
215 let mut distinct_addresses = HashSet::new();
216 while distinct_addresses.len() < NUM_ADDRESSES {
217 let addr_space = rng.gen_range(0..MAX_ADDRESS_SPACE);
218 let pointer = rng.gen_range(0..(1 << LIMB_BITS));
219 distinct_addresses.insert((addr_space, pointer));
220 }
221
222 let mut final_memory = TimestampedEquipartition::<F, 1>::new();
223 for (addr_space, pointer) in distinct_addresses.iter().cloned() {
224 let final_data = F::from_canonical_u32(rng.gen_range(0..(1 << LIMB_BITS)));
225 let final_clk = rng.gen_range(1..(1 << LIMB_BITS)) as u32;
226
227 final_memory.push((
228 (addr_space, pointer),
229 TimestampedValues {
230 values: [final_data],
231 timestamp: final_clk,
232 },
233 ));
234 }
235 final_memory.sort_by_key(|(k, _)| *k);
236
237 let mem_config = MemoryConfig::default();
238 let addr_space_max_bits = log2_ceil_usize(
239 (ADDR_SPACE_OFFSET + 2u32.pow(mem_config.addr_space_height as u32)) as usize,
240 );
241 let cpu_rc = Arc::new(VariableRangeCheckerChip::new(
242 default_var_range_checker_bus(),
243 ));
244
245 let mut gpu_boundary = BoundaryChipGPU::volatile(
246 Arc::new(VariableRangeCheckerChipGPU::hybrid(cpu_rc.clone())),
247 addr_space_max_bits,
248 mem_config.pointer_max_bits,
249 );
250 let mut cpu_boundary: VolatileBoundaryChip<F> = VolatileBoundaryChip::new(
251 MemoryBus::new(MEMORY_BUS),
252 addr_space_max_bits,
253 mem_config.pointer_max_bits,
254 cpu_rc,
255 );
256 gpu_boundary.finalize_records_volatile(final_memory.clone());
257 cpu_boundary.finalize(final_memory);
258 let gpu_ctx: AirProvingContext<GpuBackend> = gpu_boundary.generate_proving_ctx(());
259 let cpu_ctx: AirProvingContext<CpuBackend<SC>> = cpu_boundary.generate_proving_ctx(());
260 assert_eq_host_and_device_matrix(
261 cpu_ctx.common_main.unwrap(),
262 &gpu_ctx.common_main.unwrap(),
263 );
264 }
265}