openvm_cuda_common/
stream.rs1use std::{borrow::Cow, ffi::c_void};
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7 fn cudaStreamGetId(stream: cudaStream_t, id: *mut CudaStreamId) -> i32;
8 fn cudaStreamCreate(stream: *mut cudaStream_t) -> i32;
9 fn cudaStreamDestroy(stream: cudaStream_t) -> i32;
10 fn cudaStreamSynchronize(stream: cudaStream_t) -> i32;
11 fn cudaStreamWaitEvent(stream: cudaStream_t, event: cudaEvent_t, flags: u32) -> i32;
12 fn cudaEventCreate(event: *mut cudaEvent_t) -> i32;
13 fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> i32;
14 fn cudaEventSynchronize(event: cudaEvent_t) -> i32;
15 fn cudaEventQuery(event: cudaEvent_t) -> i32;
16 fn cudaEventDestroy(event: cudaEvent_t) -> i32;
17 fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> i32;
18}
19
20#[allow(non_camel_case_types)]
21pub type cudaStream_t = *mut c_void;
22
23pub struct CudaStream {
24 stream: cudaStream_t,
25}
26
27unsafe impl Send for CudaStream {}
28unsafe impl Sync for CudaStream {}
29
30impl CudaStream {
31 pub fn new() -> Result<Self, CudaError> {
33 let mut stream: cudaStream_t = std::ptr::null_mut();
34 check(unsafe { cudaStreamCreate(&mut stream) })?;
35 Ok(Self { stream })
36 }
37
38 #[inline]
40 pub fn as_raw(&self) -> cudaStream_t {
41 self.stream
42 }
43
44 pub fn synchronize(&self) -> Result<(), CudaError> {
46 check(unsafe { cudaStreamSynchronize(self.stream) })
47 }
48
49 pub fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
51 check(unsafe { cudaStreamWaitEvent(self.stream, event.event, 0) })
52 }
53}
54
55impl Drop for CudaStream {
56 fn drop(&mut self) {
57 if !self.stream.is_null() {
58 self.synchronize().unwrap();
59 let _ = unsafe { cudaStreamDestroy(self.stream) };
60 self.stream = std::ptr::null_mut();
61 }
62 }
63}
64
65#[allow(non_upper_case_globals)]
66pub const cudaStreamPerThread: cudaStream_t = 0x02 as cudaStream_t;
67
68pub type CudaStreamId = u64;
69
70pub fn current_stream_id() -> Result<CudaStreamId, CudaError> {
71 let mut id = 0;
72 check(unsafe { cudaStreamGetId(cudaStreamPerThread, &mut id) })?;
73 Ok(id)
74}
75
76pub fn current_stream_sync() -> Result<(), CudaError> {
77 check(unsafe { cudaStreamSynchronize(cudaStreamPerThread) })
78}
79
80#[allow(non_camel_case_types)]
81pub type cudaEvent_t = *mut c_void;
82
83#[derive(Debug)]
84pub enum CudaEventStatus {
85 Completed,
86 NotReady,
87 Error(CudaError),
88}
89
90impl PartialEq for CudaEventStatus {
91 fn eq(&self, other: &Self) -> bool {
92 use CudaEventStatus::*;
93 matches!((self, other), (Completed, Completed) | (NotReady, NotReady))
94 }
95}
96
97impl Eq for CudaEventStatus {}
98
99impl PartialOrd for CudaEventStatus {
100 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
101 Some(self.cmp(other))
102 }
103}
104
105impl Ord for CudaEventStatus {
107 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
108 use std::cmp::Ordering;
109
110 use CudaEventStatus::*;
111
112 match (self, other) {
113 (Completed, Completed) => Ordering::Equal,
114 (Completed, _) => Ordering::Less,
115 (_, Completed) => Ordering::Greater,
116 (NotReady, NotReady) => Ordering::Equal,
117 (NotReady, Error(_)) => Ordering::Less,
118 (Error(_), NotReady) => Ordering::Greater,
119 (Error(_), Error(_)) => Ordering::Equal,
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
125pub struct CudaEvent {
126 event: cudaEvent_t,
127}
128
129pub fn default_stream_wait(event: &CudaEvent) -> Result<(), CudaError> {
130 check(unsafe { cudaStreamWaitEvent(cudaStreamPerThread, event.event, 0) })
131}
132
133unsafe impl Send for CudaEvent {}
134unsafe impl Sync for CudaEvent {}
135
136impl CudaEvent {
137 pub fn new() -> Result<Self, CudaError> {
138 let mut event: cudaEvent_t = std::ptr::null_mut();
139 check(unsafe { cudaEventCreate(&mut event) })?;
140 Ok(Self { event })
141 }
142
143 pub unsafe fn record(&self, stream: cudaStream_t) -> Result<(), CudaError> {
146 check(cudaEventRecord(self.event, stream))
147 }
148
149 pub fn record_on_this(&self) -> Result<(), CudaError> {
150 check(unsafe { cudaEventRecord(self.event, cudaStreamPerThread) })
151 }
152
153 pub unsafe fn record_and_wait(&self, stream: cudaStream_t) -> Result<(), CudaError> {
156 self.record(stream)?;
157 check(cudaEventSynchronize(self.event))
158 }
159
160 pub fn status(&self) -> CudaEventStatus {
161 let status = unsafe { cudaEventQuery(self.event) };
162 match status {
163 0 => CudaEventStatus::Completed, 600 => CudaEventStatus::NotReady, _ => CudaEventStatus::Error(CudaError::new(status)),
166 }
167 }
168
169 pub fn completed(&self) -> bool {
170 self.status() == CudaEventStatus::Completed
171 }
172}
173
174impl Drop for CudaEvent {
175 fn drop(&mut self) {
176 unsafe { cudaEventDestroy(self.event) };
177 }
178}
179
180pub fn gpu_metrics_span<R, F: FnOnce() -> R>(
182 name: impl Into<Cow<'static, str>>,
183 f: F,
184) -> Result<R, CudaError> {
185 let start = CudaEvent::new()?;
186 let stop = CudaEvent::new()?;
187 unsafe {
188 check(cudaEventRecord(start.event, cudaStreamPerThread))?;
189 }
190 let res = f();
191 unsafe { stop.record_and_wait(cudaStreamPerThread)? };
192
193 let mut elapsed_ms = 0f32;
194 unsafe {
195 check(cudaEventElapsedTime(
196 &mut elapsed_ms,
197 start.event,
198 stop.event,
199 ))?
200 };
201
202 metrics::gauge!(name.into()).set(elapsed_ms as f64);
203 Ok(res)
204}