openvm_cuda_common/
stream.rs

1use std::{borrow::Cow, ffi::c_void};
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7    fn cudaDeviceSynchronize() -> i32;
8    fn cudaStreamGetId(stream: cudaStream_t, id: *mut CudaStreamId) -> i32;
9    fn cudaStreamCreate(stream: *mut cudaStream_t) -> i32;
10    fn cudaStreamDestroy(stream: cudaStream_t) -> i32;
11    fn cudaStreamSynchronize(stream: cudaStream_t) -> i32;
12    fn cudaStreamWaitEvent(stream: cudaStream_t, event: cudaEvent_t, flags: u32) -> i32;
13    fn cudaEventCreate(event: *mut cudaEvent_t) -> i32;
14    fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> i32;
15    fn cudaEventSynchronize(event: cudaEvent_t) -> i32;
16    fn cudaEventQuery(event: cudaEvent_t) -> i32;
17    fn cudaEventDestroy(event: cudaEvent_t) -> i32;
18    fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> i32;
19}
20
21pub fn device_synchronize() -> Result<(), CudaError> {
22    check(unsafe { cudaDeviceSynchronize() })
23}
24
25#[allow(non_camel_case_types)]
26pub type cudaStream_t = *mut c_void;
27
28pub struct CudaStream {
29    stream: cudaStream_t,
30}
31
32unsafe impl Send for CudaStream {}
33unsafe impl Sync for CudaStream {}
34
35impl CudaStream {
36    /// Creates a new non-blocking CUDA stream.
37    pub fn new() -> Result<Self, CudaError> {
38        let mut stream: cudaStream_t = std::ptr::null_mut();
39        check(unsafe { cudaStreamCreate(&mut stream) })?;
40        Ok(Self { stream })
41    }
42
43    /// Get the raw CUDA stream handle.
44    #[inline]
45    pub fn as_raw(&self) -> cudaStream_t {
46        self.stream
47    }
48
49    /// Synchronize this stream.
50    pub fn synchronize(&self) -> Result<(), CudaError> {
51        check(unsafe { cudaStreamSynchronize(self.stream) })
52    }
53
54    /// Wait for the given event.
55    pub fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
56        check(unsafe { cudaStreamWaitEvent(self.stream, event.event, 0) })
57    }
58}
59
60impl Drop for CudaStream {
61    fn drop(&mut self) {
62        if !self.stream.is_null() {
63            self.synchronize().unwrap();
64            let _ = unsafe { cudaStreamDestroy(self.stream) };
65            self.stream = std::ptr::null_mut();
66        }
67    }
68}
69
70#[allow(non_upper_case_globals)]
71pub const cudaStreamPerThread: cudaStream_t = 0x02 as cudaStream_t;
72
73pub type CudaStreamId = u64;
74
75pub fn current_stream_id() -> Result<CudaStreamId, CudaError> {
76    let mut id = 0;
77    check(unsafe { cudaStreamGetId(cudaStreamPerThread, &mut id) })?;
78    Ok(id)
79}
80
81pub fn current_stream_sync() -> Result<(), CudaError> {
82    check(unsafe { cudaStreamSynchronize(cudaStreamPerThread) })
83}
84
85#[allow(non_camel_case_types)]
86pub type cudaEvent_t = *mut c_void;
87
88#[derive(Debug)]
89pub enum CudaEventStatus {
90    Completed,
91    NotReady,
92    Error(CudaError),
93}
94
95impl PartialEq for CudaEventStatus {
96    fn eq(&self, other: &Self) -> bool {
97        use CudaEventStatus::*;
98        matches!((self, other), (Completed, Completed) | (NotReady, NotReady))
99    }
100}
101
102impl Eq for CudaEventStatus {}
103
104impl PartialOrd for CudaEventStatus {
105    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
106        Some(self.cmp(other))
107    }
108}
109
110// Completed < NotReady < Error
111impl Ord for CudaEventStatus {
112    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
113        use std::cmp::Ordering;
114
115        use CudaEventStatus::*;
116
117        match (self, other) {
118            (Completed, Completed) => Ordering::Equal,
119            (Completed, _) => Ordering::Less,
120            (_, Completed) => Ordering::Greater,
121            (NotReady, NotReady) => Ordering::Equal,
122            (NotReady, Error(_)) => Ordering::Less,
123            (Error(_), NotReady) => Ordering::Greater,
124            (Error(_), Error(_)) => Ordering::Equal,
125        }
126    }
127}
128
129#[derive(Debug, Clone)]
130pub struct CudaEvent {
131    event: cudaEvent_t,
132}
133
134pub fn default_stream_wait(event: &CudaEvent) -> Result<(), CudaError> {
135    check(unsafe { cudaStreamWaitEvent(cudaStreamPerThread, event.event, 0) })
136}
137
138unsafe impl Send for CudaEvent {}
139unsafe impl Sync for CudaEvent {}
140
141impl CudaEvent {
142    pub fn new() -> Result<Self, CudaError> {
143        let mut event: cudaEvent_t = std::ptr::null_mut();
144        check(unsafe { cudaEventCreate(&mut event) })?;
145        Ok(Self { event })
146    }
147
148    /// # Safety
149    /// The caller must ensure that `stream` is a valid stream.
150    pub unsafe fn record(&self, stream: cudaStream_t) -> Result<(), CudaError> {
151        check(cudaEventRecord(self.event, stream))
152    }
153
154    pub fn record_on_this(&self) -> Result<(), CudaError> {
155        check(unsafe { cudaEventRecord(self.event, cudaStreamPerThread) })
156    }
157
158    pub fn synchronize(&self) -> Result<(), CudaError> {
159        check(unsafe { cudaEventSynchronize(self.event) })
160    }
161
162    /// # Safety
163    /// The caller must ensure that `stream` is a valid stream.
164    pub unsafe fn record_and_wait(&self, stream: cudaStream_t) -> Result<(), CudaError> {
165        self.record(stream)?;
166        check(cudaEventSynchronize(self.event))
167    }
168
169    pub fn status(&self) -> CudaEventStatus {
170        let status = unsafe { cudaEventQuery(self.event) };
171        match status {
172            0 => CudaEventStatus::Completed,  // CUDA_SUCCESS
173            600 => CudaEventStatus::NotReady, // CUDA_ERROR_NOT_READY
174            _ => CudaEventStatus::Error(CudaError::new(status)),
175        }
176    }
177
178    pub fn completed(&self) -> bool {
179        self.status() == CudaEventStatus::Completed
180    }
181}
182
183impl Drop for CudaEvent {
184    fn drop(&mut self) {
185        unsafe { cudaEventDestroy(self.event) };
186    }
187}
188
189/// A GPU-aware span that collects a gauge metric using CUDA events.
190pub fn gpu_metrics_span<R, F: FnOnce() -> R>(
191    name: impl Into<Cow<'static, str>>,
192    f: F,
193) -> Result<R, CudaError> {
194    let start = CudaEvent::new()?;
195    let stop = CudaEvent::new()?;
196    unsafe {
197        check(cudaEventRecord(start.event, cudaStreamPerThread))?;
198    }
199    let res = f();
200    unsafe { stop.record_and_wait(cudaStreamPerThread)? };
201
202    let mut elapsed_ms = 0f32;
203    unsafe {
204        check(cudaEventElapsedTime(
205            &mut elapsed_ms,
206            start.event,
207            stop.event,
208        ))?
209    };
210
211    metrics::gauge!(name.into()).set(elapsed_ms as f64);
212    Ok(res)
213}