openvm_cuda_common/memory_manager/
mod.rs1use std::{
2 collections::HashMap,
3 ffi::c_void,
4 ptr::NonNull,
5 sync::{Mutex, OnceLock},
6};
7
8use bytesize::ByteSize;
9
10use crate::{
11 error::{check, MemoryError},
12 stream::{cudaStreamPerThread, cudaStream_t, current_stream_id, device_synchronize},
13};
14
15mod cuda;
16mod vm_pool;
17use vm_pool::VirtualMemoryPool;
18
19#[cfg(test)]
20mod tests;
21
22#[link(name = "cudart")]
23extern "C" {
24 fn cudaMallocAsync(dev_ptr: *mut *mut c_void, size: usize, stream: cudaStream_t) -> i32;
25 fn cudaFreeAsync(dev_ptr: *mut c_void, stream: cudaStream_t) -> i32;
26}
27
28static MEMORY_MANAGER: OnceLock<Mutex<MemoryManager>> = OnceLock::new();
29
30#[ctor::ctor]
31fn init() {
32 let _ = MEMORY_MANAGER.set(Mutex::new(MemoryManager::new()));
33 tracing::info!("Memory manager initialized at program start");
34}
35
36pub struct MemoryManager {
37 pool: VirtualMemoryPool,
38 allocated_ptrs: HashMap<NonNull<c_void>, usize>,
39 current_size: usize,
40 max_used_size: usize,
41}
42
43unsafe impl Send for MemoryManager {}
47unsafe impl Sync for MemoryManager {}
48
49impl MemoryManager {
50 pub fn new() -> Self {
51 let pool = VirtualMemoryPool::default();
53
54 Self {
55 pool,
56 allocated_ptrs: HashMap::new(),
57 current_size: 0,
58 max_used_size: 0,
59 }
60 }
61
62 fn d_malloc(&mut self, size: usize) -> Result<*mut c_void, MemoryError> {
63 assert!(size != 0, "Requested size must be non-zero");
64
65 let mut tracked_size = size;
66 let ptr = if size < self.pool.page_size {
67 let mut ptr: *mut c_void = std::ptr::null_mut();
68 check(unsafe { cudaMallocAsync(&mut ptr, size, cudaStreamPerThread) }).map_err(
69 |e| {
70 tracing::error!("cudaMallocAsync failed: size={}: {:?}", size, e);
71 MemoryError::from(e)
72 },
73 )?;
74 self.allocated_ptrs.insert(
75 NonNull::new(ptr).expect("BUG: cudaMallocAsync returned null"),
76 size,
77 );
78 ptr
79 } else {
80 tracked_size = size.next_multiple_of(self.pool.page_size);
81 let stream_id = current_stream_id()?;
82 self.pool.malloc_internal(tracked_size, stream_id)?
83 };
84
85 self.current_size += tracked_size;
86 if self.current_size > self.max_used_size {
87 self.max_used_size = self.current_size;
88 }
89 Ok(ptr)
90 }
91
92 unsafe fn d_free(&mut self, ptr: *mut c_void) -> Result<(), MemoryError> {
96 let nn = NonNull::new(ptr).ok_or(MemoryError::NullPointer)?;
97
98 if let Some(size) = self.allocated_ptrs.remove(&nn) {
99 self.current_size -= size;
100 check(unsafe { cudaFreeAsync(ptr, cudaStreamPerThread) }).map_err(|e| {
101 tracing::error!("cudaFreeAsync failed: ptr={:p}: {:?}", ptr, e);
102 MemoryError::from(e)
103 })?;
104 } else {
105 let stream_id = current_stream_id()?;
106 let freed_size = self.pool.free_internal(ptr, stream_id)?;
107 self.current_size -= freed_size;
108 }
109
110 Ok(())
111 }
112}
113
114impl Drop for MemoryManager {
115 fn drop(&mut self) {
116 device_synchronize().unwrap();
117 let ptrs: Vec<*mut c_void> = self.allocated_ptrs.keys().map(|nn| nn.as_ptr()).collect();
118 for &ptr in &ptrs {
119 if let Err(e) = unsafe { self.d_free(ptr) } {
120 tracing::error!("MemoryManager drop: failed to free {:p}: {:?}", ptr, e);
121 }
122 }
123 }
124}
125
126impl Default for MemoryManager {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132pub fn d_malloc(size: usize) -> Result<*mut c_void, MemoryError> {
133 let manager = MEMORY_MANAGER.get().unwrap();
134 let mut manager = manager.lock().map_err(|_| MemoryError::LockError)?;
135 manager.d_malloc(size)
136}
137
138pub unsafe fn d_free(ptr: *mut c_void) -> Result<(), MemoryError> {
142 let manager = MEMORY_MANAGER.get().unwrap();
143 let mut manager = manager.lock().map_err(|_| MemoryError::LockError)?;
144 manager.d_free(ptr)
145}
146
147#[derive(Debug, Clone)]
148pub struct MemTracker {
149 current: usize,
150 label: &'static str,
151}
152
153impl MemTracker {
154 pub fn start(label: &'static str) -> Self {
155 let current = MEMORY_MANAGER
156 .get()
157 .and_then(|m| m.lock().ok())
158 .map(|m| m.current_size)
159 .unwrap_or(0);
160
161 Self { current, label }
162 }
163
164 #[inline]
165 pub fn tracing_info(&self, msg: impl Into<Option<&'static str>>) {
166 let Some(manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) else {
167 tracing::error!("Memory manager not available");
168 return;
169 };
170 let current = manager.current_size;
171 let peak = manager.max_used_size;
172 let used = current as isize - self.current as isize;
173 let sign = if used >= 0 { "+" } else { "-" };
174 let pool_usage = manager.pool.memory_usage();
175 tracing::info!(
176 "GPU mem: used={}{}, current={}, peak={}, in pool={} ({})",
177 sign,
178 ByteSize::b(used.unsigned_abs() as u64),
179 ByteSize::b(current as u64),
180 ByteSize::b(peak as u64),
181 ByteSize::b(pool_usage as u64),
182 msg.into()
183 .map_or(self.label.to_string(), |m| format!("{}:{}", self.label, m))
184 );
185 }
186
187 pub fn reset_peak(&mut self) {
188 if let Some(mut manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) {
189 manager.max_used_size = manager.current_size;
190 }
191 }
192}
193
194impl Drop for MemTracker {
195 fn drop(&mut self) {
196 self.tracing_info(None);
197 }
198}