openvm_cuda_common/memory_manager/
mod.rs

1use std::{
2    collections::HashMap,
3    ffi::c_void,
4    ptr::NonNull,
5    sync::{Mutex, OnceLock},
6};
7
8use bytesize::ByteSize;
9
10use crate::{
11    error::{check, MemoryError},
12    stream::{cudaStreamPerThread, cudaStream_t, current_stream_id},
13};
14
15mod cuda;
16mod vm_pool;
17use vm_pool::VirtualMemoryPool;
18
19#[link(name = "cudart")]
20extern "C" {
21    fn cudaMallocAsync(dev_ptr: *mut *mut c_void, size: usize, stream: cudaStream_t) -> i32;
22    fn cudaFreeAsync(dev_ptr: *mut c_void, stream: cudaStream_t) -> i32;
23}
24
25static MEMORY_MANAGER: OnceLock<Mutex<MemoryManager>> = OnceLock::new();
26
27#[ctor::ctor]
28fn init() {
29    let _ = MEMORY_MANAGER.set(Mutex::new(MemoryManager::new()));
30    tracing::info!("Memory manager initialized at program start");
31}
32
33pub struct MemoryManager {
34    pool: VirtualMemoryPool,
35    allocated_ptrs: HashMap<NonNull<c_void>, usize>,
36    current_size: usize,
37    max_used_size: usize,
38}
39
40unsafe impl Send for MemoryManager {}
41unsafe impl Sync for MemoryManager {}
42
43impl MemoryManager {
44    pub fn new() -> Self {
45        // Create virtual memory pool
46        let pool = VirtualMemoryPool::default();
47
48        Self {
49            pool,
50            allocated_ptrs: HashMap::new(),
51            current_size: 0,
52            max_used_size: 0,
53        }
54    }
55
56    fn d_malloc(&mut self, size: usize) -> Result<*mut c_void, MemoryError> {
57        assert!(size != 0, "Requested size must be non-zero");
58
59        let mut tracked_size = size;
60        let ptr = if size < self.pool.page_size {
61            let mut ptr: *mut c_void = std::ptr::null_mut();
62            check(unsafe { cudaMallocAsync(&mut ptr, size, cudaStreamPerThread) })?;
63            self.allocated_ptrs
64                .insert(NonNull::new(ptr).expect("cudaMalloc returned null"), size);
65            Ok(ptr)
66        } else {
67            tracked_size = size.next_multiple_of(self.pool.page_size);
68            self.pool
69                .malloc_internal(tracked_size, current_stream_id()?)
70        };
71
72        self.current_size += tracked_size;
73        if self.current_size > self.max_used_size {
74            self.max_used_size = self.current_size;
75        }
76        ptr
77    }
78
79    /// # Safety
80    /// The pointer `ptr` must be a valid, previously allocated device pointer.
81    /// The caller must ensure that `ptr` is not used after this function is called.
82    unsafe fn d_free(&mut self, ptr: *mut c_void) -> Result<(), MemoryError> {
83        let nn = NonNull::new(ptr).ok_or(MemoryError::NullPointer)?;
84
85        if let Some(size) = self.allocated_ptrs.remove(&nn) {
86            self.current_size -= size;
87            check(unsafe { cudaFreeAsync(ptr, cudaStreamPerThread) })?;
88        } else {
89            self.current_size -= self.pool.free_internal(ptr, current_stream_id()?)?;
90        }
91
92        Ok(())
93    }
94}
95
96impl Drop for MemoryManager {
97    fn drop(&mut self) {
98        let ptrs: Vec<*mut c_void> = self.allocated_ptrs.keys().map(|nn| nn.as_ptr()).collect();
99        for &ptr in &ptrs {
100            unsafe { self.d_free(ptr).unwrap() };
101        }
102        if !self.allocated_ptrs.is_empty() {
103            println!(
104                "Error: {} allocations were automatically freed on MemoryManager drop",
105                self.allocated_ptrs.len()
106            );
107        }
108    }
109}
110
111impl Default for MemoryManager {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117pub fn d_malloc(size: usize) -> Result<*mut c_void, MemoryError> {
118    let manager = MEMORY_MANAGER.get().unwrap();
119    let mut manager = manager.lock().map_err(|_| MemoryError::LockError)?;
120    manager.d_malloc(size)
121}
122
123/// # Safety
124/// The pointer `ptr` must be a valid, previously allocated device pointer.
125/// The caller must ensure that `ptr` is not used after this function is called.
126pub unsafe fn d_free(ptr: *mut c_void) -> Result<(), MemoryError> {
127    let manager = MEMORY_MANAGER.get().unwrap();
128    let mut manager = manager.lock().map_err(|_| MemoryError::LockError)?;
129    manager.d_free(ptr)
130}
131
132#[derive(Debug, Clone)]
133pub struct MemTracker {
134    current: usize,
135    label: &'static str,
136}
137
138impl MemTracker {
139    pub fn start(label: &'static str) -> Self {
140        let current = MEMORY_MANAGER
141            .get()
142            .and_then(|m| m.lock().ok())
143            .map(|m| m.current_size)
144            .unwrap_or(0);
145
146        Self { current, label }
147    }
148
149    #[inline]
150    pub fn tracing_info(&self, msg: impl Into<Option<&'static str>>) {
151        let Some(manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) else {
152            tracing::error!("Memory manager not available");
153            return;
154        };
155        let current = manager.current_size;
156        let peak = manager.max_used_size;
157        let used = current as isize - self.current as isize;
158        let sign = if used >= 0 { "+" } else { "-" };
159        let pool_usage = manager.pool.memory_usage();
160        tracing::info!(
161            "GPU mem: used={}{}, current={}, peak={}, in pool={} ({})",
162            sign,
163            ByteSize::b(used.unsigned_abs() as u64),
164            ByteSize::b(current as u64),
165            ByteSize::b(peak as u64),
166            ByteSize::b(pool_usage as u64),
167            msg.into()
168                .map_or(self.label.to_string(), |m| format!("{}:{}", self.label, m))
169        );
170    }
171
172    pub fn reset_peak(&mut self) {
173        if let Some(mut manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) {
174            manager.max_used_size = manager.current_size;
175        }
176    }
177}
178
179impl Drop for MemTracker {
180    fn drop(&mut self) {
181        self.tracing_info(None);
182    }
183}