openvm_cuda_common/
memory_manager.rs1use std::{collections::HashMap, ffi::c_void, ptr::NonNull, sync::Mutex};
2
3use bytesize::ByteSize;
4use lazy_static::lazy_static;
5
6use crate::{
7 common::set_device,
8 error::{check, MemoryError},
9 stream::{cudaStreamPerThread, cudaStream_t, default_stream_sync},
10};
11
12#[link(name = "cudart")]
13extern "C" {
14 fn cudaMallocAsync(dev_ptr: *mut *mut c_void, size: usize, stream: cudaStream_t) -> i32;
15 fn cudaFreeAsync(dev_ptr: *mut c_void, stream: cudaStream_t) -> i32;
16 fn cudaMemGetInfo(free_bytes: *mut usize, total_bytes: *mut usize) -> i32;
17}
18
19lazy_static! {
20 static ref MEMORY_MANAGER: Mutex<MemoryManager> = Mutex::new(MemoryManager::default());
21}
22
23pub struct MemoryManager {
24 allocated_ptrs: HashMap<NonNull<c_void>, usize>,
25 current_size: usize,
26 max_used_size: usize,
27}
28
29unsafe impl Send for MemoryManager {}
30unsafe impl Sync for MemoryManager {}
31
32impl MemoryManager {
33 pub fn new() -> Self {
34 set_device().unwrap();
35 let mut free: usize = 0;
36 let mut total: usize = 0;
37 check(unsafe { cudaMemGetInfo(&mut free, &mut total) }).unwrap();
38 let initial_used = total - free;
39 tracing::info!(
40 "GPU mem initial usage: current={}",
41 ByteSize::b(initial_used as u64)
42 );
43 Self {
44 allocated_ptrs: HashMap::new(),
45 current_size: initial_used,
46 max_used_size: initial_used,
47 }
48 }
49
50 pub fn d_malloc(&mut self, size: usize) -> Result<*mut c_void, MemoryError> {
51 let mut ptr: *mut c_void = std::ptr::null_mut();
52 check(unsafe { cudaMallocAsync(&mut ptr, size, cudaStreamPerThread) })?;
53
54 self.allocated_ptrs
55 .insert(NonNull::new(ptr).expect("cudaMalloc returned null"), size);
56 self.current_size += size;
57 if self.current_size > self.max_used_size {
58 self.max_used_size = self.current_size;
59 }
60 Ok(ptr)
61 }
62
63 pub unsafe fn d_free(&mut self, ptr: *mut c_void) -> Result<(), MemoryError> {
67 let nn = NonNull::new(ptr).ok_or(MemoryError::NullPointer)?;
68
69 if let Some(size) = self.allocated_ptrs.remove(&nn) {
70 self.current_size -= size;
71 } else {
72 return Err(MemoryError::UntrackedPointer);
73 }
74
75 check(unsafe { cudaFreeAsync(ptr, cudaStreamPerThread) })?;
76
77 Ok(())
78 }
79}
80
81impl Drop for MemoryManager {
82 fn drop(&mut self) {
83 for &nn in self.allocated_ptrs.keys() {
84 unsafe { d_free(nn.as_ptr()).unwrap() };
85 }
86 default_stream_sync().unwrap();
87 if !self.allocated_ptrs.is_empty() {
88 println!(
89 "Warning: {} allocations were automatically freed on MemoryManager drop",
90 self.allocated_ptrs.len()
91 );
92 }
93 }
94}
95
96impl Default for MemoryManager {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102pub fn d_malloc(size: usize) -> Result<*mut c_void, MemoryError> {
103 let mut manager = MEMORY_MANAGER.lock().map_err(|_| MemoryError::LockError)?;
104 manager.d_malloc(size)
105}
106
107pub unsafe fn d_free(ptr: *mut c_void) -> Result<(), MemoryError> {
111 let mut manager = MEMORY_MANAGER.lock().map_err(|_| MemoryError::LockError)?;
112 manager.d_free(ptr)
113}
114
115fn peak_memory_usage() -> usize {
116 let manager = MEMORY_MANAGER.lock().unwrap();
117 manager.max_used_size
118}
119
120fn current_memory_usage() -> usize {
121 let manager = MEMORY_MANAGER.lock().unwrap();
122 manager.current_size
123}
124
125fn reset_peak_memory(new_value: usize) {
126 let mut manager = MEMORY_MANAGER.lock().unwrap();
127 manager.max_used_size = new_value;
128}
129
130#[derive(Debug, Clone)]
131pub struct MemTracker {
132 current: usize,
133 label: &'static str,
134}
135
136impl MemTracker {
137 pub fn start(label: &'static str) -> Self {
138 Self {
139 current: current_memory_usage(),
140 label,
141 }
142 }
143
144 #[inline]
145 pub fn tracing_info(&self, msg: impl Into<Option<&'static str>>) {
146 let current = current_memory_usage();
147 let peak = peak_memory_usage();
148 let used = current as isize - self.current as isize;
149 let sign = if used >= 0 { "+" } else { "-" };
150 tracing::info!(
151 "GPU mem usage: used={}{}, current={}, peak={} ({})",
152 sign,
153 ByteSize::b(used.unsigned_abs() as u64),
154 ByteSize::b(current as u64),
155 ByteSize::b(peak as u64),
156 msg.into()
157 .map_or(self.label.to_string(), |m| format!("{}:{}", self.label, m))
158 );
159 }
160
161 pub fn reset_peak(&mut self) {
162 reset_peak_memory(self.current);
163 }
164}
165
166impl Drop for MemTracker {
167 fn drop(&mut self) {
168 self.tracing_info(None);
169 }
170}