openvm_cuda_common/
stream.rs

1use std::{borrow::Cow, cell::Cell, ffi::c_void};
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7    fn cudaDeviceSynchronize() -> i32;
8    fn cudaStreamGetId(stream: cudaStream_t, id: *mut CudaStreamId) -> i32;
9    fn cudaStreamCreate(stream: *mut cudaStream_t) -> i32;
10    fn cudaStreamDestroy(stream: cudaStream_t) -> i32;
11    fn cudaStreamSynchronize(stream: cudaStream_t) -> i32;
12    fn cudaStreamWaitEvent(stream: cudaStream_t, event: cudaEvent_t, flags: u32) -> i32;
13    fn cudaEventCreate(event: *mut cudaEvent_t) -> i32;
14    fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> i32;
15    fn cudaEventSynchronize(event: cudaEvent_t) -> i32;
16    fn cudaEventQuery(event: cudaEvent_t) -> i32;
17    fn cudaEventDestroy(event: cudaEvent_t) -> i32;
18    fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> i32;
19}
20
21pub fn device_synchronize() -> Result<(), CudaError> {
22    check(unsafe { cudaDeviceSynchronize() })
23}
24
25#[allow(non_camel_case_types)]
26pub type cudaStream_t = *mut c_void;
27
28pub struct CudaStream {
29    stream: cudaStream_t,
30}
31
32unsafe impl Send for CudaStream {}
33unsafe impl Sync for CudaStream {}
34
35impl CudaStream {
36    /// Creates a new non-blocking CUDA stream.
37    pub fn new() -> Result<Self, CudaError> {
38        let mut stream: cudaStream_t = std::ptr::null_mut();
39        check(unsafe { cudaStreamCreate(&mut stream) })?;
40        Ok(Self { stream })
41    }
42
43    /// Get the raw CUDA stream handle.
44    #[inline]
45    pub fn as_raw(&self) -> cudaStream_t {
46        self.stream
47    }
48
49    /// Synchronize this stream.
50    pub fn synchronize(&self) -> Result<(), CudaError> {
51        check(unsafe { cudaStreamSynchronize(self.stream) })
52    }
53
54    /// Wait for the given event.
55    pub fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
56        check(unsafe { cudaStreamWaitEvent(self.stream, event.event, 0) })
57    }
58}
59
60impl Drop for CudaStream {
61    fn drop(&mut self) {
62        if !self.stream.is_null() {
63            self.synchronize().unwrap();
64            let _ = unsafe { cudaStreamDestroy(self.stream) };
65            self.stream = std::ptr::null_mut();
66        }
67    }
68}
69
70#[allow(non_upper_case_globals)]
71pub const cudaStreamPerThread: cudaStream_t = 0x02 as cudaStream_t;
72
73pub type CudaStreamId = u64;
74
75pub fn current_stream_id() -> Result<CudaStreamId, CudaError> {
76    let mut id = 0;
77    check(unsafe { cudaStreamGetId(cudaStreamPerThread, &mut id) })?;
78    Ok(id)
79}
80
81pub fn current_stream_sync() -> Result<(), CudaError> {
82    check(unsafe { cudaStreamSynchronize(cudaStreamPerThread) })
83}
84
85struct CudaThreadCleanup {
86    touched_cuda: Cell<bool>,
87}
88
89impl CudaThreadCleanup {
90    const fn new() -> Self {
91        Self {
92            touched_cuda: Cell::new(false),
93        }
94    }
95
96    fn mark_used(&self) {
97        self.touched_cuda.set(true);
98    }
99}
100
101impl Drop for CudaThreadCleanup {
102    fn drop(&mut self) {
103        if self.touched_cuda.get() {
104            // Best-effort drain of this thread's default stream before CUDA TLS teardown.
105            // Avoid calling `check` here to prevent re-entering TLS tracking during drop.
106            let _ = unsafe { cudaStreamSynchronize(cudaStreamPerThread) };
107        }
108    }
109}
110
111thread_local! {
112    static CUDA_THREAD_CLEANUP: CudaThreadCleanup = const { CudaThreadCleanup::new() };
113}
114
115pub(crate) fn mark_cuda_thread_used() {
116    CUDA_THREAD_CLEANUP.with(|cleanup| cleanup.mark_used());
117}
118
119#[allow(non_camel_case_types)]
120pub type cudaEvent_t = *mut c_void;
121
122#[derive(Debug)]
123pub enum CudaEventStatus {
124    Completed,
125    NotReady,
126    Error(CudaError),
127}
128
129impl PartialEq for CudaEventStatus {
130    fn eq(&self, other: &Self) -> bool {
131        use CudaEventStatus::*;
132        matches!((self, other), (Completed, Completed) | (NotReady, NotReady))
133    }
134}
135
136impl Eq for CudaEventStatus {}
137
138impl PartialOrd for CudaEventStatus {
139    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
140        Some(self.cmp(other))
141    }
142}
143
144// Completed < NotReady < Error
145impl Ord for CudaEventStatus {
146    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
147        use std::cmp::Ordering;
148
149        use CudaEventStatus::*;
150
151        match (self, other) {
152            (Completed, Completed) => Ordering::Equal,
153            (Completed, _) => Ordering::Less,
154            (_, Completed) => Ordering::Greater,
155            (NotReady, NotReady) => Ordering::Equal,
156            (NotReady, Error(_)) => Ordering::Less,
157            (Error(_), NotReady) => Ordering::Greater,
158            (Error(_), Error(_)) => Ordering::Equal,
159        }
160    }
161}
162
163#[derive(Debug, Clone)]
164pub struct CudaEvent {
165    event: cudaEvent_t,
166}
167
168pub fn default_stream_wait(event: &CudaEvent) -> Result<(), CudaError> {
169    check(unsafe { cudaStreamWaitEvent(cudaStreamPerThread, event.event, 0) })
170}
171
172unsafe impl Send for CudaEvent {}
173unsafe impl Sync for CudaEvent {}
174
175impl CudaEvent {
176    pub fn new() -> Result<Self, CudaError> {
177        let mut event: cudaEvent_t = std::ptr::null_mut();
178        check(unsafe { cudaEventCreate(&mut event) })?;
179        Ok(Self { event })
180    }
181
182    /// # Safety
183    /// The caller must ensure that `stream` is a valid stream.
184    pub unsafe fn record(&self, stream: cudaStream_t) -> Result<(), CudaError> {
185        check(cudaEventRecord(self.event, stream))
186    }
187
188    pub fn record_on_this(&self) -> Result<(), CudaError> {
189        check(unsafe { cudaEventRecord(self.event, cudaStreamPerThread) })
190    }
191
192    pub fn synchronize(&self) -> Result<(), CudaError> {
193        check(unsafe { cudaEventSynchronize(self.event) })
194    }
195
196    /// # Safety
197    /// The caller must ensure that `stream` is a valid stream.
198    pub unsafe fn record_and_wait(&self, stream: cudaStream_t) -> Result<(), CudaError> {
199        self.record(stream)?;
200        check(cudaEventSynchronize(self.event))
201    }
202
203    pub fn status(&self) -> CudaEventStatus {
204        let status = unsafe { cudaEventQuery(self.event) };
205        match status {
206            0 => CudaEventStatus::Completed,  // CUDA_SUCCESS
207            600 => CudaEventStatus::NotReady, // CUDA_ERROR_NOT_READY
208            _ => CudaEventStatus::Error(CudaError::new(status)),
209        }
210    }
211
212    pub fn completed(&self) -> bool {
213        self.status() == CudaEventStatus::Completed
214    }
215}
216
217impl Drop for CudaEvent {
218    fn drop(&mut self) {
219        unsafe { cudaEventDestroy(self.event) };
220    }
221}
222
223/// A GPU-aware span that collects a gauge metric using CUDA events.
224pub fn gpu_metrics_span<R, F: FnOnce() -> R>(
225    name: impl Into<Cow<'static, str>>,
226    f: F,
227) -> Result<R, CudaError> {
228    let start = CudaEvent::new()?;
229    let stop = CudaEvent::new()?;
230    unsafe {
231        check(cudaEventRecord(start.event, cudaStreamPerThread))?;
232    }
233    let res = f();
234    unsafe { stop.record_and_wait(cudaStreamPerThread)? };
235
236    let mut elapsed_ms = 0f32;
237    unsafe {
238        check(cudaEventElapsedTime(
239            &mut elapsed_ms,
240            start.event,
241            stop.event,
242        ))?
243    };
244
245    metrics::gauge!(name.into()).set(elapsed_ms as f64);
246    Ok(res)
247}