openvm_cuda_common/
common.rs

1use std::ffi::c_void;
2
3use crate::error::{check, CudaError};
4
5#[allow(non_camel_case_types)]
6type cudaMemPool_t = *mut c_void;
7#[allow(non_upper_case_globals)]
8pub const cudaMemPoolAttrReleaseThreshold: u32 = 4;
9#[allow(non_upper_case_globals)]
10pub const cudaMemPoolAttrReuseFollowEventDependencies: u32 = 1;
11
12#[link(name = "cudart")]
13extern "C" {
14    fn cudaFree(dev_ptr: *mut c_void) -> i32;
15    fn cudaGetDevice(device: *mut i32) -> i32;
16    fn cudaDeviceReset() -> i32;
17    fn cudaDeviceGetDefaultMemPool(pool: *mut cudaMemPool_t, device: i32) -> i32;
18    fn cudaMemPoolSetAttribute(pool: cudaMemPool_t, attr: u32, value: *const c_void) -> i32;
19    fn cudaDeviceSetMemPool(device: i32, pool: cudaMemPool_t) -> i32;
20}
21
22pub fn get_device() -> Result<i32, CudaError> {
23    let mut device = 0;
24    unsafe {
25        check(cudaGetDevice(&mut device))?;
26    }
27    assert!(device >= 0);
28    Ok(device)
29}
30
31pub fn set_device() -> Result<(), CudaError> {
32    let device = get_device()?;
33    unsafe {
34        // 0. Create a context
35        check(cudaFree(std::ptr::null_mut()))?;
36
37        // 1. Get the default memory pool
38        let mut pool: cudaMemPool_t = std::ptr::null_mut();
39        check(cudaDeviceGetDefaultMemPool(&mut pool, device))?;
40
41        let reuse: i32 = 1;
42        check(cudaMemPoolSetAttribute(
43            pool,
44            cudaMemPoolAttrReuseFollowEventDependencies,
45            &reuse as *const i32 as *const c_void,
46        ))?;
47
48        // 2. Set release threshold to 512 MB
49        let threshold: usize = 512 * 1024 * 1024;
50        check(cudaMemPoolSetAttribute(
51            pool,
52            cudaMemPoolAttrReleaseThreshold,
53            &threshold as *const usize as *const c_void,
54        ))?;
55
56        // 3. Optional but safe: assign pool back to device
57        check(cudaDeviceSetMemPool(device, pool))?;
58    }
59    Ok(())
60}
61
62pub fn reset_device() -> Result<(), CudaError> {
63    check(unsafe { cudaDeviceReset() })
64}