openvm_circuit_primitives/var_range/
cuda.rs

1use std::sync::{atomic::Ordering, Arc};
2
3use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F};
4use openvm_cuda_common::{copy::MemCopyH2D as _, d_buffer::DeviceBuffer};
5use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
6
7use crate::{
8    cuda_abi::var_range::tracegen,
9    var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip, NUM_VARIABLE_RANGE_COLS},
10};
11
12pub struct VariableRangeCheckerChipGPU {
13    pub count: Arc<DeviceBuffer<F>>,
14    pub cpu_chip: Option<Arc<VariableRangeCheckerChip>>,
15}
16
17/// [value, bits] are in preprocessed trace
18/// generate_trace returns [count]
19impl VariableRangeCheckerChipGPU {
20    pub fn new(bus: VariableRangeCheckerBus) -> Self {
21        let num_rows = (1 << (bus.range_max_bits + 1)) as usize;
22        let count = Arc::new(DeviceBuffer::<F>::with_capacity(num_rows));
23        count.fill_zero().unwrap();
24        Self {
25            count,
26            cpu_chip: None,
27        }
28    }
29
30    pub fn hybrid(cpu_chip: Arc<VariableRangeCheckerChip>) -> Self {
31        let count = Arc::new(DeviceBuffer::<F>::with_capacity(cpu_chip.count.len()));
32        count.fill_zero().unwrap();
33        Self {
34            count,
35            cpu_chip: Some(cpu_chip),
36        }
37    }
38}
39
40impl<RA> Chip<RA, GpuBackend> for VariableRangeCheckerChipGPU {
41    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
42        assert_eq!(size_of::<F>(), size_of::<u32>());
43        let cpu_count = self.cpu_chip.as_ref().map(|cpu_chip| {
44            cpu_chip
45                .count
46                .iter()
47                .map(|c| c.swap(0, Ordering::Relaxed))
48                .collect::<Vec<_>>()
49                .to_device()
50                .unwrap()
51        });
52        // ATTENTION: we create a new buffer to copy `count` into because this chip is stateful and
53        // `count` will be reused.
54        let trace = DeviceMatrix::<F>::with_capacity(self.count.len(), NUM_VARIABLE_RANGE_COLS);
55        unsafe {
56            tracegen(&self.count, &cpu_count, trace.buffer()).unwrap();
57        }
58        // Zero the internal count buffer because this chip is stateful and may be used again.
59        self.count.fill_zero().unwrap();
60        AirProvingContext::simple_no_pis(trace)
61    }
62}