openvm_circuit/system/cuda/
poseidon2.rs1#[cfg(feature = "metrics")]
2use std::sync::atomic::AtomicUsize;
3use std::sync::Arc;
4
5use openvm_circuit::{
6 system::poseidon2::columns::Poseidon2PeripheryCols, utils::next_power_of_two_or_zero,
7};
8use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
9use openvm_cuda_common::{copy::MemCopyD2H, d_buffer::DeviceBuffer};
10use openvm_stark_backend::{
11 prover::{hal::MatrixDimensions, types::AirProvingContext},
12 Chip,
13};
14
15use crate::cuda_abi::poseidon2;
16
17#[derive(Clone)]
18pub struct SharedBuffer<T> {
19 pub buffer: Arc<DeviceBuffer<T>>,
20 pub idx: Arc<DeviceBuffer<u32>>,
21 #[cfg(feature = "metrics")]
22 pub(crate) current_trace_height: Arc<AtomicUsize>,
23}
24
25pub struct Poseidon2ChipGPU<const SBOX_REGISTERS: usize> {
26 pub records: Arc<DeviceBuffer<F>>,
27 pub idx: Arc<DeviceBuffer<u32>>,
28 #[cfg(feature = "metrics")]
29 pub(crate) current_trace_height: Arc<AtomicUsize>,
30}
31
32impl<const SBOX_REGISTERS: usize> Poseidon2ChipGPU<SBOX_REGISTERS> {
33 pub fn new(max_buffer_size: usize) -> Self {
34 let idx = Arc::new(DeviceBuffer::<u32>::with_capacity(1));
35 idx.fill_zero().unwrap();
36 Self {
37 records: Arc::new(DeviceBuffer::<F>::with_capacity(max_buffer_size)),
38 idx,
39 #[cfg(feature = "metrics")]
40 current_trace_height: Arc::new(AtomicUsize::new(0)),
41 }
42 }
43
44 pub fn shared_buffer(&self) -> SharedBuffer<F> {
45 SharedBuffer {
46 buffer: self.records.clone(),
47 idx: self.idx.clone(),
48 #[cfg(feature = "metrics")]
49 current_trace_height: self.current_trace_height.clone(),
50 }
51 }
52
53 pub fn trace_width() -> usize {
54 Poseidon2PeripheryCols::<F, SBOX_REGISTERS>::width()
55 }
56}
57
58impl<RA, const SBOX_REGISTERS: usize> Chip<RA, GpuBackend> for Poseidon2ChipGPU<SBOX_REGISTERS> {
59 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
60 let mut num_records = self.idx.to_host().unwrap()[0] as usize;
61 let counts = DeviceBuffer::<u32>::with_capacity(num_records);
62 unsafe {
63 poseidon2::deduplicate_records(&self.records, &counts, &mut num_records)
64 .expect("Failed to deduplicate records");
65 }
66 #[cfg(feature = "metrics")]
67 self.current_trace_height
68 .store(num_records, std::sync::atomic::Ordering::Relaxed);
69 let trace_height = next_power_of_two_or_zero(num_records);
70 let trace = DeviceMatrix::<F>::with_capacity(trace_height, Self::trace_width());
71 unsafe {
72 poseidon2::tracegen(
73 trace.buffer(),
74 trace.height(),
75 trace.width(),
76 &self.records,
77 &counts,
78 num_records,
79 SBOX_REGISTERS,
80 )
81 .expect("Failed to generate trace");
82 }
83 self.idx.fill_zero().unwrap();
85 AirProvingContext::simple_no_pis(trace)
86 }
87}
88
89pub enum Poseidon2PeripheryChipGPU {
90 Register0(Poseidon2ChipGPU<0>),
91 Register1(Poseidon2ChipGPU<1>),
92}
93
94impl Poseidon2PeripheryChipGPU {
95 pub fn new(max_buffer_size: usize, sbox_registers: usize) -> Self {
96 match sbox_registers {
97 0 => Self::Register0(Poseidon2ChipGPU::new(max_buffer_size)),
98 1 => Self::Register1(Poseidon2ChipGPU::new(max_buffer_size)),
99 _ => panic!("Invalid number of sbox registers: {}", sbox_registers),
100 }
101 }
102
103 pub fn shared_buffer(&self) -> SharedBuffer<F> {
104 match self {
105 Self::Register0(chip) => chip.shared_buffer(),
106 Self::Register1(chip) => chip.shared_buffer(),
107 }
108 }
109}
110
111impl<RA> Chip<RA, GpuBackend> for Poseidon2PeripheryChipGPU {
112 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
113 match self {
114 Self::Register0(chip) => chip.generate_proving_ctx(()),
115 Self::Register1(chip) => chip.generate_proving_ctx(()),
116 }
117 }
118}