openvm_cuda_common/
stream.rs1use std::{borrow::Cow, ffi::c_void};
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7 fn cudaDeviceSynchronize() -> i32;
8 fn cudaStreamGetId(stream: cudaStream_t, id: *mut CudaStreamId) -> i32;
9 fn cudaStreamCreate(stream: *mut cudaStream_t) -> i32;
10 fn cudaStreamDestroy(stream: cudaStream_t) -> i32;
11 fn cudaStreamSynchronize(stream: cudaStream_t) -> i32;
12 fn cudaStreamWaitEvent(stream: cudaStream_t, event: cudaEvent_t, flags: u32) -> i32;
13 fn cudaEventCreate(event: *mut cudaEvent_t) -> i32;
14 fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> i32;
15 fn cudaEventSynchronize(event: cudaEvent_t) -> i32;
16 fn cudaEventQuery(event: cudaEvent_t) -> i32;
17 fn cudaEventDestroy(event: cudaEvent_t) -> i32;
18 fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> i32;
19}
20
21pub fn device_synchronize() -> Result<(), CudaError> {
22 check(unsafe { cudaDeviceSynchronize() })
23}
24
25#[allow(non_camel_case_types)]
26pub type cudaStream_t = *mut c_void;
27
28pub struct CudaStream {
29 stream: cudaStream_t,
30}
31
32unsafe impl Send for CudaStream {}
33unsafe impl Sync for CudaStream {}
34
35impl CudaStream {
36 pub fn new() -> Result<Self, CudaError> {
38 let mut stream: cudaStream_t = std::ptr::null_mut();
39 check(unsafe { cudaStreamCreate(&mut stream) })?;
40 Ok(Self { stream })
41 }
42
43 #[inline]
45 pub fn as_raw(&self) -> cudaStream_t {
46 self.stream
47 }
48
49 pub fn synchronize(&self) -> Result<(), CudaError> {
51 check(unsafe { cudaStreamSynchronize(self.stream) })
52 }
53
54 pub fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
56 check(unsafe { cudaStreamWaitEvent(self.stream, event.event, 0) })
57 }
58}
59
60impl Drop for CudaStream {
61 fn drop(&mut self) {
62 if !self.stream.is_null() {
63 self.synchronize().unwrap();
64 let _ = unsafe { cudaStreamDestroy(self.stream) };
65 self.stream = std::ptr::null_mut();
66 }
67 }
68}
69
70#[allow(non_upper_case_globals)]
71pub const cudaStreamPerThread: cudaStream_t = 0x02 as cudaStream_t;
72
73pub type CudaStreamId = u64;
74
75pub fn current_stream_id() -> Result<CudaStreamId, CudaError> {
76 let mut id = 0;
77 check(unsafe { cudaStreamGetId(cudaStreamPerThread, &mut id) })?;
78 Ok(id)
79}
80
81pub fn current_stream_sync() -> Result<(), CudaError> {
82 check(unsafe { cudaStreamSynchronize(cudaStreamPerThread) })
83}
84
85#[allow(non_camel_case_types)]
86pub type cudaEvent_t = *mut c_void;
87
88#[derive(Debug)]
89pub enum CudaEventStatus {
90 Completed,
91 NotReady,
92 Error(CudaError),
93}
94
95impl PartialEq for CudaEventStatus {
96 fn eq(&self, other: &Self) -> bool {
97 use CudaEventStatus::*;
98 matches!((self, other), (Completed, Completed) | (NotReady, NotReady))
99 }
100}
101
102impl Eq for CudaEventStatus {}
103
104impl PartialOrd for CudaEventStatus {
105 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
106 Some(self.cmp(other))
107 }
108}
109
110impl Ord for CudaEventStatus {
112 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
113 use std::cmp::Ordering;
114
115 use CudaEventStatus::*;
116
117 match (self, other) {
118 (Completed, Completed) => Ordering::Equal,
119 (Completed, _) => Ordering::Less,
120 (_, Completed) => Ordering::Greater,
121 (NotReady, NotReady) => Ordering::Equal,
122 (NotReady, Error(_)) => Ordering::Less,
123 (Error(_), NotReady) => Ordering::Greater,
124 (Error(_), Error(_)) => Ordering::Equal,
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
130pub struct CudaEvent {
131 event: cudaEvent_t,
132}
133
134pub fn default_stream_wait(event: &CudaEvent) -> Result<(), CudaError> {
135 check(unsafe { cudaStreamWaitEvent(cudaStreamPerThread, event.event, 0) })
136}
137
138unsafe impl Send for CudaEvent {}
139unsafe impl Sync for CudaEvent {}
140
141impl CudaEvent {
142 pub fn new() -> Result<Self, CudaError> {
143 let mut event: cudaEvent_t = std::ptr::null_mut();
144 check(unsafe { cudaEventCreate(&mut event) })?;
145 Ok(Self { event })
146 }
147
148 pub unsafe fn record(&self, stream: cudaStream_t) -> Result<(), CudaError> {
151 check(cudaEventRecord(self.event, stream))
152 }
153
154 pub fn record_on_this(&self) -> Result<(), CudaError> {
155 check(unsafe { cudaEventRecord(self.event, cudaStreamPerThread) })
156 }
157
158 pub fn synchronize(&self) -> Result<(), CudaError> {
159 check(unsafe { cudaEventSynchronize(self.event) })
160 }
161
162 pub unsafe fn record_and_wait(&self, stream: cudaStream_t) -> Result<(), CudaError> {
165 self.record(stream)?;
166 check(cudaEventSynchronize(self.event))
167 }
168
169 pub fn status(&self) -> CudaEventStatus {
170 let status = unsafe { cudaEventQuery(self.event) };
171 match status {
172 0 => CudaEventStatus::Completed, 600 => CudaEventStatus::NotReady, _ => CudaEventStatus::Error(CudaError::new(status)),
175 }
176 }
177
178 pub fn completed(&self) -> bool {
179 self.status() == CudaEventStatus::Completed
180 }
181}
182
183impl Drop for CudaEvent {
184 fn drop(&mut self) {
185 unsafe { cudaEventDestroy(self.event) };
186 }
187}
188
189pub fn gpu_metrics_span<R, F: FnOnce() -> R>(
191 name: impl Into<Cow<'static, str>>,
192 f: F,
193) -> Result<R, CudaError> {
194 let start = CudaEvent::new()?;
195 let stop = CudaEvent::new()?;
196 unsafe {
197 check(cudaEventRecord(start.event, cudaStreamPerThread))?;
198 }
199 let res = f();
200 unsafe { stop.record_and_wait(cudaStreamPerThread)? };
201
202 let mut elapsed_ms = 0f32;
203 unsafe {
204 check(cudaEventElapsedTime(
205 &mut elapsed_ms,
206 start.event,
207 stop.event,
208 ))?
209 };
210
211 metrics::gauge!(name.into()).set(elapsed_ms as f64);
212 Ok(res)
213}