openvm_circuit/system/cuda/
boundary.rs

1use std::sync::Arc;
2
3use openvm_circuit::{
4    system::memory::{
5        persistent::PersistentBoundaryCols, volatile::VolatileBoundaryCols,
6        TimestampedEquipartition, TimestampedValues,
7    },
8    utils::next_power_of_two_or_zero,
9};
10use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU;
11use openvm_cuda_backend::{
12    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend,
13};
14use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer};
15use openvm_stark_backend::{
16    p3_field::PrimeField32,
17    p3_maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator},
18    prover::{hal::MatrixDimensions, types::AirProvingContext},
19    Chip,
20};
21
22use super::{merkle_tree::TIMESTAMPED_BLOCK_WIDTH, poseidon2::SharedBuffer};
23use crate::cuda_abi::boundary::{persistent_boundary_tracegen, volatile_boundary_tracegen};
24
25pub struct PersistentBoundary {
26    pub poseidon2_buffer: SharedBuffer<F>,
27    /// A `Vec` of pointers to the copied guest memory on device.
28    /// This struct cannot own the device memory, hence we take extra care not to use memory we
29    /// don't own. TODO: use `Arc<DeviceBuffer>` instead?
30    pub initial_leaves: Vec<*const std::ffi::c_void>,
31    pub touched_blocks: Option<DeviceBuffer<u32>>,
32}
33
34pub struct VolatileBoundary {
35    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
36    pub as_max_bits: usize,
37    pub ptr_max_bits: usize,
38    pub records: Option<Vec<u32>>,
39}
40
41pub enum BoundaryFields {
42    Persistent(PersistentBoundary),
43    Volatile(VolatileBoundary),
44}
45
46pub struct BoundaryChipGPU {
47    pub fields: BoundaryFields,
48    pub num_records: Option<usize>,
49    pub trace_width: Option<usize>,
50}
51
52impl BoundaryChipGPU {
53    pub fn persistent(poseidon2_buffer: SharedBuffer<F>) -> Self {
54        Self {
55            fields: BoundaryFields::Persistent(PersistentBoundary {
56                poseidon2_buffer,
57                initial_leaves: Vec::new(),
58                touched_blocks: None,
59            }),
60            num_records: None,
61            trace_width: None,
62        }
63    }
64
65    pub fn volatile(
66        range_checker: Arc<VariableRangeCheckerChipGPU>,
67        as_max_bits: usize,
68        ptr_max_bits: usize,
69    ) -> Self {
70        Self {
71            fields: BoundaryFields::Volatile(VolatileBoundary {
72                range_checker,
73                as_max_bits,
74                ptr_max_bits,
75                records: None,
76            }),
77            num_records: None,
78            trace_width: None,
79        }
80    }
81
82    // Records in the buffer are series of u32s. A single record consts
83    // of [as, ptr, timestamp, values[0], ..., values[CHUNK - 1]].
84    pub fn finalize_records_volatile<const CHUNK: usize>(
85        &mut self,
86        final_memory: TimestampedEquipartition<F, CHUNK>,
87    ) {
88        match &mut self.fields {
89            BoundaryFields::Persistent(_) => panic!("call `finalize_records_persistent`"),
90            BoundaryFields::Volatile(fields) => {
91                self.num_records = Some(final_memory.len());
92                self.trace_width = Some(VolatileBoundaryCols::<F>::width());
93                let records: Vec<_> = final_memory
94                    .par_iter()
95                    .flat_map(|&((addr_space, ptr), ts_values)| {
96                        let TimestampedValues { timestamp, values } = ts_values;
97                        let mut record = vec![addr_space, ptr, timestamp];
98                        record.extend_from_slice(&values.map(|x| x.as_canonical_u32()));
99                        record
100                    })
101                    .collect();
102                fields.records = Some(records);
103            }
104        }
105    }
106
107    pub fn finalize_records_persistent<const CHUNK: usize>(
108        &mut self,
109        touched_blocks: DeviceBuffer<u32>,
110    ) {
111        match &mut self.fields {
112            BoundaryFields::Volatile(_) => panic!("call `finalize_records_volatile`"),
113            BoundaryFields::Persistent(fields) => {
114                self.num_records = Some(touched_blocks.len() / TIMESTAMPED_BLOCK_WIDTH);
115                self.trace_width = Some(PersistentBoundaryCols::<F, CHUNK>::width());
116                fields.touched_blocks = Some(touched_blocks);
117            }
118        }
119    }
120
121    pub fn trace_width(&self) -> usize {
122        self.trace_width.expect("Finalize records to get width")
123    }
124}
125
126impl<RA> Chip<RA, GpuBackend> for BoundaryChipGPU {
127    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
128        let num_records = self.num_records.unwrap();
129        if num_records == 0 {
130            return get_empty_air_proving_ctx();
131        }
132        let unpadded_height = match &self.fields {
133            BoundaryFields::Persistent(_) => 2 * num_records,
134            BoundaryFields::Volatile(_) => num_records,
135        };
136        let trace_height = next_power_of_two_or_zero(unpadded_height);
137        let trace = DeviceMatrix::<F>::with_capacity(trace_height, self.trace_width());
138        match &self.fields {
139            BoundaryFields::Persistent(boundary) => {
140                let mem_ptrs = boundary.initial_leaves.to_device().unwrap();
141                unsafe {
142                    persistent_boundary_tracegen(
143                        trace.buffer(),
144                        trace.height(),
145                        trace.width(),
146                        &mem_ptrs,
147                        boundary.touched_blocks.as_ref().unwrap(),
148                        num_records,
149                        &boundary.poseidon2_buffer.buffer,
150                        &boundary.poseidon2_buffer.idx,
151                    )
152                    .expect("Failed to generate persistent boundary trace");
153                }
154            }
155            BoundaryFields::Volatile(boundary) => unsafe {
156                let records = boundary
157                    .records
158                    .as_ref()
159                    .expect("Records must be finalized before generating trace");
160                let records = records.to_device().unwrap();
161                volatile_boundary_tracegen(
162                    trace.buffer(),
163                    trace.height(),
164                    trace.width(),
165                    &records,
166                    num_records,
167                    &boundary.range_checker.count,
168                    boundary.as_max_bits,
169                    boundary.ptr_max_bits,
170                )
171                .expect("Failed to generate volatile boundary trace");
172            },
173        }
174        AirProvingContext::simple_no_pis(trace)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use std::{collections::HashSet, sync::Arc};
181
182    use openvm_circuit::{
183        arch::{testing::MEMORY_BUS, MemoryConfig, ADDR_SPACE_OFFSET},
184        system::memory::{
185            offline_checker::MemoryBus, volatile::VolatileBoundaryChip, TimestampedEquipartition,
186            TimestampedValues,
187        },
188    };
189    use openvm_circuit_primitives::var_range::VariableRangeCheckerChip;
190    use openvm_cuda_backend::{
191        data_transporter::assert_eq_host_and_device_matrix,
192        prelude::{F, SC},
193        prover_backend::GpuBackend,
194    };
195    use openvm_stark_backend::{
196        p3_util::log2_ceil_usize,
197        prover::{cpu::CpuBackend, types::AirProvingContext},
198        Chip,
199    };
200    use openvm_stark_sdk::utils::create_seeded_rng;
201    use p3_field::FieldAlgebra;
202    use rand::Rng;
203
204    use super::{BoundaryChipGPU, VariableRangeCheckerChipGPU};
205    use crate::arch::testing::default_var_range_checker_bus;
206
207    const MAX_ADDRESS_SPACE: u32 = 4;
208    const LIMB_BITS: usize = 15;
209
210    #[test]
211    fn test_cuda_volatile_boundary_tracegen() {
212        const NUM_ADDRESSES: usize = 10;
213        let mut rng = create_seeded_rng();
214
215        let mut distinct_addresses = HashSet::new();
216        while distinct_addresses.len() < NUM_ADDRESSES {
217            let addr_space = rng.gen_range(0..MAX_ADDRESS_SPACE);
218            let pointer = rng.gen_range(0..(1 << LIMB_BITS));
219            distinct_addresses.insert((addr_space, pointer));
220        }
221
222        let mut final_memory = TimestampedEquipartition::<F, 1>::new();
223        for (addr_space, pointer) in distinct_addresses.iter().cloned() {
224            let final_data = F::from_canonical_u32(rng.gen_range(0..(1 << LIMB_BITS)));
225            let final_clk = rng.gen_range(1..(1 << LIMB_BITS)) as u32;
226
227            final_memory.push((
228                (addr_space, pointer),
229                TimestampedValues {
230                    values: [final_data],
231                    timestamp: final_clk,
232                },
233            ));
234        }
235        final_memory.sort_by_key(|(k, _)| *k);
236
237        let mem_config = MemoryConfig::default();
238        let addr_space_max_bits = log2_ceil_usize(
239            (ADDR_SPACE_OFFSET + 2u32.pow(mem_config.addr_space_height as u32)) as usize,
240        );
241        let cpu_rc = Arc::new(VariableRangeCheckerChip::new(
242            default_var_range_checker_bus(),
243        ));
244
245        let mut gpu_boundary = BoundaryChipGPU::volatile(
246            Arc::new(VariableRangeCheckerChipGPU::hybrid(cpu_rc.clone())),
247            addr_space_max_bits,
248            mem_config.pointer_max_bits,
249        );
250        let mut cpu_boundary: VolatileBoundaryChip<F> = VolatileBoundaryChip::new(
251            MemoryBus::new(MEMORY_BUS),
252            addr_space_max_bits,
253            mem_config.pointer_max_bits,
254            cpu_rc,
255        );
256        gpu_boundary.finalize_records_volatile(final_memory.clone());
257        cpu_boundary.finalize(final_memory);
258        let gpu_ctx: AirProvingContext<GpuBackend> = gpu_boundary.generate_proving_ctx(());
259        let cpu_ctx: AirProvingContext<CpuBackend<SC>> = cpu_boundary.generate_proving_ctx(());
260        assert_eq_host_and_device_matrix(
261            cpu_ctx.common_main.unwrap(),
262            &gpu_ctx.common_main.unwrap(),
263        );
264    }
265}