openvm_circuit_primitives/range_tuple/
cuda.rs1use std::sync::{atomic::Ordering, Arc};
2
3use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F};
4use openvm_cuda_common::{copy::MemCopyH2D as _, d_buffer::DeviceBuffer};
5use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
6
7use crate::{
8 cuda_abi::range_tuple::tracegen,
9 range_tuple::{RangeTupleCheckerChip, NUM_RANGE_TUPLE_COLS},
10};
11
12pub struct RangeTupleCheckerChipGPU<const N: usize> {
13 pub count: Arc<DeviceBuffer<F>>,
14 pub cpu_chip: Option<Arc<RangeTupleCheckerChip<N>>>,
15 pub sizes: [u32; N],
16}
17
18impl<const N: usize> RangeTupleCheckerChipGPU<N> {
19 pub fn new(sizes: [u32; N]) -> Self {
20 let range_max = sizes.iter().product::<u32>() as usize;
21 let count = Arc::new(DeviceBuffer::<F>::with_capacity(range_max));
22 count.fill_zero().unwrap();
23 Self {
24 count,
25 cpu_chip: None,
26 sizes,
27 }
28 }
29
30 pub fn hybrid(cpu_chip: Arc<RangeTupleCheckerChip<N>>) -> Self {
31 let count = Arc::new(DeviceBuffer::<F>::with_capacity(cpu_chip.count.len()));
32 count.fill_zero().unwrap();
33 let sizes = *cpu_chip.sizes();
34 Self {
35 count,
36 cpu_chip: Some(cpu_chip),
37 sizes,
38 }
39 }
40}
41
42impl<RA, const N: usize> Chip<RA, GpuBackend> for RangeTupleCheckerChipGPU<N> {
43 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
44 let cpu_count = self.cpu_chip.as_ref().map(|cpu_chip| {
45 cpu_chip
46 .count
47 .iter()
48 .map(|c| c.swap(0, Ordering::Relaxed))
49 .collect::<Vec<_>>()
50 .to_device()
51 .unwrap()
52 });
53 let trace = DeviceMatrix::<F>::with_capacity(self.count.len(), NUM_RANGE_TUPLE_COLS);
56 unsafe {
57 tracegen(&self.count, &cpu_count, trace.buffer()).unwrap();
58 }
59 self.count.fill_zero().unwrap();
61 AirProvingContext::simple_no_pis(trace)
62 }
63}