openvm_cuda_common/
copy.rs1use std::{ffi::c_void, sync::Mutex};
2
3use lazy_static::lazy_static;
4
5use crate::{
6 d_buffer::DeviceBuffer,
7 error::{check, MemCopyError},
8 stream::{cudaStreamPerThread, cudaStream_t, CudaEvent},
9};
10
11lazy_static! {
12 static ref COPY_EVENT: Mutex<CudaEvent> = Mutex::new(CudaEvent::new().unwrap());
13}
14
15#[repr(i32)]
16#[non_exhaustive]
17#[allow(non_camel_case_types)]
18#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
19pub enum cudaMemcpyKind {
20 cudaMemcpyHostToHost = 0,
21 cudaMemcpyHostToDevice = 1,
22 cudaMemcpyDeviceToHost = 2,
23 cudaMemcpyDeviceToDevice = 3,
24 cudaMemcpyDefault = 4,
25}
26
27#[link(name = "cudart")]
28extern "C" {
29 fn cudaMemcpyAsync(
30 dst: *mut c_void,
31 src: *const c_void,
32 count: usize,
33 kind: cudaMemcpyKind,
34 stream: cudaStream_t,
35 ) -> i32;
36}
37
38pub unsafe fn cuda_memcpy<const SRC_DEVICE: bool, const DST_DEVICE: bool>(
43 dst: *mut c_void,
44 src: *const c_void,
45 size_bytes: usize,
46) -> Result<(), MemCopyError> {
47 check(unsafe {
48 cudaMemcpyAsync(
49 dst,
50 src,
51 size_bytes,
52 std::mem::transmute::<i32, cudaMemcpyKind>(
53 if DST_DEVICE { 1 } else { 0 } + if SRC_DEVICE { 2 } else { 0 },
54 ),
55 cudaStreamPerThread,
56 )
57 })
58 .map_err(MemCopyError::from)
59}
60
61pub trait MemCopyH2D<T> {
63 fn copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError>;
64 fn to_device(&self) -> Result<DeviceBuffer<T>, MemCopyError>;
65}
66
67impl<T> MemCopyH2D<T> for [T] {
68 fn copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError> {
69 if self.len() > dst.len() {
70 return Err(MemCopyError::SizeMismatch {
71 operation: "copy_to_device",
72 host_len: self.len(),
73 device_len: dst.len(),
74 });
75 }
76 let size_bytes = std::mem::size_of_val(self);
77 check(unsafe {
78 cudaMemcpyAsync(
79 dst.as_mut_raw_ptr(),
80 self.as_ptr() as *const c_void,
81 size_bytes,
82 cudaMemcpyKind::cudaMemcpyHostToDevice,
83 cudaStreamPerThread,
84 )
85 })
86 .map_err(MemCopyError::from)
87 }
88
89 fn to_device(&self) -> Result<DeviceBuffer<T>, MemCopyError> {
90 let mut dst = DeviceBuffer::with_capacity(self.len());
91 self.copy_to(&mut dst)?;
92 Ok(dst)
93 }
94}
95
96pub trait MemCopyD2H<T> {
98 fn to_host(&self) -> Result<Vec<T>, MemCopyError>;
99}
100
101impl<T> MemCopyD2H<T> for DeviceBuffer<T> {
102 fn to_host(&self) -> Result<Vec<T>, MemCopyError> {
103 let mut host_vec = Vec::with_capacity(self.len());
104 let size_bytes = std::mem::size_of::<T>() * self.len();
105
106 check(unsafe {
107 cudaMemcpyAsync(
108 host_vec.as_mut_ptr() as *mut c_void,
109 self.as_raw_ptr(),
110 size_bytes,
111 cudaMemcpyKind::cudaMemcpyDeviceToHost,
112 cudaStreamPerThread,
113 )
114 })?;
115 unsafe {
116 COPY_EVENT
117 .lock()
118 .unwrap()
119 .record_and_wait(cudaStreamPerThread)?;
120
121 host_vec.set_len(self.len());
122 }
123
124 Ok(host_vec)
125 }
126}
127
128pub trait MemCopyD2D<T> {
129 fn device_copy(&self) -> Result<DeviceBuffer<T>, MemCopyError>;
130 fn device_copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError>;
131}
132
133impl<T> MemCopyD2D<T> for DeviceBuffer<T> {
134 fn device_copy(&self) -> Result<DeviceBuffer<T>, MemCopyError> {
135 let mut dst = DeviceBuffer::<T>::with_capacity(self.len());
136 self.device_copy_to(&mut dst)?;
137 Ok(dst)
138 }
139
140 fn device_copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError> {
141 let size_bytes = std::mem::size_of::<T>() * self.len();
142
143 check(unsafe {
144 cudaMemcpyAsync(
145 dst.as_mut_raw_ptr(),
146 self.as_raw_ptr(),
147 size_bytes,
148 cudaMemcpyKind::cudaMemcpyDeviceToDevice,
149 cudaStreamPerThread,
150 )
151 })
152 .map_err(MemCopyError::from)
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::d_buffer::DeviceBuffer;
160
161 #[test]
162 fn test_mem_copy() {
163 let h = vec![1, 2, 3, 4, 5];
165
166 let d1 = h.to_device().unwrap();
168
169 let mut d2 = DeviceBuffer::<i32>::with_capacity(h.len());
171
172 h.copy_to(&mut d2).unwrap();
174
175 let h1 = d1.to_host().unwrap();
177 let h2 = d2.to_host().unwrap();
178
179 assert_eq!(h, h1, "First device buffer mismatch");
180 assert_eq!(h, h2, "Second device buffer mismatch");
181 }
182}