openvm_circuit_primitives/bitwise_op_lookup/
cuda.rs1use std::sync::{atomic::Ordering, Arc};
2
3use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F};
4use openvm_cuda_common::{copy::MemCopyH2D as _, d_buffer::DeviceBuffer};
5use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
6
7use crate::{
8 bitwise_op_lookup::{BitwiseOperationLookupChip, NUM_BITWISE_OP_LOOKUP_COLS},
9 cuda_abi::bitwise_op_lookup::tracegen,
10};
11
12pub struct BitwiseOperationLookupChipGPU<const NUM_BITS: usize> {
13 pub count: Arc<DeviceBuffer<F>>,
14 pub cpu_chip: Option<Arc<BitwiseOperationLookupChip<NUM_BITS>>>,
15}
16
17impl<const NUM_BITS: usize> BitwiseOperationLookupChipGPU<NUM_BITS> {
18 pub const fn num_rows() -> usize {
19 1 << (2 * NUM_BITS)
20 }
21
22 pub fn new() -> Self {
23 let count = Arc::new(DeviceBuffer::<F>::with_capacity(
25 NUM_BITWISE_OP_LOOKUP_COLS * Self::num_rows(),
26 ));
27 count.fill_zero().unwrap();
28 Self {
29 count,
30 cpu_chip: None,
31 }
32 }
33
34 pub fn hybrid(cpu_chip: Arc<BitwiseOperationLookupChip<NUM_BITS>>) -> Self {
35 assert_eq!(cpu_chip.count_range.len(), Self::num_rows());
36 assert_eq!(cpu_chip.count_xor.len(), Self::num_rows());
37 let count = Arc::new(DeviceBuffer::<F>::with_capacity(
38 NUM_BITWISE_OP_LOOKUP_COLS * Self::num_rows(),
39 ));
40 count.fill_zero().unwrap();
41 Self {
42 count,
43 cpu_chip: Some(cpu_chip),
44 }
45 }
46}
47
48impl<const NUM_BITS: usize> Default for BitwiseOperationLookupChipGPU<NUM_BITS> {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl<RA, const NUM_BITS: usize> Chip<RA, GpuBackend> for BitwiseOperationLookupChipGPU<NUM_BITS> {
55 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
56 debug_assert_eq!(
57 Self::num_rows() * NUM_BITWISE_OP_LOOKUP_COLS,
58 self.count.len()
59 );
60 let cpu_count = self.cpu_chip.as_ref().map(|cpu_chip| {
61 cpu_chip
62 .count_range
63 .iter()
64 .chain(cpu_chip.count_xor.iter())
65 .map(|c| c.swap(0, Ordering::Relaxed))
66 .collect::<Vec<_>>()
67 .to_device()
68 .unwrap()
69 });
70 let trace = DeviceMatrix::<F>::with_capacity(Self::num_rows(), NUM_BITWISE_OP_LOOKUP_COLS);
73 unsafe {
74 tracegen(&self.count, &cpu_count, trace.buffer(), NUM_BITS as u32).unwrap();
75 }
76 self.count.fill_zero().unwrap();
78 AirProvingContext::simple_no_pis(trace)
79 }
80}