openvm_circuit_primitives/var_range/
cuda.rs1use std::sync::{atomic::Ordering, Arc};
2
3use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F};
4use openvm_cuda_common::{copy::MemCopyH2D as _, d_buffer::DeviceBuffer};
5use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
6
7use crate::{
8 cuda_abi::var_range::tracegen,
9 var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip, NUM_VARIABLE_RANGE_COLS},
10};
11
12pub struct VariableRangeCheckerChipGPU {
13 pub count: Arc<DeviceBuffer<F>>,
14 pub cpu_chip: Option<Arc<VariableRangeCheckerChip>>,
15}
16
17impl VariableRangeCheckerChipGPU {
20 pub fn new(bus: VariableRangeCheckerBus) -> Self {
21 let num_rows = (1 << (bus.range_max_bits + 1)) as usize;
22 let count = Arc::new(DeviceBuffer::<F>::with_capacity(num_rows));
23 count.fill_zero().unwrap();
24 Self {
25 count,
26 cpu_chip: None,
27 }
28 }
29
30 pub fn hybrid(cpu_chip: Arc<VariableRangeCheckerChip>) -> Self {
31 let count = Arc::new(DeviceBuffer::<F>::with_capacity(cpu_chip.count.len()));
32 count.fill_zero().unwrap();
33 Self {
34 count,
35 cpu_chip: Some(cpu_chip),
36 }
37 }
38}
39
40impl<RA> Chip<RA, GpuBackend> for VariableRangeCheckerChipGPU {
41 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
42 assert_eq!(size_of::<F>(), size_of::<u32>());
43 let cpu_count = self.cpu_chip.as_ref().map(|cpu_chip| {
44 cpu_chip
45 .count
46 .iter()
47 .map(|c| c.swap(0, Ordering::Relaxed))
48 .collect::<Vec<_>>()
49 .to_device()
50 .unwrap()
51 });
52 let trace = DeviceMatrix::<F>::with_capacity(self.count.len(), NUM_VARIABLE_RANGE_COLS);
55 unsafe {
56 tracegen(&self.count, &cpu_count, trace.buffer()).unwrap();
57 }
58 self.count.fill_zero().unwrap();
60 AirProvingContext::simple_no_pis(trace)
61 }
62}