openvm_circuit_primitives/bitwise_op_lookup/
cuda.rs

1use std::sync::{atomic::Ordering, Arc};
2
3use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F};
4use openvm_cuda_common::{copy::MemCopyH2D as _, d_buffer::DeviceBuffer};
5use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
6
7use crate::{
8    bitwise_op_lookup::{BitwiseOperationLookupChip, NUM_BITWISE_OP_LOOKUP_COLS},
9    cuda_abi::bitwise_op_lookup::tracegen,
10};
11
12pub struct BitwiseOperationLookupChipGPU<const NUM_BITS: usize> {
13    pub count: Arc<DeviceBuffer<F>>,
14    pub cpu_chip: Option<Arc<BitwiseOperationLookupChip<NUM_BITS>>>,
15}
16
17impl<const NUM_BITS: usize> BitwiseOperationLookupChipGPU<NUM_BITS> {
18    pub const fn num_rows() -> usize {
19        1 << (2 * NUM_BITS)
20    }
21
22    pub fn new() -> Self {
23        // The first 2^(2 * NUM_BITS) indices are for range checking, the rest are for XOR
24        let count = Arc::new(DeviceBuffer::<F>::with_capacity(
25            NUM_BITWISE_OP_LOOKUP_COLS * Self::num_rows(),
26        ));
27        count.fill_zero().unwrap();
28        Self {
29            count,
30            cpu_chip: None,
31        }
32    }
33
34    pub fn hybrid(cpu_chip: Arc<BitwiseOperationLookupChip<NUM_BITS>>) -> Self {
35        assert_eq!(cpu_chip.count_range.len(), Self::num_rows());
36        assert_eq!(cpu_chip.count_xor.len(), Self::num_rows());
37        let count = Arc::new(DeviceBuffer::<F>::with_capacity(
38            NUM_BITWISE_OP_LOOKUP_COLS * Self::num_rows(),
39        ));
40        count.fill_zero().unwrap();
41        Self {
42            count,
43            cpu_chip: Some(cpu_chip),
44        }
45    }
46}
47
48impl<const NUM_BITS: usize> Default for BitwiseOperationLookupChipGPU<NUM_BITS> {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl<RA, const NUM_BITS: usize> Chip<RA, GpuBackend> for BitwiseOperationLookupChipGPU<NUM_BITS> {
55    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
56        debug_assert_eq!(
57            Self::num_rows() * NUM_BITWISE_OP_LOOKUP_COLS,
58            self.count.len()
59        );
60        let cpu_count = self.cpu_chip.as_ref().map(|cpu_chip| {
61            cpu_chip
62                .count_range
63                .iter()
64                .chain(cpu_chip.count_xor.iter())
65                .map(|c| c.swap(0, Ordering::Relaxed))
66                .collect::<Vec<_>>()
67                .to_device()
68                .unwrap()
69        });
70        // ATTENTION: we create a new buffer to copy `count` into because this chip is stateful and
71        // `count` will be reused.
72        let trace = DeviceMatrix::<F>::with_capacity(Self::num_rows(), NUM_BITWISE_OP_LOOKUP_COLS);
73        unsafe {
74            tracegen(&self.count, &cpu_count, trace.buffer(), NUM_BITS as u32).unwrap();
75        }
76        // Zero the internal count buffer because this chip is stateful and may be used again.
77        self.count.fill_zero().unwrap();
78        AirProvingContext::simple_no_pis(trace)
79    }
80}