openvm_cuda_common/
copy.rs

1use std::{cell::RefCell, ffi::c_void};
2
3use crate::{
4    d_buffer::DeviceBuffer,
5    error::{check, MemCopyError},
6    stream::{cudaStreamPerThread, cudaStream_t, CudaEvent},
7};
8
9thread_local! {
10    static COPY_EVENT: RefCell<Option<CudaEvent>> = RefCell::new(Some(CudaEvent::new().unwrap()));
11}
12
13#[repr(i32)]
14#[non_exhaustive]
15#[allow(non_camel_case_types)]
16#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
17pub enum cudaMemcpyKind {
18    cudaMemcpyHostToHost = 0,
19    cudaMemcpyHostToDevice = 1,
20    cudaMemcpyDeviceToHost = 2,
21    cudaMemcpyDeviceToDevice = 3,
22    cudaMemcpyDefault = 4,
23}
24
25#[link(name = "cudart")]
26extern "C" {
27    fn cudaMemcpyAsync(
28        dst: *mut c_void,
29        src: *const c_void,
30        count: usize,
31        kind: cudaMemcpyKind,
32        stream: cudaStream_t,
33    ) -> i32;
34}
35
36/// FFI binding for the `cudaMemcpyAsync` function on the default cuda stream.
37///
38/// # Safety
39/// Must follow the rules of the `cudaMemcpyAsync` function from the CUDA runtime API.
40pub unsafe fn cuda_memcpy<const SRC_DEVICE: bool, const DST_DEVICE: bool>(
41    dst: *mut c_void,
42    src: *const c_void,
43    size_bytes: usize,
44) -> Result<(), MemCopyError> {
45    check(unsafe {
46        cudaMemcpyAsync(
47            dst,
48            src,
49            size_bytes,
50            std::mem::transmute::<i32, cudaMemcpyKind>(
51                if DST_DEVICE { 1 } else { 0 } + if SRC_DEVICE { 2 } else { 0 },
52            ),
53            cudaStreamPerThread,
54        )
55    })
56    .map_err(MemCopyError::from)
57}
58
59// Host -> Device
60pub trait MemCopyH2D<T> {
61    fn copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError>;
62    fn to_device(&self) -> Result<DeviceBuffer<T>, MemCopyError>;
63}
64
65impl<T> MemCopyH2D<T> for [T] {
66    fn copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError> {
67        if self.len() > dst.len() {
68            return Err(MemCopyError::SizeMismatch {
69                operation: "copy_to_device",
70                host_len: self.len(),
71                device_len: dst.len(),
72            });
73        }
74        let size_bytes = std::mem::size_of_val(self);
75        check(unsafe {
76            cudaMemcpyAsync(
77                dst.as_mut_raw_ptr(),
78                self.as_ptr() as *const c_void,
79                size_bytes,
80                cudaMemcpyKind::cudaMemcpyHostToDevice,
81                cudaStreamPerThread,
82            )
83        })
84        .map_err(MemCopyError::from)
85    }
86
87    fn to_device(&self) -> Result<DeviceBuffer<T>, MemCopyError> {
88        let mut dst = DeviceBuffer::with_capacity(self.len());
89        self.copy_to(&mut dst)?;
90        Ok(dst)
91    }
92}
93
94// Device -> Host
95pub trait MemCopyD2H<T> {
96    fn to_host(&self) -> Result<Vec<T>, MemCopyError>;
97}
98
99impl<T> MemCopyD2H<T> for DeviceBuffer<T> {
100    fn to_host(&self) -> Result<Vec<T>, MemCopyError> {
101        let mut host_vec = Vec::with_capacity(self.len());
102        let size_bytes = std::mem::size_of::<T>() * self.len();
103
104        check(unsafe {
105            cudaMemcpyAsync(
106                host_vec.as_mut_ptr() as *mut c_void,
107                self.as_raw_ptr(),
108                size_bytes,
109                cudaMemcpyKind::cudaMemcpyDeviceToHost,
110                cudaStreamPerThread,
111            )
112        })?;
113        unsafe {
114            COPY_EVENT
115                .with_borrow(|ce| ce.as_ref().unwrap().record_and_wait(cudaStreamPerThread))?;
116
117            host_vec.set_len(self.len());
118        }
119
120        Ok(host_vec)
121    }
122}
123
124pub trait MemCopyD2D<T> {
125    fn device_copy(&self) -> Result<DeviceBuffer<T>, MemCopyError>;
126    fn device_copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError>;
127}
128
129impl<T> MemCopyD2D<T> for DeviceBuffer<T> {
130    fn device_copy(&self) -> Result<DeviceBuffer<T>, MemCopyError> {
131        let mut dst = DeviceBuffer::<T>::with_capacity(self.len());
132        self.device_copy_to(&mut dst)?;
133        Ok(dst)
134    }
135
136    fn device_copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError> {
137        let size_bytes = std::mem::size_of::<T>() * self.len();
138
139        check(unsafe {
140            cudaMemcpyAsync(
141                dst.as_mut_raw_ptr(),
142                self.as_raw_ptr(),
143                size_bytes,
144                cudaMemcpyKind::cudaMemcpyDeviceToDevice,
145                cudaStreamPerThread,
146            )
147        })
148        .map_err(MemCopyError::from)
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::d_buffer::DeviceBuffer;
156
157    #[test]
158    fn test_mem_copy() {
159        // Our source data on the host
160        let h = vec![1, 2, 3, 4, 5];
161
162        // 1) Copy to a newly allocated device buffer
163        let d1 = h.to_device().unwrap();
164
165        // 2) Create another device buffer of the same size
166        let mut d2 = DeviceBuffer::<i32>::with_capacity(h.len());
167
168        // 3) Copy into that existing buffer
169        h.copy_to(&mut d2).unwrap();
170
171        // 4) Copy both buffers back to host
172        let h1 = d1.to_host().unwrap();
173        let h2 = d2.to_host().unwrap();
174
175        assert_eq!(h, h1, "First device buffer mismatch");
176        assert_eq!(h, h2, "Second device buffer mismatch");
177    }
178}