openvm_circuit/system/cuda/merkle_tree/
mod.rs

1use std::{ffi::c_void, sync::Arc};
2
3use openvm_circuit::{
4    arch::{MemoryConfig, ADDR_SPACE_OFFSET},
5    system::memory::{merkle::MemoryMerkleCols, TimestampedEquipartition},
6    utils::next_power_of_two_or_zero,
7};
8use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
9use openvm_cuda_common::{
10    copy::{cuda_memcpy, MemCopyD2H, MemCopyH2D},
11    d_buffer::DeviceBuffer,
12    stream::{
13        cudaStreamPerThread, current_stream_sync, default_stream_wait, CudaEvent, CudaStream,
14    },
15};
16use openvm_stark_backend::{
17    p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator},
18    p3_util::log2_ceil_usize,
19    prover::types::AirProvingContext,
20};
21use p3_field::FieldAlgebra;
22
23use super::{poseidon2::SharedBuffer, Poseidon2PeripheryChipGPU, DIGEST_WIDTH};
24
25pub mod cuda;
26use cuda::merkle_tree::*;
27
28type H = [F; DIGEST_WIDTH];
29pub const TIMESTAMPED_BLOCK_WIDTH: usize = 11;
30
31/// A Merkle subtree stored in a single flat buffer, combining a vertical path and a heap-ordered
32/// binary tree.
33///
34/// Memory layout:
35/// - The first `path_len` elements form a vertical path (one node per level), used when the actual
36///   size is smaller than the max size.
37/// - The remaining elements store the subtree nodes in heap-order (breadth-first), with `size`
38///   leaves and `2 * size - 1` total nodes.
39///
40/// The call of filling the buffer is done async on the new stream. Option<CudaEvent> is used to
41/// wait for the completion.
42pub struct MemoryMerkleSubTree {
43    pub stream: Arc<CudaStream>,
44    #[allow(dead_code)] // See `drop_subtrees`
45    created_buffer_event: Option<CudaEvent>,
46    build_completion_event: Option<CudaEvent>,
47    pub buf: DeviceBuffer<H>,
48    pub height: usize,
49    pub path_len: usize,
50}
51
52impl MemoryMerkleSubTree {
53    /// Constructs a new Merkle subtree with a vertical path and heap-ordered tree.
54    /// The buffer is sized based on the actual address space and the maximum size.
55    ///
56    /// `addr_space_size` is the number of leaf digest nodes necessary for this address space. The
57    /// `max_size` is the number of leaf digest nodes in the full balanced tree dictated by
58    /// `addr_space_height` from the `MemoryConfig`.
59    pub fn new(addr_space_size: usize, max_size: usize) -> Self {
60        assert!(
61            max_size.is_power_of_two(),
62            "Max address space size must be a power of two"
63        );
64        let size = next_power_of_two_or_zero(addr_space_size);
65        if addr_space_size == 0 {
66            let mut res = MemoryMerkleSubTree::dummy();
67            res.height = log2_ceil_usize(max_size);
68            return res;
69        }
70        let height = log2_ceil_usize(size);
71        let path_len = log2_ceil_usize(max_size).checked_sub(height).unwrap();
72        tracing::debug!(
73            "Creating a subtree buffer, size is {} (addr space size is {})",
74            path_len + (2 * size - 1),
75            addr_space_size
76        );
77        let buf = DeviceBuffer::<H>::with_capacity(path_len + (2 * size - 1));
78
79        let created_buffer_event = CudaEvent::new().unwrap();
80        unsafe {
81            created_buffer_event.record(cudaStreamPerThread).unwrap();
82        }
83
84        let stream = Arc::new(CudaStream::new().unwrap());
85        stream.wait(&created_buffer_event).unwrap();
86        Self {
87            stream,
88            created_buffer_event: Some(created_buffer_event),
89            build_completion_event: None,
90            height,
91            buf,
92            path_len,
93        }
94    }
95
96    pub fn dummy() -> Self {
97        Self {
98            stream: Arc::new(CudaStream::new().unwrap()),
99            created_buffer_event: None,
100            build_completion_event: None,
101            height: 0,
102            buf: DeviceBuffer::new(),
103            path_len: 0,
104        }
105    }
106
107    /// Asynchronously builds the Merkle subtree on its dedicated CUDA stream.
108    /// Also reconstructs the vertical path if `path_len > 0`, and records a completion event.
109    ///
110    /// Here `addr_space_idx` is the address space _shifted_ by ADDR_SPACE_OFFSET = 1
111    pub fn build_async(
112        &mut self,
113        d_data: &DeviceBuffer<u8>,
114        addr_space_idx: usize,
115        zero_hash: &DeviceBuffer<H>,
116    ) {
117        let event = CudaEvent::new().unwrap();
118        if self.buf.is_empty() {
119            // TODO not really async in this branch is it
120            self.buf = DeviceBuffer::with_capacity(1);
121            unsafe {
122                cuda_memcpy::<true, true>(
123                    self.buf.as_mut_raw_ptr(),
124                    zero_hash.as_ptr().add(self.height) as *mut c_void,
125                    size_of::<H>(),
126                )
127                .unwrap();
128                event.record(cudaStreamPerThread).unwrap();
129            }
130        } else {
131            unsafe {
132                build_merkle_subtree(
133                    d_data,
134                    1 << self.height,
135                    &self.buf,
136                    self.path_len,
137                    addr_space_idx as u32,
138                    self.stream.as_raw(),
139                )
140                .unwrap();
141
142                if self.path_len > 0 {
143                    restore_merkle_subtree_path(
144                        &self.buf,
145                        zero_hash,
146                        self.path_len,
147                        self.height + self.path_len,
148                        self.stream.as_raw(),
149                    )
150                    .unwrap();
151                }
152                event.record(self.stream.as_raw()).unwrap();
153            }
154        }
155        self.build_completion_event = Some(event);
156    }
157
158    /// Returns the bounds [start, end) of the layer at the given depth.
159    /// These bounds correspond to the indices of the layer in the buffer.
160    /// depth: 0 = root, 1 = root's children, ..., height-1 = leaves
161    pub fn layer_bounds(&self, depth: usize) -> (usize, usize) {
162        let global_height = self.height + self.path_len;
163        assert!(
164            depth < global_height,
165            "Depth {depth} out of bounds for height {global_height}",
166        );
167        if depth >= self.path_len {
168            // depth is within the heap-ordered subtree
169            let d = depth - self.path_len;
170            let start = self.path_len + ((1 << d) - 1);
171            let end = self.path_len + ((1 << (d + 1)) - 1);
172            (start, end)
173        } else {
174            // vertical path layer: single node per level
175            (depth, depth + 1)
176        }
177    }
178}
179
180/// A Memory Merkle tree composed of independent subtrees (one per address space),
181/// each built asynchronously and finalized into a top-level Merkle root.
182///
183/// Layout:
184/// - The memory is split across multiple `MemoryMerkleSubTree` instances, one per address space.
185/// - The top-level tree is formed by hashing all subtree roots into a single buffer (`top_roots`).
186///     - top_roots layout: \[root, hash(root_addr_space_1, root_addr_space_2),
187///       hash(root_addr_space_3), hash(root_addr_space_4), ...\]
188///     - if we have > 4 address spaces, top_roots will be extended with the next hash, etc.
189///
190/// Execution:
191/// - Subtrees are built asynchronously on individual CUDA streams.
192/// - The final root is computed after all subtrees complete, on the default stream.
193/// - `CudaEvent`s are used to synchronize subtree completion.
194pub struct MemoryMerkleTree {
195    pub subtrees: Vec<MemoryMerkleSubTree>,
196    pub top_roots: DeviceBuffer<H>,
197    zero_hash: DeviceBuffer<H>,
198    pub height: usize,
199    pub hasher_buffer: SharedBuffer<F>,
200    mem_config: MemoryConfig,
201    pub(crate) top_roots_host: Vec<H>,
202}
203
204impl MemoryMerkleTree {
205    /// Creates a full Merkle tree with one subtree per address space.
206    /// Initializes all buffers and precomputes the zero hash chain.
207    pub fn new(mem_config: MemoryConfig, hasher_chip: Arc<Poseidon2PeripheryChipGPU>) -> Self {
208        let addr_space_sizes = mem_config
209            .addr_spaces
210            .iter()
211            .map(|ashc| {
212                assert!(
213                    ashc.num_cells % DIGEST_WIDTH == 0,
214                    "the number of cells must be divisible by `DIGEST_WIDTH`"
215                );
216                ashc.num_cells / DIGEST_WIDTH
217            })
218            .collect::<Vec<_>>();
219        assert!(!(addr_space_sizes.is_empty()), "Invalid config");
220
221        let num_addr_spaces = addr_space_sizes.len() - ADDR_SPACE_OFFSET as usize;
222        assert!(
223            num_addr_spaces.is_power_of_two(),
224            "Number of address spaces must be a one plus power of two"
225        );
226        for &sz in addr_space_sizes.iter().take(ADDR_SPACE_OFFSET as usize) {
227            assert!(
228                sz == 0,
229                "The first `ADDR_SPACE_OFFSET` address spaces are assumed to be empty"
230            );
231        }
232
233        let label_max_bits = mem_config.pointer_max_bits - log2_ceil_usize(DIGEST_WIDTH);
234
235        let zero_hash = DeviceBuffer::<H>::with_capacity(label_max_bits + 1);
236        let top_roots = DeviceBuffer::<H>::with_capacity(2 * num_addr_spaces - 1);
237        unsafe {
238            calculate_zero_hash(&zero_hash, label_max_bits).unwrap();
239        }
240
241        Self {
242            subtrees: Vec::new(),
243            top_roots,
244            height: label_max_bits + log2_ceil_usize(num_addr_spaces),
245            zero_hash,
246            hasher_buffer: hasher_chip.shared_buffer(),
247            mem_config,
248            top_roots_host: vec![],
249        }
250    }
251
252    pub fn mem_config(&self) -> &MemoryConfig {
253        &self.mem_config
254    }
255
256    /// Starts asynchronous construction of the specified address space's Merkle subtree.
257    /// Uses internal zero hashes and launches kernels on the subtree's own CUDA stream.
258    ///
259    /// Here `addr_space` is the _unshifted_ address space, so `addr_space = 0` is the immediate
260    /// address space, which should be ignored.
261    pub fn build_async(&mut self, d_data: &DeviceBuffer<u8>, addr_space: usize) {
262        if addr_space < ADDR_SPACE_OFFSET as usize {
263            return;
264        }
265        let addr_space_idx = addr_space - ADDR_SPACE_OFFSET as usize;
266        if addr_space < self.mem_config.addr_spaces.len() && addr_space_idx == self.subtrees.len() {
267            let mut subtree = MemoryMerkleSubTree::new(
268                self.mem_config.addr_spaces[addr_space].num_cells / DIGEST_WIDTH,
269                1 << (self.zero_hash.len() - 1), /* label_max_bits */
270            );
271            subtree.build_async(d_data, addr_space_idx, &self.zero_hash);
272            self.subtrees.push(subtree);
273        } else {
274            panic!("Invalid address space index");
275        }
276    }
277
278    /// Finalizes the Merkle tree by collecting all subtree roots and computing the final root.
279    /// Waits for all subtrees to complete and then performs the final hash operation.
280    pub fn finalize(&mut self) {
281        // Default stream waits for all subtrees to complete
282        for subtree in self.subtrees.iter() {
283            default_stream_wait(
284                subtree
285                    .build_completion_event
286                    .as_ref()
287                    .expect("Subtree build event does not exist"),
288            )
289            .unwrap();
290        }
291
292        let roots: Vec<usize> = self
293            .subtrees
294            .iter()
295            .map(|subtree| subtree.buf.as_ptr() as usize)
296            .collect();
297        let d_roots = roots.to_device().unwrap();
298
299        unsafe {
300            finalize_merkle_tree(
301                &d_roots,
302                &self.top_roots,
303                self.subtrees.len(),
304                cudaStreamPerThread,
305            )
306            .unwrap();
307        }
308    }
309
310    /// Drops all massive buffers to free memory. Used at the end of an execution segment.
311    ///
312    /// Caution: this method destroys all subtree streams and events. For safety, we force
313    /// synchronize all subtree streams and the default stream (cudaStreamPerThread) with host
314    /// before deallocating buffers.
315    pub fn drop_subtrees(&mut self) {
316        // Make sure all streams are synchronized before destroying events
317        for subtree in self.subtrees.iter() {
318            subtree.stream.synchronize().unwrap();
319            if let Some(_event) = subtree.build_completion_event.as_ref() {
320                tracing::warn!(
321                    "Dropping merkle subtree before build_async event has been destroyed"
322                );
323            }
324        }
325        current_stream_sync().unwrap();
326        // Clearing will drop streams (which calls synchronize again) and drops events (which
327        // destroys them)
328        self.subtrees.clear();
329    }
330
331    /// Updates the tree and returns the merkle trace.
332    pub fn update_with_touched_blocks(
333        &mut self,
334        unpadded_height: usize,
335        d_touched_blocks: &DeviceBuffer<u32>, // consists of (as, label, ts, [F; 8])
336        empty_touched_blocks: bool,
337    ) -> AirProvingContext<GpuBackend> {
338        let mut public_values = self.top_roots.to_host().unwrap()[0].to_vec();
339        // .to_host() calls cudaEventSynchronize on the D2H memcpy, which also means all subtree
340        // events are now completed, so we can clean up the events.
341        for subtree in &mut self.subtrees {
342            subtree.build_completion_event = None;
343        }
344        let merkle_trace = {
345            let width = MemoryMerkleCols::<u8, DIGEST_WIDTH>::width();
346            let padded_height = next_power_of_two_or_zero(unpadded_height);
347            let output = DeviceMatrix::<F>::with_capacity(padded_height, width);
348            output.buffer().fill_zero().unwrap();
349
350            let actual_heights = self.subtrees.iter().map(|s| s.height).collect::<Vec<_>>();
351            let subtrees_pointers = self
352                .subtrees
353                .iter()
354                .map(|st| st.buf.as_ptr() as usize)
355                .collect::<Vec<_>>()
356                .to_device()
357                .unwrap();
358            unsafe {
359                update_merkle_tree(
360                    &output,
361                    &subtrees_pointers,
362                    &self.top_roots,
363                    &self.zero_hash,
364                    d_touched_blocks,
365                    self.height - log2_ceil_usize(self.subtrees.len()),
366                    &actual_heights,
367                    unpadded_height,
368                    &self.hasher_buffer,
369                )
370                .unwrap();
371            }
372
373            if empty_touched_blocks {
374                // The trace is small then
375                let mut output_vec = output.to_host().unwrap();
376                output_vec[unpadded_height - 1 + (width - 2) * padded_height] = F::ONE; // left_direction_different
377                output_vec[unpadded_height - 1 + (width - 1) * padded_height] = F::ONE; // right_direction_different
378                DeviceMatrix::new(
379                    Arc::new(output_vec.to_device().unwrap()),
380                    padded_height,
381                    width,
382                )
383            } else {
384                output
385            }
386        };
387        self.top_roots_host = self.top_roots.to_host().unwrap();
388        public_values.extend(self.top_roots_host[0]);
389
390        AirProvingContext::new(Vec::new(), Some(merkle_trace), public_values)
391    }
392
393    /// An auxiliary function to calculate the required number of rows for the merkle trace.
394    pub fn calculate_unpadded_height(
395        &self,
396        touched_memory: &TimestampedEquipartition<F, DIGEST_WIDTH>,
397    ) -> usize {
398        let md = self.mem_config.memory_dimensions();
399        let tree_height = md.overall_height();
400        let shift_address = |(sp, ptr): (u32, u32)| (sp, ptr / DIGEST_WIDTH as u32);
401        2 * if touched_memory.is_empty() {
402            tree_height
403        } else {
404            tree_height
405                + (0..(touched_memory.len() - 1))
406                    .into_par_iter()
407                    .map(|i| {
408                        let x = md.label_to_index(shift_address(touched_memory[i].0));
409                        let y = md.label_to_index(shift_address(touched_memory[i + 1].0));
410                        (x ^ y).ilog2() as usize
411                    })
412                    .sum::<usize>()
413        }
414    }
415}
416
417impl Drop for MemoryMerkleTree {
418    fn drop(&mut self) {
419        self.drop_subtrees();
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use std::sync::Arc;
426
427    use openvm_circuit::{
428        arch::{
429            testing::POSEIDON2_DIRECT_BUS, vm_poseidon2_config, AddressSpaceHostLayout,
430            MemoryCellType, MemoryConfig,
431        },
432        system::{
433            memory::{
434                merkle::MerkleTree,
435                online::{GuestMemory, LinearMemory},
436                AddressMap, TimestampedValues,
437            },
438            poseidon2::Poseidon2PeripheryChip,
439        },
440    };
441    use openvm_cuda_backend::prelude::F;
442    use openvm_cuda_common::{
443        copy::{MemCopyD2H, MemCopyH2D},
444        d_buffer::DeviceBuffer,
445    };
446    use openvm_instructions::{
447        riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
448        NATIVE_AS,
449    };
450    use openvm_stark_sdk::utils::create_seeded_rng;
451    use p3_field::{FieldAlgebra, PrimeField32};
452    use rand::Rng;
453
454    use super::MemoryMerkleTree;
455    use crate::system::cuda::{Poseidon2PeripheryChipGPU, DIGEST_WIDTH};
456
457    #[test]
458    fn test_cuda_merkle_tree_cpu_gpu_root_equivalence() {
459        let mut rng = create_seeded_rng();
460        let mem_config = {
461            let mut addr_spaces = MemoryConfig::empty_address_space_configs(5);
462            let max_cells = 1 << 16;
463            addr_spaces[RV32_REGISTER_AS as usize].num_cells = 32 * size_of::<u32>();
464            addr_spaces[RV32_MEMORY_AS as usize].num_cells = max_cells;
465            addr_spaces[NATIVE_AS as usize].num_cells = max_cells;
466            MemoryConfig::new(2, addr_spaces, max_cells.ilog2() as usize, 29, 17, 32)
467        };
468
469        let mut initial_memory = GuestMemory::new(AddressMap::from_mem_config(&mem_config));
470        for (idx, space) in mem_config.addr_spaces.iter().enumerate() {
471            unsafe {
472                match space.layout {
473                    MemoryCellType::Null => {}
474                    MemoryCellType::U8 => {
475                        for i in 0..space.num_cells {
476                            initial_memory.write::<u8, 1>(
477                                idx as u32,
478                                i as u32,
479                                [rng.gen_range(0..space.layout.size()) as u8],
480                            );
481                        }
482                    }
483                    MemoryCellType::U16 => {
484                        for i in 0..space.num_cells {
485                            initial_memory.write::<u16, 1>(
486                                idx as u32,
487                                i as u32,
488                                [rng.gen_range(0..space.layout.size()) as u16],
489                            );
490                        }
491                    }
492                    MemoryCellType::U32 => {
493                        for i in 0..space.num_cells {
494                            initial_memory.write::<u32, 1>(
495                                idx as u32,
496                                i as u32,
497                                [rng.gen_range(0..space.layout.size()) as u32],
498                            );
499                        }
500                    }
501                    MemoryCellType::Native { .. } => {
502                        for i in 0..space.num_cells {
503                            initial_memory.write::<F, 1>(
504                                idx as u32,
505                                i as u32,
506                                [F::from_canonical_u32(rng.gen_range(0..F::ORDER_U32))],
507                            );
508                        }
509                    }
510                }
511            }
512        }
513
514        let gpu_hasher_chip = Arc::new(Poseidon2PeripheryChipGPU::new(
515            (mem_config
516                .addr_spaces
517                .iter()
518                .map(|ashc| ashc.num_cells * 2 + mem_config.memory_dimensions().overall_height())
519                .sum::<usize>()
520                * 2)
521            .next_power_of_two()
522                * 2
523                * DIGEST_WIDTH, // max_buffer_size
524            1, // sbox_regs
525        ));
526        let mut gpu_merkle_tree = MemoryMerkleTree::new(mem_config.clone(), gpu_hasher_chip);
527        for (i, mem) in initial_memory.memory.get_memory().iter().enumerate() {
528            let mem_slice = mem.as_slice();
529            gpu_merkle_tree.build_async(
530                &(if !mem_slice.is_empty() {
531                    mem_slice.to_device().unwrap()
532                } else {
533                    DeviceBuffer::new()
534                }),
535                i,
536            );
537        }
538        gpu_merkle_tree.finalize();
539
540        let cpu_hasher_chip =
541            Poseidon2PeripheryChip::new(vm_poseidon2_config(), POSEIDON2_DIRECT_BUS, 3);
542        let mut cpu_merkle_tree = MerkleTree::<F, DIGEST_WIDTH>::from_memory(
543            &initial_memory.memory,
544            &mem_config.memory_dimensions(),
545            &cpu_hasher_chip,
546        );
547
548        assert_eq!(
549            cpu_merkle_tree.root(),
550            gpu_merkle_tree.top_roots.to_host().unwrap()[0]
551        );
552        eprintln!("{:?}", cpu_merkle_tree.root());
553        eprintln!("{:?}", gpu_merkle_tree.top_roots.to_host().unwrap()[0]);
554
555        // Now we add some touched memory
556        // We don't care about the memory layout and whatnot, because neither implementation uses
557        // any special form of the touched blocks
558        let touched_ptrs = mem_config
559            .addr_spaces
560            .iter()
561            .enumerate()
562            .flat_map(|(i, cnf)| {
563                let mut ptrs = Vec::new();
564                for j in 0..(cnf.num_cells / DIGEST_WIDTH) {
565                    if rng.gen_bool(0.333) {
566                        ptrs.push((i as u32, (j * DIGEST_WIDTH) as u32));
567                    }
568                }
569                ptrs
570            })
571            .collect::<Vec<_>>();
572        let new_data = touched_ptrs
573            .iter()
574            .map(|_| std::array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..F::ORDER_U32))))
575            .collect::<Vec<[F; DIGEST_WIDTH]>>();
576        assert!(!touched_ptrs.is_empty());
577        cpu_merkle_tree.finalize(
578            &cpu_hasher_chip,
579            &(touched_ptrs
580                .iter()
581                .copied()
582                .zip(new_data.iter().copied())
583                .collect()),
584            &mem_config.memory_dimensions(),
585        );
586        let touched_blocks = touched_ptrs
587            .into_iter()
588            .zip(new_data)
589            .map(|(address, data)| {
590                (
591                    address,
592                    TimestampedValues {
593                        timestamp: rng.gen_range(0..(1u32 << mem_config.timestamp_max_bits)),
594                        values: data,
595                    },
596                )
597            })
598            .collect::<Vec<_>>();
599        let d_touched_blocks = touched_blocks.to_device().unwrap().as_buffer::<u32>();
600
601        gpu_merkle_tree.update_with_touched_blocks(
602            gpu_merkle_tree.calculate_unpadded_height(&touched_blocks),
603            &d_touched_blocks,
604            false,
605        );
606
607        assert_eq!(
608            cpu_merkle_tree.root(),
609            gpu_merkle_tree.top_roots.to_host().unwrap()[0]
610        );
611        eprintln!("{:?}", cpu_merkle_tree.root());
612        eprintln!("{:?}", gpu_merkle_tree.top_roots.to_host().unwrap()[0]);
613    }
614}