openvm_circuit/system/cuda/
poseidon2.rs

1#[cfg(feature = "metrics")]
2use std::sync::atomic::AtomicUsize;
3use std::sync::Arc;
4
5use openvm_circuit::{
6    system::poseidon2::columns::Poseidon2PeripheryCols, utils::next_power_of_two_or_zero,
7};
8use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
9use openvm_cuda_common::{copy::MemCopyD2H, d_buffer::DeviceBuffer};
10use openvm_stark_backend::{
11    prover::{hal::MatrixDimensions, types::AirProvingContext},
12    Chip,
13};
14
15use crate::cuda_abi::poseidon2;
16
17#[derive(Clone)]
18pub struct SharedBuffer<T> {
19    pub buffer: Arc<DeviceBuffer<T>>,
20    pub idx: Arc<DeviceBuffer<u32>>,
21    #[cfg(feature = "metrics")]
22    pub(crate) current_trace_height: Arc<AtomicUsize>,
23}
24
25pub struct Poseidon2ChipGPU<const SBOX_REGISTERS: usize> {
26    pub records: Arc<DeviceBuffer<F>>,
27    pub idx: Arc<DeviceBuffer<u32>>,
28    #[cfg(feature = "metrics")]
29    pub(crate) current_trace_height: Arc<AtomicUsize>,
30}
31
32impl<const SBOX_REGISTERS: usize> Poseidon2ChipGPU<SBOX_REGISTERS> {
33    pub fn new(max_buffer_size: usize) -> Self {
34        let idx = Arc::new(DeviceBuffer::<u32>::with_capacity(1));
35        idx.fill_zero().unwrap();
36        Self {
37            records: Arc::new(DeviceBuffer::<F>::with_capacity(max_buffer_size)),
38            idx,
39            #[cfg(feature = "metrics")]
40            current_trace_height: Arc::new(AtomicUsize::new(0)),
41        }
42    }
43
44    pub fn shared_buffer(&self) -> SharedBuffer<F> {
45        SharedBuffer {
46            buffer: self.records.clone(),
47            idx: self.idx.clone(),
48            #[cfg(feature = "metrics")]
49            current_trace_height: self.current_trace_height.clone(),
50        }
51    }
52
53    pub fn trace_width() -> usize {
54        Poseidon2PeripheryCols::<F, SBOX_REGISTERS>::width()
55    }
56}
57
58impl<RA, const SBOX_REGISTERS: usize> Chip<RA, GpuBackend> for Poseidon2ChipGPU<SBOX_REGISTERS> {
59    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
60        let mut num_records = self.idx.to_host().unwrap()[0] as usize;
61        let counts = DeviceBuffer::<u32>::with_capacity(num_records);
62        unsafe {
63            poseidon2::deduplicate_records(&self.records, &counts, &mut num_records)
64                .expect("Failed to deduplicate records");
65        }
66        #[cfg(feature = "metrics")]
67        self.current_trace_height
68            .store(num_records, std::sync::atomic::Ordering::Relaxed);
69        let trace_height = next_power_of_two_or_zero(num_records);
70        let trace = DeviceMatrix::<F>::with_capacity(trace_height, Self::trace_width());
71        unsafe {
72            poseidon2::tracegen(
73                trace.buffer(),
74                trace.height(),
75                trace.width(),
76                &self.records,
77                &counts,
78                num_records,
79                SBOX_REGISTERS,
80            )
81            .expect("Failed to generate trace");
82        }
83        // Reset state of this chip.
84        self.idx.fill_zero().unwrap();
85        AirProvingContext::simple_no_pis(trace)
86    }
87}
88
89pub enum Poseidon2PeripheryChipGPU {
90    Register0(Poseidon2ChipGPU<0>),
91    Register1(Poseidon2ChipGPU<1>),
92}
93
94impl Poseidon2PeripheryChipGPU {
95    pub fn new(max_buffer_size: usize, sbox_registers: usize) -> Self {
96        match sbox_registers {
97            0 => Self::Register0(Poseidon2ChipGPU::new(max_buffer_size)),
98            1 => Self::Register1(Poseidon2ChipGPU::new(max_buffer_size)),
99            _ => panic!("Invalid number of sbox registers: {}", sbox_registers),
100        }
101    }
102
103    pub fn shared_buffer(&self) -> SharedBuffer<F> {
104        match self {
105            Self::Register0(chip) => chip.shared_buffer(),
106            Self::Register1(chip) => chip.shared_buffer(),
107        }
108    }
109}
110
111impl<RA> Chip<RA, GpuBackend> for Poseidon2PeripheryChipGPU {
112    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
113        match self {
114            Self::Register0(chip) => chip.generate_proving_ctx(()),
115            Self::Register1(chip) => chip.generate_proving_ctx(()),
116        }
117    }
118}