openvm_cuda_common/
common.rs

1use std::ffi::c_void;
2
3use crate::error::{check, CudaError};
4
5#[link(name = "cudart")]
6extern "C" {
7    fn cudaFree(dev_ptr: *mut c_void) -> i32;
8    fn cudaGetDevice(device: *mut i32) -> i32;
9    fn cudaSetDevice(device: i32) -> i32;
10    fn cudaDeviceReset() -> i32;
11}
12
13pub fn get_device() -> Result<i32, CudaError> {
14    let mut device = 0;
15    unsafe {
16        check(cudaGetDevice(&mut device))?;
17    }
18    assert!(device >= 0);
19    Ok(device)
20}
21
22pub fn set_device() -> Result<i32, CudaError> {
23    let mut device = 0;
24    unsafe {
25        // 1. Create a context
26        check(cudaFree(std::ptr::null_mut()))?;
27        // 2. Set the device
28        check(cudaGetDevice(&mut device))?;
29        check(cudaSetDevice(device))?;
30    }
31    Ok(device)
32}
33
34pub fn reset_device() -> Result<(), CudaError> {
35    check(unsafe { cudaDeviceReset() })
36}