openvm_cuda_common/
d_buffer.rs

1use std::{ffi::c_void, fmt::Debug, ptr};
2
3use crate::{
4    copy::MemCopyD2H,
5    error::{check, CudaError},
6    memory_manager::{d_free, d_malloc},
7    stream::{cudaStreamPerThread, cudaStream_t},
8};
9
10#[link(name = "cudart")]
11extern "C" {
12    fn cudaMemsetAsync(dst: *mut c_void, value: i32, count: usize, stream: cudaStream_t) -> i32;
13}
14
15pub struct DeviceBuffer<T> {
16    ptr: *mut T,
17    len: usize,
18}
19
20/// A struct that packs a pointer with a size in bytes to pass on CUDA.
21/// It holds `*const c_void` for being a universal simple type that can be read by CUDA.
22/// Since it is hard to enforce immutability preservation, it just holds `*const`,
23/// but has two separate constructors for more robustness from the usage perspective.
24/// This is essentially a [DeviceBuffer] but without owning the data.
25#[repr(C)]
26#[derive(Debug, Copy, Clone)]
27pub struct DeviceBufferView {
28    pub ptr: *const c_void,
29    pub size: usize,
30}
31
32unsafe impl<T> Send for DeviceBuffer<T> {}
33unsafe impl<T> Sync for DeviceBuffer<T> {}
34
35impl<T> DeviceBuffer<T> {
36    /// Creates an "empty" DeviceBuffer with a null pointer and zero length.
37    #[allow(clippy::new_without_default)]
38    pub fn new() -> Self {
39        DeviceBuffer {
40            ptr: ptr::null_mut(),
41            len: 0,
42        }
43    }
44
45    /// Allocate device memory for `len` elements of type `T`.
46    pub fn with_capacity(len: usize) -> Self {
47        tracing::debug!(
48            "Creating device buffer of size {} (sizeof type = {})",
49            len,
50            size_of::<T>()
51        );
52        assert_ne!(len, 0, "Zero capacity request is wrong");
53        let size_bytes = std::mem::size_of::<T>() * len;
54        let raw_ptr = d_malloc(size_bytes).expect("GPU allocation failed");
55        #[cfg(feature = "touchemall")]
56        {
57            // 0xffffffff is `Fp::invalid()` and shouldn't occur in a trace
58            unsafe {
59                cudaMemsetAsync(raw_ptr, 0xff, size_bytes, cudaStreamPerThread);
60            }
61        }
62        let typed_ptr = raw_ptr as *mut T;
63
64        DeviceBuffer {
65            ptr: typed_ptr,
66            len,
67        }
68    }
69
70    /// Fills the buffer with zeros.
71    pub fn fill_zero(&self) -> Result<(), CudaError> {
72        assert_ne!(self.len, 0, "Empty buffer");
73        let size_bytes = std::mem::size_of::<T>() * self.len;
74        check(unsafe { cudaMemsetAsync(self.as_mut_raw_ptr(), 0, size_bytes, cudaStreamPerThread) })
75    }
76
77    /// Fills a suffix of the buffer with zeros.
78    /// The `start_idx` is the index in the buffer, in `T` elements.
79    pub fn fill_zero_suffix(&self, start_idx: usize) -> Result<(), CudaError> {
80        assert!(
81            start_idx < self.len,
82            "start index has to be smaller than length"
83        );
84        let size_bytes = std::mem::size_of::<T>() * (self.len - start_idx);
85        check(unsafe {
86            cudaMemsetAsync(
87                self.as_mut_ptr().add(start_idx) as *mut c_void,
88                0,
89                size_bytes,
90                cudaStreamPerThread,
91            )
92        })
93    }
94
95    /// Returns the number of elements in this buffer.
96    pub fn len(&self) -> usize {
97        self.len
98    }
99
100    /// Returns whether the buffer is empty (null pointer or zero length).
101    pub fn is_empty(&self) -> bool {
102        self.len == 0 || self.ptr.is_null()
103    }
104
105    /// Returns a raw mutable pointer to the device data (typed).
106    pub fn as_mut_ptr(&self) -> *mut T {
107        self.ptr
108    }
109
110    /// Returns a raw const pointer to the device data (typed).
111    pub fn as_ptr(&self) -> *const T {
112        self.ptr as *const T
113    }
114
115    /// Returns a `*mut c_void` (untyped) pointer.
116    pub fn as_mut_raw_ptr(&self) -> *mut c_void {
117        self.ptr as *mut c_void
118    }
119
120    /// Returns a `*const c_void` (untyped) pointer.
121    pub fn as_raw_ptr(&self) -> *const c_void {
122        self.ptr as *const c_void
123    }
124
125    /// Converts the buffer to a buffer of different type.
126    /// `T` must be composable of `U`s.
127    pub fn as_buffer<U>(mut self) -> DeviceBuffer<U> {
128        assert_eq!(
129            size_of::<T>() % size_of::<U>(),
130            0,
131            "the underlying type size must divide the former one"
132        );
133        assert_eq!(
134            align_of::<T>() % align_of::<U>(),
135            0,
136            "the underlying type alignment must divide the former one"
137        );
138        let res = DeviceBuffer {
139            ptr: self.ptr as *mut U,
140            len: self.len * (size_of::<T>() / size_of::<U>()),
141        };
142        self.ptr = ptr::null_mut(); // for safe drop
143        self.len = 0;
144        res
145    }
146
147    pub fn view(&self) -> DeviceBufferView {
148        DeviceBufferView {
149            ptr: self.ptr as *const c_void,
150            size: self.len * size_of::<T>(),
151        }
152    }
153}
154
155impl<T> Drop for DeviceBuffer<T> {
156    fn drop(&mut self) {
157        if !self.ptr.is_null() {
158            tracing::debug!(
159                "Freeing device buffer of size {} (sizeof type = {})",
160                self.len,
161                size_of::<T>()
162            );
163            unsafe {
164                d_free(self.ptr as *mut c_void).expect("GPU free failed");
165            }
166            self.ptr = ptr::null_mut();
167            self.len = 0;
168        }
169    }
170}
171
172impl<T: Debug> Debug for DeviceBuffer<T> {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        let host_vec = self.to_host().unwrap();
175        write!(f, "DeviceBuffer (len = {}): {:?}", self.len(), host_vec)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::copy::MemCopyH2D;
183
184    #[test]
185    fn test_device_buffer_float() {
186        // Create a DeviceBuffer of 10 floats
187        let db = DeviceBuffer::<f32>::with_capacity(10);
188
189        assert_eq!(db.len(), 10);
190        assert!(!db.as_ptr().is_null());
191        assert!(!db.is_empty());
192
193        // The buffer will be automatically freed at the end of this test
194    }
195
196    #[test]
197    fn test_device_buffer_fill_zero() {
198        let v: Vec<u64> = (0..10).collect();
199        let d_array = v.to_device().unwrap();
200        d_array.fill_zero().unwrap();
201        assert_eq!(d_array.to_host().unwrap(), vec![0; v.len()]);
202    }
203}