openvm_cuda_common/
common.rs1use std::ffi::c_void;
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7 fn cudaFree(dev_ptr: *mut c_void) -> i32;
8 fn cudaGetDevice(device: *mut i32) -> i32;
9 fn cudaSetDevice(device: i32) -> i32;
10 fn cudaDeviceReset() -> i32;
11}
12
13pub fn get_device() -> Result<i32, CudaError> {
14 let mut device = 0;
15 unsafe {
16 check(cudaGetDevice(&mut device))?;
17 }
18 assert!(device >= 0);
19 Ok(device)
20}
21
22pub fn set_device() -> Result<i32, CudaError> {
23 let mut device = 0;
24 unsafe {
25 check(cudaFree(std::ptr::null_mut()))?;
27 check(cudaGetDevice(&mut device))?;
29 check(cudaSetDevice(device))?;
30 }
31 Ok(device)
32}
33
34pub fn reset_device() -> Result<(), CudaError> {
35 check(unsafe { cudaDeviceReset() })
36}