openvm_cuda_common/
d_buffer.rs1use std::{ffi::c_void, fmt::Debug, ptr};
2
3use crate::{
4 copy::MemCopyD2H,
5 error::{check, CudaError},
6 memory_manager::{d_free, d_malloc},
7 stream::{cudaStreamPerThread, cudaStream_t},
8};
9
10#[link(name = "cudart")]
11extern "C" {
12 fn cudaMemsetAsync(dst: *mut c_void, value: i32, count: usize, stream: cudaStream_t) -> i32;
13}
14
15pub struct DeviceBuffer<T> {
16 ptr: *mut T,
17 len: usize,
18}
19
20#[repr(C)]
26#[derive(Debug, Copy, Clone)]
27pub struct DeviceBufferView {
28 pub ptr: *const c_void,
29 pub size: usize,
30}
31
32unsafe impl<T> Send for DeviceBuffer<T> {}
33unsafe impl<T> Sync for DeviceBuffer<T> {}
34
35impl<T> DeviceBuffer<T> {
36 #[allow(clippy::new_without_default)]
38 pub fn new() -> Self {
39 DeviceBuffer {
40 ptr: ptr::null_mut(),
41 len: 0,
42 }
43 }
44
45 pub fn with_capacity(len: usize) -> Self {
47 tracing::debug!(
48 "Creating device buffer of size {} (sizeof type = {})",
49 len,
50 size_of::<T>()
51 );
52 assert_ne!(len, 0, "Zero capacity request is wrong");
53 let size_bytes = std::mem::size_of::<T>() * len;
54 let raw_ptr = d_malloc(size_bytes).expect("GPU allocation failed");
55 #[cfg(feature = "touchemall")]
56 {
57 unsafe {
59 cudaMemsetAsync(raw_ptr, 0xff, size_bytes, cudaStreamPerThread);
60 }
61 }
62 let typed_ptr = raw_ptr as *mut T;
63
64 DeviceBuffer {
65 ptr: typed_ptr,
66 len,
67 }
68 }
69
70 pub fn fill_zero(&self) -> Result<(), CudaError> {
72 assert_ne!(self.len, 0, "Empty buffer");
73 let size_bytes = std::mem::size_of::<T>() * self.len;
74 check(unsafe { cudaMemsetAsync(self.as_mut_raw_ptr(), 0, size_bytes, cudaStreamPerThread) })
75 }
76
77 pub fn fill_zero_suffix(&self, start_idx: usize) -> Result<(), CudaError> {
80 assert!(
81 start_idx < self.len,
82 "start index has to be smaller than length"
83 );
84 let size_bytes = std::mem::size_of::<T>() * (self.len - start_idx);
85 check(unsafe {
86 cudaMemsetAsync(
87 self.as_mut_ptr().add(start_idx) as *mut c_void,
88 0,
89 size_bytes,
90 cudaStreamPerThread,
91 )
92 })
93 }
94
95 pub fn len(&self) -> usize {
97 self.len
98 }
99
100 pub fn is_empty(&self) -> bool {
102 self.len == 0 || self.ptr.is_null()
103 }
104
105 pub fn as_mut_ptr(&self) -> *mut T {
107 self.ptr
108 }
109
110 pub fn as_ptr(&self) -> *const T {
112 self.ptr as *const T
113 }
114
115 pub fn as_mut_raw_ptr(&self) -> *mut c_void {
117 self.ptr as *mut c_void
118 }
119
120 pub fn as_raw_ptr(&self) -> *const c_void {
122 self.ptr as *const c_void
123 }
124
125 pub fn as_buffer<U>(mut self) -> DeviceBuffer<U> {
128 assert_eq!(
129 size_of::<T>() % size_of::<U>(),
130 0,
131 "the underlying type size must divide the former one"
132 );
133 assert_eq!(
134 align_of::<T>() % align_of::<U>(),
135 0,
136 "the underlying type alignment must divide the former one"
137 );
138 let res = DeviceBuffer {
139 ptr: self.ptr as *mut U,
140 len: self.len * (size_of::<T>() / size_of::<U>()),
141 };
142 self.ptr = ptr::null_mut(); self.len = 0;
144 res
145 }
146
147 pub fn view(&self) -> DeviceBufferView {
148 DeviceBufferView {
149 ptr: self.ptr as *const c_void,
150 size: self.len * size_of::<T>(),
151 }
152 }
153}
154
155impl<T> Drop for DeviceBuffer<T> {
156 fn drop(&mut self) {
157 if !self.ptr.is_null() {
158 tracing::debug!(
159 "Freeing device buffer of size {} (sizeof type = {})",
160 self.len,
161 size_of::<T>()
162 );
163 unsafe {
164 d_free(self.ptr as *mut c_void).expect("GPU free failed");
165 }
166 self.ptr = ptr::null_mut();
167 self.len = 0;
168 }
169 }
170}
171
172impl<T: Debug> Debug for DeviceBuffer<T> {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 let host_vec = self.to_host().unwrap();
175 write!(f, "DeviceBuffer (len = {}): {:?}", self.len(), host_vec)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::copy::MemCopyH2D;
183
184 #[test]
185 fn test_device_buffer_float() {
186 let db = DeviceBuffer::<f32>::with_capacity(10);
188
189 assert_eq!(db.len(), 10);
190 assert!(!db.as_ptr().is_null());
191 assert!(!db.is_empty());
192
193 }
195
196 #[test]
197 fn test_device_buffer_fill_zero() {
198 let v: Vec<u64> = (0..10).collect();
199 let d_array = v.to_device().unwrap();
200 d_array.fill_zero().unwrap();
201 assert_eq!(d_array.to_host().unwrap(), vec![0; v.len()]);
202 }
203}