openvm_cuda_common/
stream.rs

1use std::{borrow::Cow, ffi::c_void};
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7    fn cudaStreamGetId(stream: cudaStream_t, id: *mut CudaStreamId) -> i32;
8    fn cudaStreamCreate(stream: *mut cudaStream_t) -> i32;
9    fn cudaStreamDestroy(stream: cudaStream_t) -> i32;
10    fn cudaStreamSynchronize(stream: cudaStream_t) -> i32;
11    fn cudaStreamWaitEvent(stream: cudaStream_t, event: cudaEvent_t, flags: u32) -> i32;
12    fn cudaEventCreate(event: *mut cudaEvent_t) -> i32;
13    fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> i32;
14    fn cudaEventSynchronize(event: cudaEvent_t) -> i32;
15    fn cudaEventQuery(event: cudaEvent_t) -> i32;
16    fn cudaEventDestroy(event: cudaEvent_t) -> i32;
17    fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> i32;
18}
19
20#[allow(non_camel_case_types)]
21pub type cudaStream_t = *mut c_void;
22
23pub struct CudaStream {
24    stream: cudaStream_t,
25}
26
27unsafe impl Send for CudaStream {}
28unsafe impl Sync for CudaStream {}
29
30impl CudaStream {
31    /// Creates a new non-blocking CUDA stream.
32    pub fn new() -> Result<Self, CudaError> {
33        let mut stream: cudaStream_t = std::ptr::null_mut();
34        check(unsafe { cudaStreamCreate(&mut stream) })?;
35        Ok(Self { stream })
36    }
37
38    /// Get the raw CUDA stream handle.
39    #[inline]
40    pub fn as_raw(&self) -> cudaStream_t {
41        self.stream
42    }
43
44    /// Synchronize this stream.
45    pub fn synchronize(&self) -> Result<(), CudaError> {
46        check(unsafe { cudaStreamSynchronize(self.stream) })
47    }
48
49    /// Wait for the given event.
50    pub fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
51        check(unsafe { cudaStreamWaitEvent(self.stream, event.event, 0) })
52    }
53}
54
55impl Drop for CudaStream {
56    fn drop(&mut self) {
57        if !self.stream.is_null() {
58            self.synchronize().unwrap();
59            let _ = unsafe { cudaStreamDestroy(self.stream) };
60            self.stream = std::ptr::null_mut();
61        }
62    }
63}
64
65#[allow(non_upper_case_globals)]
66pub const cudaStreamPerThread: cudaStream_t = 0x02 as cudaStream_t;
67
68pub type CudaStreamId = u64;
69
70pub fn current_stream_id() -> Result<CudaStreamId, CudaError> {
71    let mut id = 0;
72    check(unsafe { cudaStreamGetId(cudaStreamPerThread, &mut id) })?;
73    Ok(id)
74}
75
76pub fn current_stream_sync() -> Result<(), CudaError> {
77    check(unsafe { cudaStreamSynchronize(cudaStreamPerThread) })
78}
79
80#[allow(non_camel_case_types)]
81pub type cudaEvent_t = *mut c_void;
82
83#[derive(Debug)]
84pub enum CudaEventStatus {
85    Completed,
86    NotReady,
87    Error(CudaError),
88}
89
90impl PartialEq for CudaEventStatus {
91    fn eq(&self, other: &Self) -> bool {
92        use CudaEventStatus::*;
93        matches!((self, other), (Completed, Completed) | (NotReady, NotReady))
94    }
95}
96
97impl Eq for CudaEventStatus {}
98
99impl PartialOrd for CudaEventStatus {
100    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
101        Some(self.cmp(other))
102    }
103}
104
105// Completed < NotReady < Error
106impl Ord for CudaEventStatus {
107    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
108        use std::cmp::Ordering;
109
110        use CudaEventStatus::*;
111
112        match (self, other) {
113            (Completed, Completed) => Ordering::Equal,
114            (Completed, _) => Ordering::Less,
115            (_, Completed) => Ordering::Greater,
116            (NotReady, NotReady) => Ordering::Equal,
117            (NotReady, Error(_)) => Ordering::Less,
118            (Error(_), NotReady) => Ordering::Greater,
119            (Error(_), Error(_)) => Ordering::Equal,
120        }
121    }
122}
123
124#[derive(Debug, Clone)]
125pub struct CudaEvent {
126    event: cudaEvent_t,
127}
128
129pub fn default_stream_wait(event: &CudaEvent) -> Result<(), CudaError> {
130    check(unsafe { cudaStreamWaitEvent(cudaStreamPerThread, event.event, 0) })
131}
132
133unsafe impl Send for CudaEvent {}
134unsafe impl Sync for CudaEvent {}
135
136impl CudaEvent {
137    pub fn new() -> Result<Self, CudaError> {
138        let mut event: cudaEvent_t = std::ptr::null_mut();
139        check(unsafe { cudaEventCreate(&mut event) })?;
140        Ok(Self { event })
141    }
142
143    /// # Safety
144    /// The caller must ensure that `stream` is a valid stream.
145    pub unsafe fn record(&self, stream: cudaStream_t) -> Result<(), CudaError> {
146        check(cudaEventRecord(self.event, stream))
147    }
148
149    pub fn record_on_this(&self) -> Result<(), CudaError> {
150        check(unsafe { cudaEventRecord(self.event, cudaStreamPerThread) })
151    }
152
153    /// # Safety
154    /// The caller must ensure that `stream` is a valid stream.
155    pub unsafe fn record_and_wait(&self, stream: cudaStream_t) -> Result<(), CudaError> {
156        self.record(stream)?;
157        check(cudaEventSynchronize(self.event))
158    }
159
160    pub fn status(&self) -> CudaEventStatus {
161        let status = unsafe { cudaEventQuery(self.event) };
162        match status {
163            0 => CudaEventStatus::Completed,  // CUDA_SUCCESS
164            600 => CudaEventStatus::NotReady, // CUDA_ERROR_NOT_READY
165            _ => CudaEventStatus::Error(CudaError::new(status)),
166        }
167    }
168
169    pub fn completed(&self) -> bool {
170        self.status() == CudaEventStatus::Completed
171    }
172}
173
174impl Drop for CudaEvent {
175    fn drop(&mut self) {
176        unsafe { cudaEventDestroy(self.event) };
177    }
178}
179
180/// A GPU-aware span that collects a gauge metric using CUDA events.
181pub fn gpu_metrics_span<R, F: FnOnce() -> R>(
182    name: impl Into<Cow<'static, str>>,
183    f: F,
184) -> Result<R, CudaError> {
185    let start = CudaEvent::new()?;
186    let stop = CudaEvent::new()?;
187    unsafe {
188        check(cudaEventRecord(start.event, cudaStreamPerThread))?;
189    }
190    let res = f();
191    unsafe { stop.record_and_wait(cudaStreamPerThread)? };
192
193    let mut elapsed_ms = 0f32;
194    unsafe {
195        check(cudaEventElapsedTime(
196            &mut elapsed_ms,
197            start.event,
198            stop.event,
199        ))?
200    };
201
202    metrics::gauge!(name.into()).set(elapsed_ms as f64);
203    Ok(res)
204}