openvm_cuda_common/
stream.rs1use std::{borrow::Cow, cell::Cell, ffi::c_void};
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7 fn cudaDeviceSynchronize() -> i32;
8 fn cudaStreamGetId(stream: cudaStream_t, id: *mut CudaStreamId) -> i32;
9 fn cudaStreamCreate(stream: *mut cudaStream_t) -> i32;
10 fn cudaStreamDestroy(stream: cudaStream_t) -> i32;
11 fn cudaStreamSynchronize(stream: cudaStream_t) -> i32;
12 fn cudaStreamWaitEvent(stream: cudaStream_t, event: cudaEvent_t, flags: u32) -> i32;
13 fn cudaEventCreate(event: *mut cudaEvent_t) -> i32;
14 fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> i32;
15 fn cudaEventSynchronize(event: cudaEvent_t) -> i32;
16 fn cudaEventQuery(event: cudaEvent_t) -> i32;
17 fn cudaEventDestroy(event: cudaEvent_t) -> i32;
18 fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> i32;
19}
20
21pub fn device_synchronize() -> Result<(), CudaError> {
22 check(unsafe { cudaDeviceSynchronize() })
23}
24
25#[allow(non_camel_case_types)]
26pub type cudaStream_t = *mut c_void;
27
28pub struct CudaStream {
29 stream: cudaStream_t,
30}
31
32unsafe impl Send for CudaStream {}
33unsafe impl Sync for CudaStream {}
34
35impl CudaStream {
36 pub fn new() -> Result<Self, CudaError> {
38 let mut stream: cudaStream_t = std::ptr::null_mut();
39 check(unsafe { cudaStreamCreate(&mut stream) })?;
40 Ok(Self { stream })
41 }
42
43 #[inline]
45 pub fn as_raw(&self) -> cudaStream_t {
46 self.stream
47 }
48
49 pub fn synchronize(&self) -> Result<(), CudaError> {
51 check(unsafe { cudaStreamSynchronize(self.stream) })
52 }
53
54 pub fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
56 check(unsafe { cudaStreamWaitEvent(self.stream, event.event, 0) })
57 }
58}
59
60impl Drop for CudaStream {
61 fn drop(&mut self) {
62 if !self.stream.is_null() {
63 self.synchronize().unwrap();
64 let _ = unsafe { cudaStreamDestroy(self.stream) };
65 self.stream = std::ptr::null_mut();
66 }
67 }
68}
69
70#[allow(non_upper_case_globals)]
71pub const cudaStreamPerThread: cudaStream_t = 0x02 as cudaStream_t;
72
73pub type CudaStreamId = u64;
74
75pub fn current_stream_id() -> Result<CudaStreamId, CudaError> {
76 let mut id = 0;
77 check(unsafe { cudaStreamGetId(cudaStreamPerThread, &mut id) })?;
78 Ok(id)
79}
80
81pub fn current_stream_sync() -> Result<(), CudaError> {
82 check(unsafe { cudaStreamSynchronize(cudaStreamPerThread) })
83}
84
85struct CudaThreadCleanup {
86 touched_cuda: Cell<bool>,
87}
88
89impl CudaThreadCleanup {
90 const fn new() -> Self {
91 Self {
92 touched_cuda: Cell::new(false),
93 }
94 }
95
96 fn mark_used(&self) {
97 self.touched_cuda.set(true);
98 }
99}
100
101impl Drop for CudaThreadCleanup {
102 fn drop(&mut self) {
103 if self.touched_cuda.get() {
104 let _ = unsafe { cudaStreamSynchronize(cudaStreamPerThread) };
107 }
108 }
109}
110
111thread_local! {
112 static CUDA_THREAD_CLEANUP: CudaThreadCleanup = const { CudaThreadCleanup::new() };
113}
114
115pub(crate) fn mark_cuda_thread_used() {
116 CUDA_THREAD_CLEANUP.with(|cleanup| cleanup.mark_used());
117}
118
119#[allow(non_camel_case_types)]
120pub type cudaEvent_t = *mut c_void;
121
122#[derive(Debug)]
123pub enum CudaEventStatus {
124 Completed,
125 NotReady,
126 Error(CudaError),
127}
128
129impl PartialEq for CudaEventStatus {
130 fn eq(&self, other: &Self) -> bool {
131 use CudaEventStatus::*;
132 matches!((self, other), (Completed, Completed) | (NotReady, NotReady))
133 }
134}
135
136impl Eq for CudaEventStatus {}
137
138impl PartialOrd for CudaEventStatus {
139 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
140 Some(self.cmp(other))
141 }
142}
143
144impl Ord for CudaEventStatus {
146 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
147 use std::cmp::Ordering;
148
149 use CudaEventStatus::*;
150
151 match (self, other) {
152 (Completed, Completed) => Ordering::Equal,
153 (Completed, _) => Ordering::Less,
154 (_, Completed) => Ordering::Greater,
155 (NotReady, NotReady) => Ordering::Equal,
156 (NotReady, Error(_)) => Ordering::Less,
157 (Error(_), NotReady) => Ordering::Greater,
158 (Error(_), Error(_)) => Ordering::Equal,
159 }
160 }
161}
162
163#[derive(Debug, Clone)]
164pub struct CudaEvent {
165 event: cudaEvent_t,
166}
167
168pub fn default_stream_wait(event: &CudaEvent) -> Result<(), CudaError> {
169 check(unsafe { cudaStreamWaitEvent(cudaStreamPerThread, event.event, 0) })
170}
171
172unsafe impl Send for CudaEvent {}
173unsafe impl Sync for CudaEvent {}
174
175impl CudaEvent {
176 pub fn new() -> Result<Self, CudaError> {
177 let mut event: cudaEvent_t = std::ptr::null_mut();
178 check(unsafe { cudaEventCreate(&mut event) })?;
179 Ok(Self { event })
180 }
181
182 pub unsafe fn record(&self, stream: cudaStream_t) -> Result<(), CudaError> {
185 check(cudaEventRecord(self.event, stream))
186 }
187
188 pub fn record_on_this(&self) -> Result<(), CudaError> {
189 check(unsafe { cudaEventRecord(self.event, cudaStreamPerThread) })
190 }
191
192 pub fn synchronize(&self) -> Result<(), CudaError> {
193 check(unsafe { cudaEventSynchronize(self.event) })
194 }
195
196 pub unsafe fn record_and_wait(&self, stream: cudaStream_t) -> Result<(), CudaError> {
199 self.record(stream)?;
200 check(cudaEventSynchronize(self.event))
201 }
202
203 pub fn status(&self) -> CudaEventStatus {
204 let status = unsafe { cudaEventQuery(self.event) };
205 match status {
206 0 => CudaEventStatus::Completed, 600 => CudaEventStatus::NotReady, _ => CudaEventStatus::Error(CudaError::new(status)),
209 }
210 }
211
212 pub fn completed(&self) -> bool {
213 self.status() == CudaEventStatus::Completed
214 }
215}
216
217impl Drop for CudaEvent {
218 fn drop(&mut self) {
219 unsafe { cudaEventDestroy(self.event) };
220 }
221}
222
223pub fn gpu_metrics_span<R, F: FnOnce() -> R>(
225 name: impl Into<Cow<'static, str>>,
226 f: F,
227) -> Result<R, CudaError> {
228 let start = CudaEvent::new()?;
229 let stop = CudaEvent::new()?;
230 unsafe {
231 check(cudaEventRecord(start.event, cudaStreamPerThread))?;
232 }
233 let res = f();
234 unsafe { stop.record_and_wait(cudaStreamPerThread)? };
235
236 let mut elapsed_ms = 0f32;
237 unsafe {
238 check(cudaEventElapsedTime(
239 &mut elapsed_ms,
240 start.event,
241 stop.event,
242 ))?
243 };
244
245 metrics::gauge!(name.into()).set(elapsed_ms as f64);
246 Ok(res)
247}