openvm_cuda_common/memory_manager/
mod.rs

1use std::{
2    collections::HashMap,
3    ffi::c_void,
4    ptr::NonNull,
5    sync::{Mutex, OnceLock},
6};
7
8use bytesize::ByteSize;
9
10use crate::{
11    error::{check, MemoryError},
12    stream::{cudaStreamPerThread, cudaStream_t, current_stream_id, device_synchronize},
13};
14
15mod cuda;
16mod vm_pool;
17use vm_pool::VirtualMemoryPool;
18
19#[cfg(test)]
20mod tests;
21
22#[link(name = "cudart")]
23extern "C" {
24    fn cudaMallocAsync(dev_ptr: *mut *mut c_void, size: usize, stream: cudaStream_t) -> i32;
25    fn cudaFreeAsync(dev_ptr: *mut c_void, stream: cudaStream_t) -> i32;
26}
27
28static MEMORY_MANAGER: OnceLock<Mutex<MemoryManager>> = OnceLock::new();
29
30#[ctor::ctor]
31fn init() {
32    let _ = MEMORY_MANAGER.set(Mutex::new(MemoryManager::new()));
33    tracing::info!("Memory manager initialized at program start");
34}
35
36pub struct MemoryManager {
37    pool: VirtualMemoryPool,
38    allocated_ptrs: HashMap<NonNull<c_void>, usize>,
39    current_size: usize,
40    max_used_size: usize,
41}
42
43/// # Safety
44/// `MemoryManager` is not internally synchronized. These impls are safe because
45/// the singleton instance is wrapped in `Mutex` via `MEMORY_MANAGER`.
46unsafe impl Send for MemoryManager {}
47unsafe impl Sync for MemoryManager {}
48
49impl MemoryManager {
50    pub fn new() -> Self {
51        // Create virtual memory pool
52        let pool = VirtualMemoryPool::default();
53
54        Self {
55            pool,
56            allocated_ptrs: HashMap::new(),
57            current_size: 0,
58            max_used_size: 0,
59        }
60    }
61
62    fn d_malloc(&mut self, size: usize) -> Result<*mut c_void, MemoryError> {
63        assert!(size != 0, "Requested size must be non-zero");
64
65        let mut tracked_size = size;
66        let ptr = if size < self.pool.page_size {
67            let mut ptr: *mut c_void = std::ptr::null_mut();
68            check(unsafe { cudaMallocAsync(&mut ptr, size, cudaStreamPerThread) }).map_err(
69                |e| {
70                    tracing::error!("cudaMallocAsync failed: size={}: {:?}", size, e);
71                    MemoryError::from(e)
72                },
73            )?;
74            self.allocated_ptrs.insert(
75                NonNull::new(ptr).expect("BUG: cudaMallocAsync returned null"),
76                size,
77            );
78            ptr
79        } else {
80            tracked_size = size.next_multiple_of(self.pool.page_size);
81            let stream_id = current_stream_id()?;
82            self.pool.malloc_internal(tracked_size, stream_id)?
83        };
84
85        self.current_size += tracked_size;
86        if self.current_size > self.max_used_size {
87            self.max_used_size = self.current_size;
88        }
89        Ok(ptr)
90    }
91
92    /// # Safety
93    /// The pointer `ptr` must be a valid, previously allocated device pointer.
94    /// The caller must ensure that `ptr` is not used after this function is called.
95    unsafe fn d_free(&mut self, ptr: *mut c_void) -> Result<(), MemoryError> {
96        let nn = NonNull::new(ptr).ok_or(MemoryError::NullPointer)?;
97
98        if let Some(size) = self.allocated_ptrs.remove(&nn) {
99            self.current_size -= size;
100            check(unsafe { cudaFreeAsync(ptr, cudaStreamPerThread) }).map_err(|e| {
101                tracing::error!("cudaFreeAsync failed: ptr={:p}: {:?}", ptr, e);
102                MemoryError::from(e)
103            })?;
104        } else {
105            let stream_id = current_stream_id()?;
106            let freed_size = self.pool.free_internal(ptr, stream_id)?;
107            self.current_size -= freed_size;
108        }
109
110        Ok(())
111    }
112}
113
114impl Drop for MemoryManager {
115    fn drop(&mut self) {
116        device_synchronize().unwrap();
117        let ptrs: Vec<*mut c_void> = self.allocated_ptrs.keys().map(|nn| nn.as_ptr()).collect();
118        for &ptr in &ptrs {
119            if let Err(e) = unsafe { self.d_free(ptr) } {
120                tracing::error!("MemoryManager drop: failed to free {:p}: {:?}", ptr, e);
121            }
122        }
123    }
124}
125
126impl Default for MemoryManager {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132pub fn d_malloc(size: usize) -> Result<*mut c_void, MemoryError> {
133    let manager = MEMORY_MANAGER.get().unwrap();
134    let mut manager = manager.lock().map_err(|_| MemoryError::LockError)?;
135    manager.d_malloc(size)
136}
137
138/// # Safety
139/// The pointer `ptr` must be a valid, previously allocated device pointer.
140/// The caller must ensure that `ptr` is not used after this function is called.
141pub unsafe fn d_free(ptr: *mut c_void) -> Result<(), MemoryError> {
142    let manager = MEMORY_MANAGER.get().unwrap();
143    let mut manager = manager.lock().map_err(|_| MemoryError::LockError)?;
144    manager.d_free(ptr)
145}
146
147#[derive(Debug, Clone)]
148pub struct MemTracker {
149    current: usize,
150    label: &'static str,
151}
152
153impl MemTracker {
154    pub fn start(label: &'static str) -> Self {
155        let current = MEMORY_MANAGER
156            .get()
157            .and_then(|m| m.lock().ok())
158            .map(|m| m.current_size)
159            .unwrap_or(0);
160
161        Self { current, label }
162    }
163
164    #[inline]
165    pub fn tracing_info(&self, msg: impl Into<Option<&'static str>>) {
166        let Some(manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) else {
167            tracing::error!("Memory manager not available");
168            return;
169        };
170        let current = manager.current_size;
171        let peak = manager.max_used_size;
172        let used = current as isize - self.current as isize;
173        let sign = if used >= 0 { "+" } else { "-" };
174        let pool_usage = manager.pool.memory_usage();
175        tracing::info!(
176            "GPU mem: used={}{}, current={}, peak={}, in pool={} ({})",
177            sign,
178            ByteSize::b(used.unsigned_abs() as u64),
179            ByteSize::b(current as u64),
180            ByteSize::b(peak as u64),
181            ByteSize::b(pool_usage as u64),
182            msg.into()
183                .map_or(self.label.to_string(), |m| format!("{}:{}", self.label, m))
184        );
185    }
186
187    pub fn reset_peak(&mut self) {
188        if let Some(mut manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) {
189            manager.max_used_size = manager.current_size;
190        }
191    }
192}
193
194impl Drop for MemTracker {
195    fn drop(&mut self) {
196        self.tracing_info(None);
197    }
198}