openvm_circuit/system/cuda/
poseidon2.rs

1#[cfg(feature = "metrics")]
2use std::sync::atomic::AtomicUsize;
3use std::sync::Arc;
4
5use openvm_circuit::{
6    system::poseidon2::columns::Poseidon2PeripheryCols, utils::next_power_of_two_or_zero,
7};
8use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
9use openvm_cuda_common::{
10    copy::{MemCopyD2H, MemCopyH2D},
11    d_buffer::DeviceBuffer,
12};
13use openvm_stark_backend::{
14    prover::{hal::MatrixDimensions, types::AirProvingContext},
15    Chip,
16};
17
18use crate::cuda_abi::poseidon2;
19
20#[derive(Clone)]
21pub struct SharedBuffer<T> {
22    pub buffer: Arc<DeviceBuffer<T>>,
23    pub idx: Arc<DeviceBuffer<u32>>,
24    #[cfg(feature = "metrics")]
25    pub(crate) current_trace_height: Arc<AtomicUsize>,
26}
27
28pub struct Poseidon2ChipGPU<const SBOX_REGISTERS: usize> {
29    pub records: Arc<DeviceBuffer<F>>,
30    pub idx: Arc<DeviceBuffer<u32>>,
31    #[cfg(feature = "metrics")]
32    pub(crate) current_trace_height: Arc<AtomicUsize>,
33}
34
35impl<const SBOX_REGISTERS: usize> Poseidon2ChipGPU<SBOX_REGISTERS> {
36    pub fn new(max_buffer_size: usize) -> Self {
37        let idx = Arc::new(DeviceBuffer::<u32>::with_capacity(1));
38        idx.fill_zero().unwrap();
39        Self {
40            records: Arc::new(DeviceBuffer::<F>::with_capacity(max_buffer_size)),
41            idx,
42            #[cfg(feature = "metrics")]
43            current_trace_height: Arc::new(AtomicUsize::new(0)),
44        }
45    }
46
47    pub fn shared_buffer(&self) -> SharedBuffer<F> {
48        SharedBuffer {
49            buffer: self.records.clone(),
50            idx: self.idx.clone(),
51            #[cfg(feature = "metrics")]
52            current_trace_height: self.current_trace_height.clone(),
53        }
54    }
55
56    pub fn trace_width() -> usize {
57        Poseidon2PeripheryCols::<F, SBOX_REGISTERS>::width()
58    }
59}
60
61impl<RA, const SBOX_REGISTERS: usize> Chip<RA, GpuBackend> for Poseidon2ChipGPU<SBOX_REGISTERS> {
62    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
63        let mut num_records = self.idx.to_host().unwrap()[0] as usize;
64        let counts = DeviceBuffer::<u32>::with_capacity(num_records);
65        unsafe {
66            let d_num_records = [num_records].to_device().unwrap();
67            let mut temp_bytes = 0;
68            poseidon2::deduplicate_records_get_temp_bytes(
69                &self.records,
70                &counts,
71                num_records,
72                &d_num_records,
73                &mut temp_bytes,
74            )
75            .expect("Failed to get temp bytes");
76            let d_temp_storage = DeviceBuffer::<u8>::with_capacity(temp_bytes);
77            poseidon2::deduplicate_records(
78                &self.records,
79                &counts,
80                num_records,
81                &d_num_records,
82                &d_temp_storage,
83                temp_bytes,
84            )
85            .expect("Failed to deduplicate records");
86            num_records = *d_num_records.to_host().unwrap().first().unwrap();
87        }
88        #[cfg(feature = "metrics")]
89        self.current_trace_height
90            .store(num_records, std::sync::atomic::Ordering::Relaxed);
91        let trace_height = next_power_of_two_or_zero(num_records);
92        let trace = DeviceMatrix::<F>::with_capacity(trace_height, Self::trace_width());
93        unsafe {
94            poseidon2::tracegen(
95                trace.buffer(),
96                trace.height(),
97                trace.width(),
98                &self.records,
99                &counts,
100                num_records,
101                SBOX_REGISTERS,
102            )
103            .expect("Failed to generate trace");
104        }
105        // Reset state of this chip.
106        self.idx.fill_zero().unwrap();
107        AirProvingContext::simple_no_pis(trace)
108    }
109}
110
111pub enum Poseidon2PeripheryChipGPU {
112    Register0(Poseidon2ChipGPU<0>),
113    Register1(Poseidon2ChipGPU<1>),
114}
115
116impl Poseidon2PeripheryChipGPU {
117    pub fn new(max_buffer_size: usize, sbox_registers: usize) -> Self {
118        match sbox_registers {
119            0 => Self::Register0(Poseidon2ChipGPU::new(max_buffer_size)),
120            1 => Self::Register1(Poseidon2ChipGPU::new(max_buffer_size)),
121            _ => panic!("Invalid number of sbox registers: {}", sbox_registers),
122        }
123    }
124
125    pub fn shared_buffer(&self) -> SharedBuffer<F> {
126        match self {
127            Self::Register0(chip) => chip.shared_buffer(),
128            Self::Register1(chip) => chip.shared_buffer(),
129        }
130    }
131}
132
133impl<RA> Chip<RA, GpuBackend> for Poseidon2PeripheryChipGPU {
134    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
135        match self {
136            Self::Register0(chip) => chip.generate_proving_ctx(()),
137            Self::Register1(chip) => chip.generate_proving_ctx(()),
138        }
139    }
140}