openvm_cuda_common/
copy.rs1use std::{cell::RefCell, ffi::c_void};
2
3use crate::{
4 d_buffer::DeviceBuffer,
5 error::{check, MemCopyError},
6 stream::{cudaStreamPerThread, cudaStream_t, CudaEvent},
7};
8
9thread_local! {
10 static COPY_EVENT: RefCell<Option<CudaEvent>> = RefCell::new(Some(CudaEvent::new().unwrap()));
11}
12
13#[repr(i32)]
14#[non_exhaustive]
15#[allow(non_camel_case_types)]
16#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
17pub enum cudaMemcpyKind {
18 cudaMemcpyHostToHost = 0,
19 cudaMemcpyHostToDevice = 1,
20 cudaMemcpyDeviceToHost = 2,
21 cudaMemcpyDeviceToDevice = 3,
22 cudaMemcpyDefault = 4,
23}
24
25#[link(name = "cudart")]
26extern "C" {
27 fn cudaMemcpyAsync(
28 dst: *mut c_void,
29 src: *const c_void,
30 count: usize,
31 kind: cudaMemcpyKind,
32 stream: cudaStream_t,
33 ) -> i32;
34}
35
36pub unsafe fn cuda_memcpy<const SRC_DEVICE: bool, const DST_DEVICE: bool>(
41 dst: *mut c_void,
42 src: *const c_void,
43 size_bytes: usize,
44) -> Result<(), MemCopyError> {
45 check(unsafe {
46 cudaMemcpyAsync(
47 dst,
48 src,
49 size_bytes,
50 std::mem::transmute::<i32, cudaMemcpyKind>(
51 if DST_DEVICE { 1 } else { 0 } + if SRC_DEVICE { 2 } else { 0 },
52 ),
53 cudaStreamPerThread,
54 )
55 })
56 .map_err(MemCopyError::from)
57}
58
59pub trait MemCopyH2D<T> {
61 fn copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError>;
62 fn to_device(&self) -> Result<DeviceBuffer<T>, MemCopyError>;
63}
64
65impl<T> MemCopyH2D<T> for [T] {
66 fn copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError> {
67 if self.len() > dst.len() {
68 return Err(MemCopyError::SizeMismatch {
69 operation: "copy_to_device",
70 host_len: self.len(),
71 device_len: dst.len(),
72 });
73 }
74 let size_bytes = std::mem::size_of_val(self);
75 check(unsafe {
76 cudaMemcpyAsync(
77 dst.as_mut_raw_ptr(),
78 self.as_ptr() as *const c_void,
79 size_bytes,
80 cudaMemcpyKind::cudaMemcpyHostToDevice,
81 cudaStreamPerThread,
82 )
83 })
84 .map_err(MemCopyError::from)
85 }
86
87 fn to_device(&self) -> Result<DeviceBuffer<T>, MemCopyError> {
88 let mut dst = DeviceBuffer::with_capacity(self.len());
89 self.copy_to(&mut dst)?;
90 Ok(dst)
91 }
92}
93
94pub trait MemCopyD2H<T> {
96 fn to_host(&self) -> Result<Vec<T>, MemCopyError>;
97}
98
99impl<T> MemCopyD2H<T> for DeviceBuffer<T> {
100 fn to_host(&self) -> Result<Vec<T>, MemCopyError> {
101 let mut host_vec = Vec::with_capacity(self.len());
102 let size_bytes = std::mem::size_of::<T>() * self.len();
103
104 check(unsafe {
105 cudaMemcpyAsync(
106 host_vec.as_mut_ptr() as *mut c_void,
107 self.as_raw_ptr(),
108 size_bytes,
109 cudaMemcpyKind::cudaMemcpyDeviceToHost,
110 cudaStreamPerThread,
111 )
112 })?;
113 unsafe {
114 COPY_EVENT
115 .with_borrow(|ce| ce.as_ref().unwrap().record_and_wait(cudaStreamPerThread))?;
116
117 host_vec.set_len(self.len());
118 }
119
120 Ok(host_vec)
121 }
122}
123
124pub trait MemCopyD2D<T> {
125 fn device_copy(&self) -> Result<DeviceBuffer<T>, MemCopyError>;
126 fn device_copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError>;
127}
128
129impl<T> MemCopyD2D<T> for DeviceBuffer<T> {
130 fn device_copy(&self) -> Result<DeviceBuffer<T>, MemCopyError> {
131 let mut dst = DeviceBuffer::<T>::with_capacity(self.len());
132 self.device_copy_to(&mut dst)?;
133 Ok(dst)
134 }
135
136 fn device_copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError> {
137 let size_bytes = std::mem::size_of::<T>() * self.len();
138
139 check(unsafe {
140 cudaMemcpyAsync(
141 dst.as_mut_raw_ptr(),
142 self.as_raw_ptr(),
143 size_bytes,
144 cudaMemcpyKind::cudaMemcpyDeviceToDevice,
145 cudaStreamPerThread,
146 )
147 })
148 .map_err(MemCopyError::from)
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::d_buffer::DeviceBuffer;
156
157 #[test]
158 fn test_mem_copy() {
159 let h = vec![1, 2, 3, 4, 5];
161
162 let d1 = h.to_device().unwrap();
164
165 let mut d2 = DeviceBuffer::<i32>::with_capacity(h.len());
167
168 h.copy_to(&mut d2).unwrap();
170
171 let h1 = d1.to_host().unwrap();
173 let h2 = d2.to_host().unwrap();
174
175 assert_eq!(h, h1, "First device buffer mismatch");
176 assert_eq!(h, h2, "Second device buffer mismatch");
177 }
178}