openvm_circuit/system/cuda/
memory.rs

1use std::sync::Arc;
2
3use openvm_circuit::{
4    arch::{AddressSpaceHostLayout, DenseRecordArena, MemoryConfig, ADDR_SPACE_OFFSET},
5    system::{
6        memory::{online::LinearMemory, AddressMap, TimestampedValues},
7        TouchedMemory,
8    },
9};
10use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU;
11use openvm_cuda_backend::{prover_backend::GpuBackend, types::F};
12use openvm_cuda_common::{
13    copy::{cuda_memcpy, MemCopyD2D, MemCopyH2D},
14    d_buffer::DeviceBuffer,
15    memory_manager::MemTracker,
16};
17use openvm_stark_backend::{
18    p3_field::FieldAlgebra, p3_util::log2_ceil_usize, prover::types::AirProvingContext, Chip,
19};
20
21use super::{
22    access_adapters::AccessAdapterInventoryGPU,
23    boundary::{BoundaryChipGPU, BoundaryFields},
24    merkle_tree::{MemoryMerkleTree, TIMESTAMPED_BLOCK_WIDTH},
25    Poseidon2PeripheryChipGPU, DIGEST_WIDTH,
26};
27
28pub struct MemoryInventoryGPU {
29    pub boundary: BoundaryChipGPU,
30    pub access_adapters: AccessAdapterInventoryGPU,
31    pub persistent: Option<PersistentMemoryInventoryGPU>,
32    #[cfg(feature = "metrics")]
33    pub(super) unpadded_merkle_height: usize,
34}
35
36pub struct PersistentMemoryInventoryGPU {
37    pub merkle_tree: MemoryMerkleTree,
38    pub initial_memory: Vec<DeviceBuffer<u8>>,
39}
40
41impl MemoryInventoryGPU {
42    pub fn volatile(config: MemoryConfig, range_checker: Arc<VariableRangeCheckerChipGPU>) -> Self {
43        let addr_space_max_bits = log2_ceil_usize(
44            (ADDR_SPACE_OFFSET + 2u32.pow(config.addr_space_height as u32)) as usize,
45        );
46        Self {
47            boundary: BoundaryChipGPU::volatile(
48                range_checker.clone(),
49                addr_space_max_bits,
50                config.pointer_max_bits,
51            ),
52            access_adapters: AccessAdapterInventoryGPU::new(
53                range_checker,
54                config.max_access_adapter_n,
55                config.timestamp_max_bits,
56            ),
57            persistent: None,
58            #[cfg(feature = "metrics")]
59            unpadded_merkle_height: 0,
60        }
61    }
62
63    pub fn persistent(
64        config: MemoryConfig,
65        range_checker: Arc<VariableRangeCheckerChipGPU>,
66        hasher_chip: Arc<Poseidon2PeripheryChipGPU>,
67    ) -> Self {
68        Self {
69            boundary: BoundaryChipGPU::persistent(hasher_chip.shared_buffer()),
70            access_adapters: AccessAdapterInventoryGPU::new(
71                range_checker,
72                config.max_access_adapter_n,
73                config.timestamp_max_bits,
74            ),
75            persistent: Some(PersistentMemoryInventoryGPU {
76                merkle_tree: MemoryMerkleTree::new(config.clone(), hasher_chip.clone()),
77                initial_memory: Vec::new(),
78            }),
79            #[cfg(feature = "metrics")]
80            unpadded_merkle_height: 0,
81        }
82    }
83
84    pub fn continuation_enabled(&self) -> bool {
85        self.persistent.is_some()
86    }
87
88    pub fn set_initial_memory(&mut self, initial_memory: &AddressMap) {
89        let _mem = MemTracker::start("set initial memory");
90        let persistent = self
91            .persistent
92            .as_mut()
93            .expect("`set_initial_memory` requires persistent memory");
94        for (addr_sp, raw_mem) in initial_memory
95            .get_memory()
96            .iter()
97            .map(|mem| mem.as_slice())
98            .enumerate()
99        {
100            tracing::debug!(
101                "Setting initial memory for address space {}: {} bytes",
102                addr_sp,
103                raw_mem.len()
104            );
105            persistent.initial_memory.push(if raw_mem.is_empty() {
106                DeviceBuffer::new()
107            } else {
108                raw_mem
109                    .to_device()
110                    .expect("failed to copy memory to device")
111            });
112            persistent
113                .merkle_tree
114                .build_async(&persistent.initial_memory[addr_sp], addr_sp);
115        }
116        match &mut self.boundary.fields {
117            BoundaryFields::Volatile(_) => {
118                panic!("`set_initial_memory` requires persistent memory")
119            }
120            BoundaryFields::Persistent(fields) => {
121                fields.initial_leaves = persistent
122                    .initial_memory
123                    .iter()
124                    .skip(1)
125                    .map(|per_as| per_as.as_raw_ptr())
126                    .collect();
127            }
128        }
129    }
130
131    pub fn generate_proving_ctxs(
132        &mut self,
133        access_adapter_arena: DenseRecordArena,
134        touched_memory: TouchedMemory<F>,
135    ) -> Vec<AirProvingContext<GpuBackend>> {
136        let mem = MemTracker::start("generate mem proving ctxs");
137        let merkle_proof_ctx = match touched_memory {
138            TouchedMemory::Persistent(partition) => {
139                let persistent = self
140                    .persistent
141                    .as_ref()
142                    .expect("persistent touched memory requires persistent memory interface");
143
144                let unpadded_merkle_height =
145                    persistent.merkle_tree.calculate_unpadded_height(&partition);
146                #[cfg(feature = "metrics")]
147                {
148                    self.unpadded_merkle_height = unpadded_merkle_height;
149                }
150
151                mem.tracing_info("boundary finalize");
152                let (touched_memory, empty) = if partition.is_empty() {
153                    let leftmost_values = 'left: {
154                        let mut res = [F::ZERO; DIGEST_WIDTH];
155                        if persistent.initial_memory[ADDR_SPACE_OFFSET as usize].is_empty() {
156                            break 'left res;
157                        }
158                        let layout = &persistent.merkle_tree.mem_config().addr_spaces
159                            [ADDR_SPACE_OFFSET as usize]
160                            .layout;
161                        let one_cell_size = layout.size();
162                        let values = vec![0u8; one_cell_size * DIGEST_WIDTH];
163                        unsafe {
164                            cuda_memcpy::<true, false>(
165                                values.as_ptr() as *mut std::ffi::c_void,
166                                persistent.initial_memory[ADDR_SPACE_OFFSET as usize].as_ptr()
167                                    as *const std::ffi::c_void,
168                                values.len(),
169                            )
170                            .unwrap();
171                            for i in 0..DIGEST_WIDTH {
172                                res[i] = layout.to_field::<F>(&values[i * one_cell_size..]);
173                            }
174                        }
175                        res
176                    };
177
178                    (
179                        vec![(
180                            (1, 0),
181                            TimestampedValues {
182                                timestamp: 0,
183                                values: leftmost_values,
184                            },
185                        )],
186                        true,
187                    )
188                } else {
189                    (partition, false)
190                };
191                debug_assert_eq!(
192                    size_of_val(&touched_memory[0]),
193                    TIMESTAMPED_BLOCK_WIDTH * size_of::<u32>()
194                );
195                let d_touched_memory = touched_memory.to_device().unwrap().as_buffer::<u32>();
196                if empty {
197                    self.boundary
198                        .finalize_records_persistent::<DIGEST_WIDTH>(DeviceBuffer::new());
199                } else {
200                    self.boundary.finalize_records_persistent::<DIGEST_WIDTH>(
201                        d_touched_memory.device_copy().unwrap().as_buffer::<u32>(),
202                    ); // TODO do not copy
203                }
204                mem.tracing_info("merkle update");
205                persistent.merkle_tree.finalize();
206                Some(persistent.merkle_tree.update_with_touched_blocks(
207                    unpadded_merkle_height,
208                    &d_touched_memory,
209                    empty,
210                ))
211            }
212            TouchedMemory::Volatile(partition) => {
213                assert!(self.persistent.is_none(), "TouchedMemory enum mismatch");
214                self.boundary.finalize_records_volatile(partition);
215                None
216            }
217        };
218        mem.tracing_info("boundary tracegen");
219        let mut ret = vec![self.boundary.generate_proving_ctx(())];
220        if let Some(merkle_proof_ctx) = merkle_proof_ctx {
221            ret.push(merkle_proof_ctx);
222            mem.tracing_info("dropping merkle tree");
223            let persistent = self.persistent.as_mut().unwrap();
224            persistent.merkle_tree.drop_subtrees();
225            persistent.initial_memory = Vec::new();
226        }
227        ret.extend(
228            self.access_adapters
229                .generate_air_proving_ctxs(access_adapter_arena),
230        );
231        ret
232    }
233}