openvm_cuda_common/
memory_manager.rs

1use std::{collections::HashMap, ffi::c_void, ptr::NonNull, sync::Mutex};
2
3use bytesize::ByteSize;
4use lazy_static::lazy_static;
5
6use crate::{
7    common::set_device,
8    error::{check, MemoryError},
9    stream::{cudaStreamPerThread, cudaStream_t, default_stream_sync},
10};
11
12#[link(name = "cudart")]
13extern "C" {
14    fn cudaMallocAsync(dev_ptr: *mut *mut c_void, size: usize, stream: cudaStream_t) -> i32;
15    fn cudaFreeAsync(dev_ptr: *mut c_void, stream: cudaStream_t) -> i32;
16    fn cudaMemGetInfo(free_bytes: *mut usize, total_bytes: *mut usize) -> i32;
17}
18
19lazy_static! {
20    static ref MEMORY_MANAGER: Mutex<MemoryManager> = Mutex::new(MemoryManager::default());
21}
22
23pub struct MemoryManager {
24    allocated_ptrs: HashMap<NonNull<c_void>, usize>,
25    current_size: usize,
26    max_used_size: usize,
27}
28
29unsafe impl Send for MemoryManager {}
30unsafe impl Sync for MemoryManager {}
31
32impl MemoryManager {
33    pub fn new() -> Self {
34        set_device().unwrap();
35        let mut free: usize = 0;
36        let mut total: usize = 0;
37        check(unsafe { cudaMemGetInfo(&mut free, &mut total) }).unwrap();
38        let initial_used = total - free;
39        tracing::info!(
40            "GPU mem initial usage: current={}",
41            ByteSize::b(initial_used as u64)
42        );
43        Self {
44            allocated_ptrs: HashMap::new(),
45            current_size: initial_used,
46            max_used_size: initial_used,
47        }
48    }
49
50    pub fn d_malloc(&mut self, size: usize) -> Result<*mut c_void, MemoryError> {
51        let mut ptr: *mut c_void = std::ptr::null_mut();
52        check(unsafe { cudaMallocAsync(&mut ptr, size, cudaStreamPerThread) })?;
53
54        self.allocated_ptrs
55            .insert(NonNull::new(ptr).expect("cudaMalloc returned null"), size);
56        self.current_size += size;
57        if self.current_size > self.max_used_size {
58            self.max_used_size = self.current_size;
59        }
60        Ok(ptr)
61    }
62
63    /// # Safety
64    /// The pointer `ptr` must be a valid, previously allocated device pointer.
65    /// The caller must ensure that `ptr` is not used after this function is called.
66    pub unsafe fn d_free(&mut self, ptr: *mut c_void) -> Result<(), MemoryError> {
67        let nn = NonNull::new(ptr).ok_or(MemoryError::NullPointer)?;
68
69        if let Some(size) = self.allocated_ptrs.remove(&nn) {
70            self.current_size -= size;
71        } else {
72            return Err(MemoryError::UntrackedPointer);
73        }
74
75        check(unsafe { cudaFreeAsync(ptr, cudaStreamPerThread) })?;
76
77        Ok(())
78    }
79}
80
81impl Drop for MemoryManager {
82    fn drop(&mut self) {
83        for &nn in self.allocated_ptrs.keys() {
84            unsafe { d_free(nn.as_ptr()).unwrap() };
85        }
86        default_stream_sync().unwrap();
87        if !self.allocated_ptrs.is_empty() {
88            println!(
89                "Warning: {} allocations were automatically freed on MemoryManager drop",
90                self.allocated_ptrs.len()
91            );
92        }
93    }
94}
95
96impl Default for MemoryManager {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102pub fn d_malloc(size: usize) -> Result<*mut c_void, MemoryError> {
103    let mut manager = MEMORY_MANAGER.lock().map_err(|_| MemoryError::LockError)?;
104    manager.d_malloc(size)
105}
106
107/// # Safety
108/// The pointer `ptr` must be a valid, previously allocated device pointer.
109/// The caller must ensure that `ptr` is not used after this function is called.
110pub unsafe fn d_free(ptr: *mut c_void) -> Result<(), MemoryError> {
111    let mut manager = MEMORY_MANAGER.lock().map_err(|_| MemoryError::LockError)?;
112    manager.d_free(ptr)
113}
114
115fn peak_memory_usage() -> usize {
116    let manager = MEMORY_MANAGER.lock().unwrap();
117    manager.max_used_size
118}
119
120fn current_memory_usage() -> usize {
121    let manager = MEMORY_MANAGER.lock().unwrap();
122    manager.current_size
123}
124
125fn reset_peak_memory(new_value: usize) {
126    let mut manager = MEMORY_MANAGER.lock().unwrap();
127    manager.max_used_size = new_value;
128}
129
130#[derive(Debug, Clone)]
131pub struct MemTracker {
132    current: usize,
133    label: &'static str,
134}
135
136impl MemTracker {
137    pub fn start(label: &'static str) -> Self {
138        Self {
139            current: current_memory_usage(),
140            label,
141        }
142    }
143
144    #[inline]
145    pub fn tracing_info(&self, msg: impl Into<Option<&'static str>>) {
146        let current = current_memory_usage();
147        let peak = peak_memory_usage();
148        let used = current as isize - self.current as isize;
149        let sign = if used >= 0 { "+" } else { "-" };
150        tracing::info!(
151            "GPU mem usage: used={}{}, current={}, peak={} ({})",
152            sign,
153            ByteSize::b(used.unsigned_abs() as u64),
154            ByteSize::b(current as u64),
155            ByteSize::b(peak as u64),
156            msg.into()
157                .map_or(self.label.to_string(), |m| format!("{}:{}", self.label, m))
158        );
159    }
160
161    pub fn reset_peak(&mut self) {
162        reset_peak_memory(self.current);
163    }
164}
165
166impl Drop for MemTracker {
167    fn drop(&mut self) {
168        self.tracing_info(None);
169    }
170}