openvm_cuda_common/memory_manager/
mod.rs1use std::{
2 collections::HashMap,
3 ffi::c_void,
4 ptr::NonNull,
5 sync::{Mutex, OnceLock},
6};
7
8use bytesize::ByteSize;
9
10use crate::{
11 error::{check, MemoryError},
12 stream::{cudaStreamPerThread, cudaStream_t, current_stream_id},
13};
14
15mod cuda;
16mod vm_pool;
17use vm_pool::VirtualMemoryPool;
18
19#[link(name = "cudart")]
20extern "C" {
21 fn cudaMallocAsync(dev_ptr: *mut *mut c_void, size: usize, stream: cudaStream_t) -> i32;
22 fn cudaFreeAsync(dev_ptr: *mut c_void, stream: cudaStream_t) -> i32;
23}
24
25static MEMORY_MANAGER: OnceLock<Mutex<MemoryManager>> = OnceLock::new();
26
27#[ctor::ctor]
28fn init() {
29 let _ = MEMORY_MANAGER.set(Mutex::new(MemoryManager::new()));
30 tracing::info!("Memory manager initialized at program start");
31}
32
33pub struct MemoryManager {
34 pool: VirtualMemoryPool,
35 allocated_ptrs: HashMap<NonNull<c_void>, usize>,
36 current_size: usize,
37 max_used_size: usize,
38}
39
40unsafe impl Send for MemoryManager {}
41unsafe impl Sync for MemoryManager {}
42
43impl MemoryManager {
44 pub fn new() -> Self {
45 let pool = VirtualMemoryPool::default();
47
48 Self {
49 pool,
50 allocated_ptrs: HashMap::new(),
51 current_size: 0,
52 max_used_size: 0,
53 }
54 }
55
56 fn d_malloc(&mut self, size: usize) -> Result<*mut c_void, MemoryError> {
57 assert!(size != 0, "Requested size must be non-zero");
58
59 let mut tracked_size = size;
60 let ptr = if size < self.pool.page_size {
61 let mut ptr: *mut c_void = std::ptr::null_mut();
62 check(unsafe { cudaMallocAsync(&mut ptr, size, cudaStreamPerThread) })?;
63 self.allocated_ptrs
64 .insert(NonNull::new(ptr).expect("cudaMalloc returned null"), size);
65 Ok(ptr)
66 } else {
67 tracked_size = size.next_multiple_of(self.pool.page_size);
68 self.pool
69 .malloc_internal(tracked_size, current_stream_id()?)
70 };
71
72 self.current_size += tracked_size;
73 if self.current_size > self.max_used_size {
74 self.max_used_size = self.current_size;
75 }
76 ptr
77 }
78
79 unsafe fn d_free(&mut self, ptr: *mut c_void) -> Result<(), MemoryError> {
83 let nn = NonNull::new(ptr).ok_or(MemoryError::NullPointer)?;
84
85 if let Some(size) = self.allocated_ptrs.remove(&nn) {
86 self.current_size -= size;
87 check(unsafe { cudaFreeAsync(ptr, cudaStreamPerThread) })?;
88 } else {
89 self.current_size -= self.pool.free_internal(ptr, current_stream_id()?)?;
90 }
91
92 Ok(())
93 }
94}
95
96impl Drop for MemoryManager {
97 fn drop(&mut self) {
98 let ptrs: Vec<*mut c_void> = self.allocated_ptrs.keys().map(|nn| nn.as_ptr()).collect();
99 for &ptr in &ptrs {
100 unsafe { self.d_free(ptr).unwrap() };
101 }
102 if !self.allocated_ptrs.is_empty() {
103 println!(
104 "Error: {} allocations were automatically freed on MemoryManager drop",
105 self.allocated_ptrs.len()
106 );
107 }
108 }
109}
110
111impl Default for MemoryManager {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117pub fn d_malloc(size: usize) -> Result<*mut c_void, MemoryError> {
118 let manager = MEMORY_MANAGER.get().unwrap();
119 let mut manager = manager.lock().map_err(|_| MemoryError::LockError)?;
120 manager.d_malloc(size)
121}
122
123pub unsafe fn d_free(ptr: *mut c_void) -> Result<(), MemoryError> {
127 let manager = MEMORY_MANAGER.get().unwrap();
128 let mut manager = manager.lock().map_err(|_| MemoryError::LockError)?;
129 manager.d_free(ptr)
130}
131
132#[derive(Debug, Clone)]
133pub struct MemTracker {
134 current: usize,
135 label: &'static str,
136}
137
138impl MemTracker {
139 pub fn start(label: &'static str) -> Self {
140 let current = MEMORY_MANAGER
141 .get()
142 .and_then(|m| m.lock().ok())
143 .map(|m| m.current_size)
144 .unwrap_or(0);
145
146 Self { current, label }
147 }
148
149 #[inline]
150 pub fn tracing_info(&self, msg: impl Into<Option<&'static str>>) {
151 let Some(manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) else {
152 tracing::error!("Memory manager not available");
153 return;
154 };
155 let current = manager.current_size;
156 let peak = manager.max_used_size;
157 let used = current as isize - self.current as isize;
158 let sign = if used >= 0 { "+" } else { "-" };
159 let pool_usage = manager.pool.memory_usage();
160 tracing::info!(
161 "GPU mem: used={}{}, current={}, peak={}, in pool={} ({})",
162 sign,
163 ByteSize::b(used.unsigned_abs() as u64),
164 ByteSize::b(current as u64),
165 ByteSize::b(peak as u64),
166 ByteSize::b(pool_usage as u64),
167 msg.into()
168 .map_or(self.label.to_string(), |m| format!("{}:{}", self.label, m))
169 );
170 }
171
172 pub fn reset_peak(&mut self) {
173 if let Some(mut manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) {
174 manager.max_used_size = manager.current_size;
175 }
176 }
177}
178
179impl Drop for MemTracker {
180 fn drop(&mut self) {
181 self.tracing_info(None);
182 }
183}