openvm_cuda_common/
stream.rs1use std::{borrow::Cow, ffi::c_void};
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7 fn cudaStreamCreate(stream: *mut cudaStream_t) -> i32;
8 fn cudaStreamDestroy(stream: cudaStream_t) -> i32;
9 fn cudaStreamSynchronize(stream: cudaStream_t) -> i32;
10 fn cudaStreamWaitEvent(stream: cudaStream_t, event: cudaEvent_t, flags: u32) -> i32;
11 fn cudaEventCreate(event: *mut cudaEvent_t) -> i32;
12 fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> i32;
13 fn cudaEventSynchronize(event: cudaEvent_t) -> i32;
14 fn cudaEventDestroy(event: cudaEvent_t) -> i32;
15 fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> i32;
16}
17
18#[allow(non_camel_case_types)]
19pub type cudaStream_t = *mut c_void;
20
21pub struct CudaStream {
22 stream: cudaStream_t,
23}
24
25unsafe impl Send for CudaStream {}
26unsafe impl Sync for CudaStream {}
27
28impl CudaStream {
29 pub fn new() -> Result<Self, CudaError> {
31 let mut stream: cudaStream_t = std::ptr::null_mut();
32 check(unsafe { cudaStreamCreate(&mut stream) })?;
33 Ok(Self { stream })
34 }
35
36 #[inline]
38 pub fn as_raw(&self) -> cudaStream_t {
39 self.stream
40 }
41
42 pub fn synchronize(&self) -> Result<(), CudaError> {
44 check(unsafe { cudaStreamSynchronize(self.stream) })
45 }
46
47 pub fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
49 check(unsafe { cudaStreamWaitEvent(self.stream, event.event, 0) })
50 }
51}
52
53impl Drop for CudaStream {
54 fn drop(&mut self) {
55 if !self.stream.is_null() {
56 self.synchronize().unwrap();
57 let _ = unsafe { cudaStreamDestroy(self.stream) };
58 self.stream = std::ptr::null_mut();
59 }
60 }
61}
62
63#[allow(non_camel_case_types)]
64pub type cudaEvent_t = *mut c_void;
65#[allow(non_upper_case_globals)]
66pub const cudaStreamPerThread: cudaStream_t = 0x02 as cudaStream_t;
67
68pub fn default_stream_sync() -> Result<(), CudaError> {
69 check(unsafe { cudaStreamSynchronize(cudaStreamPerThread) })
70}
71
72pub struct CudaEvent {
73 event: cudaEvent_t,
74}
75
76pub fn default_stream_wait(event: &CudaEvent) -> Result<(), CudaError> {
77 check(unsafe { cudaStreamWaitEvent(cudaStreamPerThread, event.event, 0) })
78}
79
80unsafe impl Send for CudaEvent {}
81unsafe impl Sync for CudaEvent {}
82
83impl CudaEvent {
84 pub fn new() -> Result<Self, CudaError> {
85 let mut event: cudaEvent_t = std::ptr::null_mut();
86 check(unsafe { cudaEventCreate(&mut event) })?;
87 Ok(Self { event })
88 }
89
90 pub unsafe fn record(&self, stream: cudaStream_t) -> Result<(), CudaError> {
93 check(cudaEventRecord(self.event, stream))
94 }
95
96 pub unsafe fn record_and_wait(&self, stream: cudaStream_t) -> Result<(), CudaError> {
99 self.record(stream)?;
100 check(cudaEventSynchronize(self.event))
101 }
102}
103
104impl Drop for CudaEvent {
105 fn drop(&mut self) {
106 unsafe { cudaEventDestroy(self.event) };
107 }
108}
109
110pub fn gpu_metrics_span<R, F: FnOnce() -> R>(
112 name: impl Into<Cow<'static, str>>,
113 f: F,
114) -> Result<R, CudaError> {
115 let start = CudaEvent::new()?;
116 let stop = CudaEvent::new()?;
117 unsafe {
118 check(cudaEventRecord(start.event, cudaStreamPerThread))?;
119 }
120 let res = f();
121 unsafe { stop.record_and_wait(cudaStreamPerThread)? };
122
123 let mut elapsed_ms = 0f32;
124 unsafe {
125 check(cudaEventElapsedTime(
126 &mut elapsed_ms,
127 start.event,
128 stop.event,
129 ))?
130 };
131
132 metrics::gauge!(name.into()).set(elapsed_ms as f64);
133 Ok(res)
134}