openvm_cuda_backend/
base.rs

1use std::{fmt::Debug, marker::PhantomData, sync::Arc};
2
3use openvm_cuda_common::{copy::MemCopyD2H, d_buffer::DeviceBuffer, error::MemCopyError};
4use openvm_stark_backend::prover::hal::MatrixDimensions;
5
6pub struct DeviceMatrix<T> {
7    buffer: Arc<DeviceBuffer<T>>,
8    height: usize,
9    width: usize,
10}
11
12unsafe impl<T> Send for DeviceMatrix<T> {}
13unsafe impl<T> Sync for DeviceMatrix<T> {}
14
15impl<T> Clone for DeviceMatrix<T> {
16    fn clone(&self) -> Self {
17        Self {
18            buffer: Arc::clone(&self.buffer),
19            height: self.height,
20            width: self.width,
21        }
22    }
23}
24
25impl<T> Drop for DeviceMatrix<T> {
26    fn drop(&mut self) {
27        tracing::debug!(
28            "Dropping DeviceMatrix of size {} with Arc strong count={}",
29            self.buffer.len(),
30            self.strong_count()
31        );
32    }
33}
34
35impl<T> DeviceMatrix<T> {
36    pub fn new(buffer: Arc<DeviceBuffer<T>>, height: usize, width: usize) -> Self {
37        assert_ne!(
38            height * width,
39            0,
40            "Zero dimensions h {} w {} are wrong",
41            height,
42            width
43        );
44        assert_eq!(
45            buffer.len(),
46            height * width,
47            "Buffer size must match dimensions"
48        );
49        Self {
50            buffer,
51            height,
52            width,
53        }
54    }
55
56    pub fn with_capacity(height: usize, width: usize) -> Self {
57        Self {
58            buffer: Arc::new(DeviceBuffer::with_capacity(height * width)),
59            height,
60            width,
61        }
62    }
63
64    pub fn dummy() -> Self {
65        Self {
66            buffer: Arc::new(DeviceBuffer::new()),
67            height: 0,
68            width: 0,
69        }
70    }
71
72    pub fn buffer(&self) -> &DeviceBuffer<T> {
73        &self.buffer
74    }
75
76    pub fn strong_count(&self) -> usize {
77        Arc::strong_count(&self.buffer)
78    }
79}
80
81impl<T> MatrixDimensions for DeviceMatrix<T> {
82    #[inline]
83    fn height(&self) -> usize {
84        self.height
85    }
86
87    #[inline]
88    fn width(&self) -> usize {
89        self.width
90    }
91}
92
93impl<T> MemCopyD2H<T> for DeviceMatrix<T> {
94    fn to_host(&self) -> Result<Vec<T>, MemCopyError> {
95        self.buffer.to_host()
96    }
97}
98
99impl<T: Debug> Debug for DeviceMatrix<T> {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        write!(
102            f,
103            "DeviceMatrix (height = {}, width = {}): {:?}",
104            self.height(),
105            self.width(),
106            self.buffer()
107        )
108    }
109}
110
111/// The following trait and types are borrowed from [halo2](https:://github.com/zcash/halo2).
112/// The basis over which a polynomial is described.
113pub trait Basis: Copy + Debug + Send + Sync {}
114
115/// The polynomial is defined as coefficients
116#[derive(Clone, Copy, Debug)]
117pub struct Coeff;
118impl Basis for Coeff {}
119
120/// The polynomial is defined as coefficients of Lagrange basis polynomials
121#[derive(Clone, Copy, Debug)]
122pub struct LagrangeCoeff;
123impl Basis for LagrangeCoeff {}
124
125/// The polynomial is defined as coefficients of Lagrange basis polynomials in
126/// an extended size domain which supports multiplication
127#[derive(Clone, Copy, Debug)]
128pub struct ExtendedLagrangeCoeff;
129impl Basis for ExtendedLagrangeCoeff {}
130
131pub struct DevicePoly<T, B> {
132    pub is_bit_reversed: bool,
133    pub coeff: DeviceBuffer<T>,
134    _marker: PhantomData<B>,
135}
136
137impl<T, B> DevicePoly<T, B> {
138    pub fn new(is_bit_reversed: bool, coeff: DeviceBuffer<T>) -> Self {
139        Self {
140            is_bit_reversed,
141            coeff,
142            _marker: PhantomData,
143        }
144    }
145
146    #[inline]
147    pub fn len(&self) -> usize {
148        self.coeff.len()
149    }
150
151    #[inline]
152    pub fn is_empty(&self) -> bool {
153        self.len() == 0
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_device_matrix() {
163        let buffer = Arc::new(DeviceBuffer::<i32>::with_capacity(12));
164        let matrix = DeviceMatrix::<i32>::new(buffer, 3, 4);
165        assert_eq!(matrix.height(), 3);
166        assert_eq!(matrix.width(), 4);
167    }
168}