openvm_cuda_common/
stream.rs

1use std::{borrow::Cow, ffi::c_void};
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7    fn cudaStreamCreate(stream: *mut cudaStream_t) -> i32;
8    fn cudaStreamDestroy(stream: cudaStream_t) -> i32;
9    fn cudaStreamSynchronize(stream: cudaStream_t) -> i32;
10    fn cudaStreamWaitEvent(stream: cudaStream_t, event: cudaEvent_t, flags: u32) -> i32;
11    fn cudaEventCreate(event: *mut cudaEvent_t) -> i32;
12    fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> i32;
13    fn cudaEventSynchronize(event: cudaEvent_t) -> i32;
14    fn cudaEventDestroy(event: cudaEvent_t) -> i32;
15    fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> i32;
16}
17
18#[allow(non_camel_case_types)]
19pub type cudaStream_t = *mut c_void;
20
21pub struct CudaStream {
22    stream: cudaStream_t,
23}
24
25unsafe impl Send for CudaStream {}
26unsafe impl Sync for CudaStream {}
27
28impl CudaStream {
29    /// Creates a new non-blocking CUDA stream.
30    pub fn new() -> Result<Self, CudaError> {
31        let mut stream: cudaStream_t = std::ptr::null_mut();
32        check(unsafe { cudaStreamCreate(&mut stream) })?;
33        Ok(Self { stream })
34    }
35
36    /// Get the raw CUDA stream handle.
37    #[inline]
38    pub fn as_raw(&self) -> cudaStream_t {
39        self.stream
40    }
41
42    /// Synchronize this stream.
43    pub fn synchronize(&self) -> Result<(), CudaError> {
44        check(unsafe { cudaStreamSynchronize(self.stream) })
45    }
46
47    /// Wait for the given event.
48    pub fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
49        check(unsafe { cudaStreamWaitEvent(self.stream, event.event, 0) })
50    }
51}
52
53impl Drop for CudaStream {
54    fn drop(&mut self) {
55        if !self.stream.is_null() {
56            self.synchronize().unwrap();
57            let _ = unsafe { cudaStreamDestroy(self.stream) };
58            self.stream = std::ptr::null_mut();
59        }
60    }
61}
62
63#[allow(non_camel_case_types)]
64pub type cudaEvent_t = *mut c_void;
65#[allow(non_upper_case_globals)]
66pub const cudaStreamPerThread: cudaStream_t = 0x02 as cudaStream_t;
67
68pub fn default_stream_sync() -> Result<(), CudaError> {
69    check(unsafe { cudaStreamSynchronize(cudaStreamPerThread) })
70}
71
72pub struct CudaEvent {
73    event: cudaEvent_t,
74}
75
76pub fn default_stream_wait(event: &CudaEvent) -> Result<(), CudaError> {
77    check(unsafe { cudaStreamWaitEvent(cudaStreamPerThread, event.event, 0) })
78}
79
80unsafe impl Send for CudaEvent {}
81unsafe impl Sync for CudaEvent {}
82
83impl CudaEvent {
84    pub fn new() -> Result<Self, CudaError> {
85        let mut event: cudaEvent_t = std::ptr::null_mut();
86        check(unsafe { cudaEventCreate(&mut event) })?;
87        Ok(Self { event })
88    }
89
90    /// # Safety
91    /// The caller must ensure that `stream` is a valid stream.
92    pub unsafe fn record(&self, stream: cudaStream_t) -> Result<(), CudaError> {
93        check(cudaEventRecord(self.event, stream))
94    }
95
96    /// # Safety
97    /// The caller must ensure that `stream` is a valid stream.
98    pub unsafe fn record_and_wait(&self, stream: cudaStream_t) -> Result<(), CudaError> {
99        self.record(stream)?;
100        check(cudaEventSynchronize(self.event))
101    }
102}
103
104impl Drop for CudaEvent {
105    fn drop(&mut self) {
106        unsafe { cudaEventDestroy(self.event) };
107    }
108}
109
110/// A GPU-aware span that collects a gauge metric using CUDA events.
111pub fn gpu_metrics_span<R, F: FnOnce() -> R>(
112    name: impl Into<Cow<'static, str>>,
113    f: F,
114) -> Result<R, CudaError> {
115    let start = CudaEvent::new()?;
116    let stop = CudaEvent::new()?;
117    unsafe {
118        check(cudaEventRecord(start.event, cudaStreamPerThread))?;
119    }
120    let res = f();
121    unsafe { stop.record_and_wait(cudaStreamPerThread)? };
122
123    let mut elapsed_ms = 0f32;
124    unsafe {
125        check(cudaEventElapsedTime(
126            &mut elapsed_ms,
127            start.event,
128            stop.event,
129        ))?
130    };
131
132    metrics::gauge!(name.into()).set(elapsed_ms as f64);
133    Ok(res)
134}