openvm_cuda_common/
copy.rs

1use std::{ffi::c_void, sync::Mutex};
2
3use lazy_static::lazy_static;
4
5use crate::{
6    d_buffer::DeviceBuffer,
7    error::{check, MemCopyError},
8    stream::{cudaStreamPerThread, cudaStream_t, CudaEvent},
9};
10
11lazy_static! {
12    static ref COPY_EVENT: Mutex<CudaEvent> = Mutex::new(CudaEvent::new().unwrap());
13}
14
15#[repr(i32)]
16#[non_exhaustive]
17#[allow(non_camel_case_types)]
18#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
19pub enum cudaMemcpyKind {
20    cudaMemcpyHostToHost = 0,
21    cudaMemcpyHostToDevice = 1,
22    cudaMemcpyDeviceToHost = 2,
23    cudaMemcpyDeviceToDevice = 3,
24    cudaMemcpyDefault = 4,
25}
26
27#[link(name = "cudart")]
28extern "C" {
29    fn cudaMemcpyAsync(
30        dst: *mut c_void,
31        src: *const c_void,
32        count: usize,
33        kind: cudaMemcpyKind,
34        stream: cudaStream_t,
35    ) -> i32;
36}
37
38/// FFI binding for the `cudaMemcpyAsync` function on the default cuda stream.
39///
40/// # Safety
41/// Must follow the rules of the `cudaMemcpyAsync` function from the CUDA runtime API.
42pub unsafe fn cuda_memcpy<const SRC_DEVICE: bool, const DST_DEVICE: bool>(
43    dst: *mut c_void,
44    src: *const c_void,
45    size_bytes: usize,
46) -> Result<(), MemCopyError> {
47    check(unsafe {
48        cudaMemcpyAsync(
49            dst,
50            src,
51            size_bytes,
52            std::mem::transmute::<i32, cudaMemcpyKind>(
53                if DST_DEVICE { 1 } else { 0 } + if SRC_DEVICE { 2 } else { 0 },
54            ),
55            cudaStreamPerThread,
56        )
57    })
58    .map_err(MemCopyError::from)
59}
60
61// Host -> Device
62pub trait MemCopyH2D<T> {
63    fn copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError>;
64    fn to_device(&self) -> Result<DeviceBuffer<T>, MemCopyError>;
65}
66
67impl<T> MemCopyH2D<T> for [T] {
68    fn copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError> {
69        if self.len() > dst.len() {
70            return Err(MemCopyError::SizeMismatch {
71                operation: "copy_to_device",
72                host_len: self.len(),
73                device_len: dst.len(),
74            });
75        }
76        let size_bytes = std::mem::size_of_val(self);
77        check(unsafe {
78            cudaMemcpyAsync(
79                dst.as_mut_raw_ptr(),
80                self.as_ptr() as *const c_void,
81                size_bytes,
82                cudaMemcpyKind::cudaMemcpyHostToDevice,
83                cudaStreamPerThread,
84            )
85        })
86        .map_err(MemCopyError::from)
87    }
88
89    fn to_device(&self) -> Result<DeviceBuffer<T>, MemCopyError> {
90        let mut dst = DeviceBuffer::with_capacity(self.len());
91        self.copy_to(&mut dst)?;
92        Ok(dst)
93    }
94}
95
96// Device -> Host
97pub trait MemCopyD2H<T> {
98    fn to_host(&self) -> Result<Vec<T>, MemCopyError>;
99}
100
101impl<T> MemCopyD2H<T> for DeviceBuffer<T> {
102    fn to_host(&self) -> Result<Vec<T>, MemCopyError> {
103        let mut host_vec = Vec::with_capacity(self.len());
104        let size_bytes = std::mem::size_of::<T>() * self.len();
105
106        check(unsafe {
107            cudaMemcpyAsync(
108                host_vec.as_mut_ptr() as *mut c_void,
109                self.as_raw_ptr(),
110                size_bytes,
111                cudaMemcpyKind::cudaMemcpyDeviceToHost,
112                cudaStreamPerThread,
113            )
114        })?;
115        unsafe {
116            COPY_EVENT
117                .lock()
118                .unwrap()
119                .record_and_wait(cudaStreamPerThread)?;
120
121            host_vec.set_len(self.len());
122        }
123
124        Ok(host_vec)
125    }
126}
127
128pub trait MemCopyD2D<T> {
129    fn device_copy(&self) -> Result<DeviceBuffer<T>, MemCopyError>;
130    fn device_copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError>;
131}
132
133impl<T> MemCopyD2D<T> for DeviceBuffer<T> {
134    fn device_copy(&self) -> Result<DeviceBuffer<T>, MemCopyError> {
135        let mut dst = DeviceBuffer::<T>::with_capacity(self.len());
136        self.device_copy_to(&mut dst)?;
137        Ok(dst)
138    }
139
140    fn device_copy_to(&self, dst: &mut DeviceBuffer<T>) -> Result<(), MemCopyError> {
141        let size_bytes = std::mem::size_of::<T>() * self.len();
142
143        check(unsafe {
144            cudaMemcpyAsync(
145                dst.as_mut_raw_ptr(),
146                self.as_raw_ptr(),
147                size_bytes,
148                cudaMemcpyKind::cudaMemcpyDeviceToDevice,
149                cudaStreamPerThread,
150            )
151        })
152        .map_err(MemCopyError::from)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::d_buffer::DeviceBuffer;
160
161    #[test]
162    fn test_mem_copy() {
163        // Our source data on the host
164        let h = vec![1, 2, 3, 4, 5];
165
166        // 1) Copy to a newly allocated device buffer
167        let d1 = h.to_device().unwrap();
168
169        // 2) Create another device buffer of the same size
170        let mut d2 = DeviceBuffer::<i32>::with_capacity(h.len());
171
172        // 3) Copy into that existing buffer
173        h.copy_to(&mut d2).unwrap();
174
175        // 4) Copy both buffers back to host
176        let h1 = d1.to_host().unwrap();
177        let h2 = d2.to_host().unwrap();
178
179        assert_eq!(h, h1, "First device buffer mismatch");
180        assert_eq!(h, h2, "Second device buffer mismatch");
181    }
182}