openvm_cuda_backend/
base.rs1use std::{fmt::Debug, marker::PhantomData, sync::Arc};
2
3use openvm_cuda_common::{copy::MemCopyD2H, d_buffer::DeviceBuffer, error::MemCopyError};
4use openvm_stark_backend::prover::hal::MatrixDimensions;
5
6pub struct DeviceMatrix<T> {
7 buffer: Arc<DeviceBuffer<T>>,
8 height: usize,
9 width: usize,
10}
11
12unsafe impl<T> Send for DeviceMatrix<T> {}
13unsafe impl<T> Sync for DeviceMatrix<T> {}
14
15impl<T> Clone for DeviceMatrix<T> {
16 fn clone(&self) -> Self {
17 Self {
18 buffer: Arc::clone(&self.buffer),
19 height: self.height,
20 width: self.width,
21 }
22 }
23}
24
25impl<T> Drop for DeviceMatrix<T> {
26 fn drop(&mut self) {
27 tracing::debug!(
28 "Dropping DeviceMatrix of size {} with Arc strong count={}",
29 self.buffer.len(),
30 self.strong_count()
31 );
32 }
33}
34
35impl<T> DeviceMatrix<T> {
36 pub fn new(buffer: Arc<DeviceBuffer<T>>, height: usize, width: usize) -> Self {
37 assert_ne!(
38 height * width,
39 0,
40 "Zero dimensions h {} w {} are wrong",
41 height,
42 width
43 );
44 assert_eq!(
45 buffer.len(),
46 height * width,
47 "Buffer size must match dimensions"
48 );
49 Self {
50 buffer,
51 height,
52 width,
53 }
54 }
55
56 pub fn with_capacity(height: usize, width: usize) -> Self {
57 Self {
58 buffer: Arc::new(DeviceBuffer::with_capacity(height * width)),
59 height,
60 width,
61 }
62 }
63
64 pub fn dummy() -> Self {
65 Self {
66 buffer: Arc::new(DeviceBuffer::new()),
67 height: 0,
68 width: 0,
69 }
70 }
71
72 pub fn buffer(&self) -> &DeviceBuffer<T> {
73 &self.buffer
74 }
75
76 pub fn strong_count(&self) -> usize {
77 Arc::strong_count(&self.buffer)
78 }
79}
80
81impl<T> MatrixDimensions for DeviceMatrix<T> {
82 #[inline]
83 fn height(&self) -> usize {
84 self.height
85 }
86
87 #[inline]
88 fn width(&self) -> usize {
89 self.width
90 }
91}
92
93impl<T> MemCopyD2H<T> for DeviceMatrix<T> {
94 fn to_host(&self) -> Result<Vec<T>, MemCopyError> {
95 self.buffer.to_host()
96 }
97}
98
99impl<T: Debug> Debug for DeviceMatrix<T> {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 write!(
102 f,
103 "DeviceMatrix (height = {}, width = {}): {:?}",
104 self.height(),
105 self.width(),
106 self.buffer()
107 )
108 }
109}
110
111pub trait Basis: Copy + Debug + Send + Sync {}
114
115#[derive(Clone, Copy, Debug)]
117pub struct Coeff;
118impl Basis for Coeff {}
119
120#[derive(Clone, Copy, Debug)]
122pub struct LagrangeCoeff;
123impl Basis for LagrangeCoeff {}
124
125#[derive(Clone, Copy, Debug)]
128pub struct ExtendedLagrangeCoeff;
129impl Basis for ExtendedLagrangeCoeff {}
130
131pub struct DevicePoly<T, B> {
132 pub is_bit_reversed: bool,
133 pub coeff: DeviceBuffer<T>,
134 _marker: PhantomData<B>,
135}
136
137impl<T, B> DevicePoly<T, B> {
138 pub fn new(is_bit_reversed: bool, coeff: DeviceBuffer<T>) -> Self {
139 Self {
140 is_bit_reversed,
141 coeff,
142 _marker: PhantomData,
143 }
144 }
145
146 #[inline]
147 pub fn len(&self) -> usize {
148 self.coeff.len()
149 }
150
151 #[inline]
152 pub fn is_empty(&self) -> bool {
153 self.len() == 0
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_device_matrix() {
163 let buffer = Arc::new(DeviceBuffer::<i32>::with_capacity(12));
164 let matrix = DeviceMatrix::<i32>::new(buffer, 3, 4);
165 assert_eq!(matrix.height(), 3);
166 assert_eq!(matrix.width(), 4);
167 }
168}