openvm_circuit_primitives/range_tuple/
cuda.rs

1use std::sync::{atomic::Ordering, Arc};
2
3use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F};
4use openvm_cuda_common::{copy::MemCopyH2D as _, d_buffer::DeviceBuffer};
5use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
6
7use crate::{
8    cuda_abi::range_tuple::tracegen,
9    range_tuple::{RangeTupleCheckerChip, NUM_RANGE_TUPLE_COLS},
10};
11
12pub struct RangeTupleCheckerChipGPU<const N: usize> {
13    pub count: Arc<DeviceBuffer<F>>,
14    pub cpu_chip: Option<Arc<RangeTupleCheckerChip<N>>>,
15    pub sizes: [u32; N],
16}
17
18impl<const N: usize> RangeTupleCheckerChipGPU<N> {
19    pub fn new(sizes: [u32; N]) -> Self {
20        let range_max = sizes.iter().product::<u32>() as usize;
21        let count = Arc::new(DeviceBuffer::<F>::with_capacity(range_max));
22        count.fill_zero().unwrap();
23        Self {
24            count,
25            cpu_chip: None,
26            sizes,
27        }
28    }
29
30    pub fn hybrid(cpu_chip: Arc<RangeTupleCheckerChip<N>>) -> Self {
31        let count = Arc::new(DeviceBuffer::<F>::with_capacity(cpu_chip.count.len()));
32        count.fill_zero().unwrap();
33        let sizes = *cpu_chip.sizes();
34        Self {
35            count,
36            cpu_chip: Some(cpu_chip),
37            sizes,
38        }
39    }
40}
41
42impl<RA, const N: usize> Chip<RA, GpuBackend> for RangeTupleCheckerChipGPU<N> {
43    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
44        let cpu_count = self.cpu_chip.as_ref().map(|cpu_chip| {
45            cpu_chip
46                .count
47                .iter()
48                .map(|c| c.swap(0, Ordering::Relaxed))
49                .collect::<Vec<_>>()
50                .to_device()
51                .unwrap()
52        });
53        // ATTENTION: we create a new buffer to copy `count` into because this chip is stateful and
54        // `count` will be reused.
55        let trace = DeviceMatrix::<F>::with_capacity(self.count.len(), NUM_RANGE_TUPLE_COLS);
56        unsafe {
57            tracegen(&self.count, &cpu_count, trace.buffer()).unwrap();
58        }
59        // Zero the internal count buffer because this chip is stateful and may be used again.
60        self.count.fill_zero().unwrap();
61        AirProvingContext::simple_no_pis(trace)
62    }
63}