openvm_cuda_common/
common.rs1use std::ffi::c_void;
2
3use crate::error::{check, CudaError};
4
5#[allow(non_camel_case_types)]
6type cudaMemPool_t = *mut c_void;
7#[allow(non_upper_case_globals)]
8pub const cudaMemPoolAttrReleaseThreshold: u32 = 4;
9#[allow(non_upper_case_globals)]
10pub const cudaMemPoolAttrReuseFollowEventDependencies: u32 = 1;
11
12#[link(name = "cudart")]
13extern "C" {
14 fn cudaFree(dev_ptr: *mut c_void) -> i32;
15 fn cudaGetDevice(device: *mut i32) -> i32;
16 fn cudaDeviceReset() -> i32;
17 fn cudaDeviceGetDefaultMemPool(pool: *mut cudaMemPool_t, device: i32) -> i32;
18 fn cudaMemPoolSetAttribute(pool: cudaMemPool_t, attr: u32, value: *const c_void) -> i32;
19 fn cudaDeviceSetMemPool(device: i32, pool: cudaMemPool_t) -> i32;
20}
21
22pub fn get_device() -> Result<i32, CudaError> {
23 let mut device = 0;
24 unsafe {
25 check(cudaGetDevice(&mut device))?;
26 }
27 assert!(device >= 0);
28 Ok(device)
29}
30
31pub fn set_device() -> Result<(), CudaError> {
32 let device = get_device()?;
33 unsafe {
34 check(cudaFree(std::ptr::null_mut()))?;
36
37 let mut pool: cudaMemPool_t = std::ptr::null_mut();
39 check(cudaDeviceGetDefaultMemPool(&mut pool, device))?;
40
41 let reuse: i32 = 1;
42 check(cudaMemPoolSetAttribute(
43 pool,
44 cudaMemPoolAttrReuseFollowEventDependencies,
45 &reuse as *const i32 as *const c_void,
46 ))?;
47
48 let threshold: usize = 512 * 1024 * 1024;
50 check(cudaMemPoolSetAttribute(
51 pool,
52 cudaMemPoolAttrReleaseThreshold,
53 &threshold as *const usize as *const c_void,
54 ))?;
55
56 check(cudaDeviceSetMemPool(device, pool))?;
58 }
59 Ok(())
60}
61
62pub fn reset_device() -> Result<(), CudaError> {
63 check(unsafe { cudaDeviceReset() })
64}