openvm_circuit/system/cuda/
poseidon2.rs1#[cfg(feature = "metrics")]
2use std::sync::atomic::AtomicUsize;
3use std::sync::Arc;
4
5use openvm_circuit::{
6 system::poseidon2::columns::Poseidon2PeripheryCols, utils::next_power_of_two_or_zero,
7};
8use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
9use openvm_cuda_common::{
10 copy::{MemCopyD2H, MemCopyH2D},
11 d_buffer::DeviceBuffer,
12};
13use openvm_stark_backend::{
14 prover::{hal::MatrixDimensions, types::AirProvingContext},
15 Chip,
16};
17
18use crate::cuda_abi::poseidon2;
19
20#[derive(Clone)]
21pub struct SharedBuffer<T> {
22 pub buffer: Arc<DeviceBuffer<T>>,
23 pub idx: Arc<DeviceBuffer<u32>>,
24 #[cfg(feature = "metrics")]
25 pub(crate) current_trace_height: Arc<AtomicUsize>,
26}
27
28pub struct Poseidon2ChipGPU<const SBOX_REGISTERS: usize> {
29 pub records: Arc<DeviceBuffer<F>>,
30 pub idx: Arc<DeviceBuffer<u32>>,
31 #[cfg(feature = "metrics")]
32 pub(crate) current_trace_height: Arc<AtomicUsize>,
33}
34
35impl<const SBOX_REGISTERS: usize> Poseidon2ChipGPU<SBOX_REGISTERS> {
36 pub fn new(max_buffer_size: usize) -> Self {
37 let idx = Arc::new(DeviceBuffer::<u32>::with_capacity(1));
38 idx.fill_zero().unwrap();
39 Self {
40 records: Arc::new(DeviceBuffer::<F>::with_capacity(max_buffer_size)),
41 idx,
42 #[cfg(feature = "metrics")]
43 current_trace_height: Arc::new(AtomicUsize::new(0)),
44 }
45 }
46
47 pub fn shared_buffer(&self) -> SharedBuffer<F> {
48 SharedBuffer {
49 buffer: self.records.clone(),
50 idx: self.idx.clone(),
51 #[cfg(feature = "metrics")]
52 current_trace_height: self.current_trace_height.clone(),
53 }
54 }
55
56 pub fn trace_width() -> usize {
57 Poseidon2PeripheryCols::<F, SBOX_REGISTERS>::width()
58 }
59}
60
61impl<RA, const SBOX_REGISTERS: usize> Chip<RA, GpuBackend> for Poseidon2ChipGPU<SBOX_REGISTERS> {
62 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
63 let mut num_records = self.idx.to_host().unwrap()[0] as usize;
64 let counts = DeviceBuffer::<u32>::with_capacity(num_records);
65 unsafe {
66 let d_num_records = [num_records].to_device().unwrap();
67 let mut temp_bytes = 0;
68 poseidon2::deduplicate_records_get_temp_bytes(
69 &self.records,
70 &counts,
71 num_records,
72 &d_num_records,
73 &mut temp_bytes,
74 )
75 .expect("Failed to get temp bytes");
76 let d_temp_storage = DeviceBuffer::<u8>::with_capacity(temp_bytes);
77 poseidon2::deduplicate_records(
78 &self.records,
79 &counts,
80 num_records,
81 &d_num_records,
82 &d_temp_storage,
83 temp_bytes,
84 )
85 .expect("Failed to deduplicate records");
86 num_records = *d_num_records.to_host().unwrap().first().unwrap();
87 }
88 #[cfg(feature = "metrics")]
89 self.current_trace_height
90 .store(num_records, std::sync::atomic::Ordering::Relaxed);
91 let trace_height = next_power_of_two_or_zero(num_records);
92 let trace = DeviceMatrix::<F>::with_capacity(trace_height, Self::trace_width());
93 unsafe {
94 poseidon2::tracegen(
95 trace.buffer(),
96 trace.height(),
97 trace.width(),
98 &self.records,
99 &counts,
100 num_records,
101 SBOX_REGISTERS,
102 )
103 .expect("Failed to generate trace");
104 }
105 self.idx.fill_zero().unwrap();
107 AirProvingContext::simple_no_pis(trace)
108 }
109}
110
111pub enum Poseidon2PeripheryChipGPU {
112 Register0(Poseidon2ChipGPU<0>),
113 Register1(Poseidon2ChipGPU<1>),
114}
115
116impl Poseidon2PeripheryChipGPU {
117 pub fn new(max_buffer_size: usize, sbox_registers: usize) -> Self {
118 match sbox_registers {
119 0 => Self::Register0(Poseidon2ChipGPU::new(max_buffer_size)),
120 1 => Self::Register1(Poseidon2ChipGPU::new(max_buffer_size)),
121 _ => panic!("Invalid number of sbox registers: {}", sbox_registers),
122 }
123 }
124
125 pub fn shared_buffer(&self) -> SharedBuffer<F> {
126 match self {
127 Self::Register0(chip) => chip.shared_buffer(),
128 Self::Register1(chip) => chip.shared_buffer(),
129 }
130 }
131}
132
133impl<RA> Chip<RA, GpuBackend> for Poseidon2PeripheryChipGPU {
134 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
135 match self {
136 Self::Register0(chip) => chip.generate_proving_ctx(()),
137 Self::Register1(chip) => chip.generate_proving_ctx(()),
138 }
139 }
140}