openvm_circuit/system/cuda/merkle_tree/
mod.rs

1use std::{ffi::c_void, sync::Arc};
2
3use openvm_circuit::{
4    arch::{MemoryConfig, ADDR_SPACE_OFFSET},
5    system::memory::{merkle::MemoryMerkleCols, TimestampedEquipartition},
6    utils::next_power_of_two_or_zero,
7};
8use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
9use openvm_cuda_common::{
10    copy::{cuda_memcpy, MemCopyD2H, MemCopyH2D},
11    d_buffer::DeviceBuffer,
12    stream::{cudaStreamPerThread, default_stream_wait, CudaEvent, CudaStream},
13};
14use openvm_stark_backend::{
15    p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator},
16    p3_util::log2_ceil_usize,
17    prover::types::AirProvingContext,
18};
19use p3_field::FieldAlgebra;
20
21use super::{poseidon2::SharedBuffer, Poseidon2PeripheryChipGPU, DIGEST_WIDTH};
22
23pub mod cuda;
24use cuda::merkle_tree::*;
25
26type H = [F; DIGEST_WIDTH];
27pub const TIMESTAMPED_BLOCK_WIDTH: usize = 11;
28
29/// A Merkle subtree stored in a single flat buffer, combining a vertical path and a heap-ordered
30/// binary tree.
31///
32/// Memory layout:
33/// - The first `path_len` elements form a vertical path (one node per level), used when the actual
34///   size is smaller than the max size.
35/// - The remaining elements store the subtree nodes in heap-order (breadth-first), with `size`
36///   leaves and `2 * size - 1` total nodes.
37///
38/// The call of filling the buffer is done async on the new stream. Option<CudaEvent> is used to
39/// wait for the completion.
40pub struct MemoryMerkleSubTree {
41    pub stream: Arc<CudaStream>,
42    pub event: Option<CudaEvent>,
43    pub buf: DeviceBuffer<H>,
44    pub height: usize,
45    pub path_len: usize,
46}
47
48impl MemoryMerkleSubTree {
49    /// Constructs a new Merkle subtree with a vertical path and heap-ordered tree.
50    /// The buffer is sized based on the actual address space and the maximum size.
51    ///
52    /// `addr_space_size` is the number of leaf digest nodes necessary for this address space. The
53    /// `max_size` is the number of leaf digest nodes in the full balanced tree dictated by
54    /// `addr_space_height` from the `MemoryConfig`.
55    pub fn new(addr_space_size: usize, max_size: usize) -> Self {
56        assert!(
57            max_size.is_power_of_two(),
58            "Max address space size must be a power of two"
59        );
60        let size = next_power_of_two_or_zero(addr_space_size);
61        if addr_space_size == 0 {
62            let mut res = MemoryMerkleSubTree::dummy();
63            res.height = log2_ceil_usize(max_size);
64            return res;
65        }
66        let height = log2_ceil_usize(size);
67        let path_len = log2_ceil_usize(max_size).checked_sub(height).unwrap();
68        tracing::debug!(
69            "Creating a subtree buffer, size is {} (addr space size is {})",
70            path_len + (2 * size - 1),
71            addr_space_size
72        );
73        let buf = DeviceBuffer::<H>::with_capacity(path_len + (2 * size - 1));
74
75        let created_buffer_event = CudaEvent::new().unwrap();
76        unsafe {
77            created_buffer_event.record(cudaStreamPerThread).unwrap();
78        }
79
80        let stream = Arc::new(CudaStream::new().unwrap());
81        stream.wait(&created_buffer_event).unwrap();
82        Self {
83            stream,
84            event: None,
85            height,
86            buf,
87            path_len,
88        }
89    }
90
91    pub fn dummy() -> Self {
92        Self {
93            stream: Arc::new(CudaStream::new().unwrap()),
94            event: None,
95            height: 0,
96            buf: DeviceBuffer::new(),
97            path_len: 0,
98        }
99    }
100
101    /// Asynchronously builds the Merkle subtree on its dedicated CUDA stream.
102    /// Also reconstructs the vertical path if `path_len > 0`, and records a completion event.
103    ///
104    /// Here `addr_space_idx` is the address space _shifted_ by ADDR_SPACE_OFFSET = 1
105    pub fn build_async(
106        &mut self,
107        d_data: &DeviceBuffer<u8>,
108        addr_space_idx: usize,
109        zero_hash: &DeviceBuffer<H>,
110    ) {
111        let event = CudaEvent::new().unwrap();
112        if self.buf.is_empty() {
113            // TODO not really async in this branch is it
114            self.buf = DeviceBuffer::with_capacity(1);
115            unsafe {
116                cuda_memcpy::<true, true>(
117                    self.buf.as_mut_raw_ptr(),
118                    zero_hash.as_ptr().add(self.height) as *mut c_void,
119                    size_of::<H>(),
120                )
121                .unwrap();
122                event.record(cudaStreamPerThread).unwrap();
123            }
124        } else {
125            unsafe {
126                build_merkle_subtree(
127                    d_data,
128                    1 << self.height,
129                    &self.buf,
130                    self.path_len,
131                    addr_space_idx as u32,
132                    self.stream.as_raw(),
133                )
134                .unwrap();
135
136                if self.path_len > 0 {
137                    restore_merkle_subtree_path(
138                        &self.buf,
139                        zero_hash,
140                        self.path_len,
141                        self.height + self.path_len,
142                        self.stream.as_raw(),
143                    )
144                    .unwrap();
145                }
146                event.record(self.stream.as_raw()).unwrap();
147            }
148        }
149        self.event = Some(event);
150    }
151
152    /// Returns the bounds [start, end) of the layer at the given depth.
153    /// These bounds correspond to the indices of the layer in the buffer.
154    /// depth: 0 = root, 1 = root's children, ..., height-1 = leaves
155    pub fn layer_bounds(&self, depth: usize) -> (usize, usize) {
156        let global_height = self.height + self.path_len;
157        assert!(
158            depth < global_height,
159            "Depth {} out of bounds for height {}",
160            depth,
161            global_height
162        );
163        if depth >= self.path_len {
164            // depth is within the heap-ordered subtree
165            let d = depth - self.path_len;
166            let start = self.path_len + ((1 << d) - 1);
167            let end = self.path_len + ((1 << (d + 1)) - 1);
168            (start, end)
169        } else {
170            // vertical path layer: single node per level
171            (depth, depth + 1)
172        }
173    }
174}
175
176/// A Memory Merkle tree composed of independent subtrees (one per address space),
177/// each built asynchronously and finalized into a top-level Merkle root.
178///
179/// Layout:
180/// - The memory is split across multiple `MemoryMerkleSubTree` instances, one per address space.
181/// - The top-level tree is formed by hashing all subtree roots into a single buffer (`top_roots`).
182///     - top_roots layout: \[root, hash(root_addr_space_1, root_addr_space_2),
183///       hash(root_addr_space_3), hash(root_addr_space_4), ...\]
184///     - if we have > 4 address spaces, top_roots will be extended with the next hash, etc.
185///
186/// Execution:
187/// - Subtrees are built asynchronously on individual CUDA streams.
188/// - The final root is computed after all subtrees complete, on a shared stream.
189/// - `CudaEvent`s are used to synchronize subtree completion.
190pub struct MemoryMerkleTree {
191    pub stream: Arc<CudaStream>,
192    pub subtrees: Vec<MemoryMerkleSubTree>,
193    pub top_roots: DeviceBuffer<H>,
194    zero_hash: DeviceBuffer<H>,
195    pub height: usize,
196    pub hasher_buffer: SharedBuffer<F>,
197    mem_config: MemoryConfig,
198}
199
200impl MemoryMerkleTree {
201    /// Creates a full Merkle tree with one subtree per address space.
202    /// Initializes all buffers and precomputes the zero hash chain.
203    pub fn new(mem_config: MemoryConfig, hasher_chip: Arc<Poseidon2PeripheryChipGPU>) -> Self {
204        let addr_space_sizes = mem_config
205            .addr_spaces
206            .iter()
207            .map(|ashc| {
208                assert!(
209                    ashc.num_cells % DIGEST_WIDTH == 0,
210                    "the number of cells must be divisible by `DIGEST_WIDTH`"
211                );
212                ashc.num_cells / DIGEST_WIDTH
213            })
214            .collect::<Vec<_>>();
215        assert!(!(addr_space_sizes.is_empty()), "Invalid config");
216
217        let num_addr_spaces = addr_space_sizes.len() - ADDR_SPACE_OFFSET as usize;
218        assert!(
219            num_addr_spaces.is_power_of_two(),
220            "Number of address spaces must be a one plus power of two"
221        );
222        for &sz in addr_space_sizes.iter().take(ADDR_SPACE_OFFSET as usize) {
223            assert!(
224                sz == 0,
225                "The first `ADDR_SPACE_OFFSET` address spaces are assumed to be empty"
226            );
227        }
228
229        let label_max_bits = mem_config.pointer_max_bits - log2_ceil_usize(DIGEST_WIDTH);
230
231        let zero_hash = DeviceBuffer::<H>::with_capacity(label_max_bits + 1);
232        let top_roots = DeviceBuffer::<H>::with_capacity(2 * num_addr_spaces - 1);
233        unsafe {
234            calculate_zero_hash(&zero_hash, label_max_bits).unwrap();
235        }
236
237        Self {
238            stream: Arc::new(CudaStream::new().unwrap()),
239            subtrees: Vec::new(),
240            top_roots,
241            height: label_max_bits + log2_ceil_usize(num_addr_spaces),
242            zero_hash,
243            hasher_buffer: hasher_chip.shared_buffer(),
244            mem_config,
245        }
246    }
247
248    pub fn mem_config(&self) -> &MemoryConfig {
249        &self.mem_config
250    }
251
252    /// Starts asynchronous construction of the specified address space's Merkle subtree.
253    /// Uses internal zero hashes and launches kernels on the subtree's own CUDA stream.
254    ///
255    /// Here `addr_space` is the _unshifted_ address space, so `addr_space = 0` is the immediate
256    /// address space, which should be ignored.
257    pub fn build_async(&mut self, d_data: &DeviceBuffer<u8>, addr_space: usize) {
258        if addr_space < ADDR_SPACE_OFFSET as usize {
259            return;
260        }
261        let addr_space_idx = addr_space - ADDR_SPACE_OFFSET as usize;
262        if addr_space < self.mem_config.addr_spaces.len() && addr_space_idx == self.subtrees.len() {
263            let mut subtree = MemoryMerkleSubTree::new(
264                self.mem_config.addr_spaces[addr_space].num_cells / DIGEST_WIDTH,
265                1 << (self.zero_hash.len() - 1), /* label_max_bits */
266            );
267            subtree.build_async(d_data, addr_space_idx, &self.zero_hash);
268            self.subtrees.push(subtree);
269        } else {
270            panic!("Invalid address space index");
271        }
272    }
273
274    /// Finalizes the Merkle tree by collecting all subtree roots and computing the final root.
275    /// Waits for all subtrees to complete and then performs the final hash operation.
276    pub fn finalize(&self) {
277        for subtree in self.subtrees.iter() {
278            self.stream.wait(subtree.event.as_ref().unwrap()).unwrap();
279        }
280
281        let we_can_gather_bufs_event = CudaEvent::new().unwrap();
282        unsafe {
283            we_can_gather_bufs_event
284                .record(self.stream.as_raw())
285                .unwrap();
286        }
287        default_stream_wait(&we_can_gather_bufs_event).unwrap();
288
289        let roots: Vec<usize> = self
290            .subtrees
291            .iter()
292            .map(|subtree| subtree.buf.as_ptr() as usize)
293            .collect();
294        let d_roots = roots.to_device().unwrap();
295        let to_device_event = CudaEvent::new().unwrap();
296        unsafe {
297            to_device_event.record(cudaStreamPerThread).unwrap();
298        }
299        self.stream.wait(&to_device_event).unwrap();
300
301        unsafe {
302            finalize_merkle_tree(
303                &d_roots,
304                &self.top_roots,
305                self.subtrees.len(),
306                self.stream.as_raw(),
307            )
308            .unwrap();
309        }
310
311        self.stream.synchronize().unwrap();
312    }
313
314    /// Drops all massive buffers to free memory. Used at the end of an execution segment.
315    pub fn drop_subtrees(&mut self) {
316        self.subtrees = Vec::new();
317    }
318
319    /// Updates the tree and returns the merkle trace.
320    pub fn update_with_touched_blocks(
321        &self,
322        unpadded_height: usize,
323        d_touched_blocks: &DeviceBuffer<u32>, // consists of (as, label, ts, [F; 8])
324        empty_touched_blocks: bool,
325    ) -> AirProvingContext<GpuBackend> {
326        let mut public_values = self.top_roots.to_host().unwrap()[0].to_vec();
327        let merkle_trace = {
328            let width = MemoryMerkleCols::<u8, DIGEST_WIDTH>::width();
329            let padded_height = next_power_of_two_or_zero(unpadded_height);
330            let output = DeviceMatrix::<F>::with_capacity(padded_height, width);
331            output.buffer().fill_zero().unwrap();
332
333            let actual_heights = self.subtrees.iter().map(|s| s.height).collect::<Vec<_>>();
334            let subtrees_pointers = self
335                .subtrees
336                .iter()
337                .map(|st| st.buf.as_ptr() as usize)
338                .collect::<Vec<_>>()
339                .to_device()
340                .unwrap();
341            unsafe {
342                update_merkle_tree(
343                    &output,
344                    &subtrees_pointers,
345                    &self.top_roots,
346                    &self.zero_hash,
347                    d_touched_blocks,
348                    self.height - log2_ceil_usize(self.subtrees.len()),
349                    &actual_heights,
350                    unpadded_height,
351                    &self.hasher_buffer,
352                )
353                .unwrap();
354            }
355
356            if empty_touched_blocks {
357                // The trace is small then
358                let mut output_vec = output.to_host().unwrap();
359                output_vec[unpadded_height - 1 + (width - 2) * padded_height] = F::ONE; // left_direction_different
360                output_vec[unpadded_height - 1 + (width - 1) * padded_height] = F::ONE; // right_direction_different
361                DeviceMatrix::new(
362                    Arc::new(output_vec.to_device().unwrap()),
363                    padded_height,
364                    width,
365                )
366            } else {
367                output
368            }
369        };
370        public_values.extend(self.top_roots.to_host().unwrap()[0].to_vec());
371
372        AirProvingContext::new(Vec::new(), Some(merkle_trace), public_values)
373    }
374
375    /// An auxiliary function to calculate the required number of rows for the merkle trace.
376    pub fn calculate_unpadded_height(
377        &self,
378        touched_memory: &TimestampedEquipartition<F, DIGEST_WIDTH>,
379    ) -> usize {
380        let md = self.mem_config.memory_dimensions();
381        let tree_height = md.overall_height();
382        let shift_address = |(sp, ptr): (u32, u32)| (sp, ptr / DIGEST_WIDTH as u32);
383        2 * if touched_memory.is_empty() {
384            tree_height
385        } else {
386            tree_height
387                + (0..(touched_memory.len() - 1))
388                    .into_par_iter()
389                    .map(|i| {
390                        let x = md.label_to_index(shift_address(touched_memory[i].0));
391                        let y = md.label_to_index(shift_address(touched_memory[i + 1].0));
392                        (x ^ y).ilog2() as usize
393                    })
394                    .sum::<usize>()
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use std::sync::Arc;
402
403    use openvm_circuit::{
404        arch::{
405            testing::POSEIDON2_DIRECT_BUS, vm_poseidon2_config, AddressSpaceHostLayout,
406            MemoryCellType, MemoryConfig,
407        },
408        system::{
409            memory::{
410                merkle::MerkleTree,
411                online::{GuestMemory, LinearMemory},
412                AddressMap, TimestampedValues,
413            },
414            poseidon2::Poseidon2PeripheryChip,
415        },
416    };
417    use openvm_cuda_backend::prelude::F;
418    use openvm_cuda_common::{
419        copy::{MemCopyD2H, MemCopyH2D},
420        d_buffer::DeviceBuffer,
421    };
422    use openvm_instructions::{
423        riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
424        NATIVE_AS,
425    };
426    use openvm_stark_sdk::utils::create_seeded_rng;
427    use p3_field::{FieldAlgebra, PrimeField32};
428    use rand::Rng;
429
430    use super::MemoryMerkleTree;
431    use crate::system::cuda::{Poseidon2PeripheryChipGPU, DIGEST_WIDTH};
432
433    #[test]
434    fn test_cuda_merkle_tree_cpu_gpu_root_equivalence() {
435        let mut rng = create_seeded_rng();
436        let mem_config = {
437            let mut addr_spaces = MemoryConfig::empty_address_space_configs(5);
438            let max_cells = 1 << 16;
439            addr_spaces[RV32_REGISTER_AS as usize].num_cells = 32 * size_of::<u32>();
440            addr_spaces[RV32_MEMORY_AS as usize].num_cells = max_cells;
441            addr_spaces[NATIVE_AS as usize].num_cells = max_cells;
442            MemoryConfig::new(2, addr_spaces, max_cells.ilog2() as usize, 29, 17, 32)
443        };
444
445        let mut initial_memory = GuestMemory::new(AddressMap::from_mem_config(&mem_config));
446        for (idx, space) in mem_config.addr_spaces.iter().enumerate() {
447            unsafe {
448                match space.layout {
449                    MemoryCellType::Null => {}
450                    MemoryCellType::U8 => {
451                        for i in 0..space.num_cells {
452                            initial_memory.write::<u8, 1>(
453                                idx as u32,
454                                i as u32,
455                                [rng.gen_range(0..space.layout.size()) as u8],
456                            );
457                        }
458                    }
459                    MemoryCellType::U16 => {
460                        for i in 0..space.num_cells {
461                            initial_memory.write::<u16, 1>(
462                                idx as u32,
463                                i as u32,
464                                [rng.gen_range(0..space.layout.size()) as u16],
465                            );
466                        }
467                    }
468                    MemoryCellType::U32 => {
469                        for i in 0..space.num_cells {
470                            initial_memory.write::<u32, 1>(
471                                idx as u32,
472                                i as u32,
473                                [rng.gen_range(0..space.layout.size()) as u32],
474                            );
475                        }
476                    }
477                    MemoryCellType::Native { .. } => {
478                        for i in 0..space.num_cells {
479                            initial_memory.write::<F, 1>(
480                                idx as u32,
481                                i as u32,
482                                [F::from_canonical_u32(rng.gen_range(0..F::ORDER_U32))],
483                            );
484                        }
485                    }
486                }
487            }
488        }
489
490        let gpu_hasher_chip = Arc::new(Poseidon2PeripheryChipGPU::new(
491            (mem_config
492                .addr_spaces
493                .iter()
494                .map(|ashc| ashc.num_cells * 2 + mem_config.memory_dimensions().overall_height())
495                .sum::<usize>()
496                * 2)
497            .next_power_of_two()
498                * 2
499                * DIGEST_WIDTH, // max_buffer_size
500            1, // sbox_regs
501        ));
502        let mut gpu_merkle_tree = MemoryMerkleTree::new(mem_config.clone(), gpu_hasher_chip);
503        for (i, mem) in initial_memory.memory.get_memory().iter().enumerate() {
504            let mem_slice = mem.as_slice();
505            gpu_merkle_tree.build_async(
506                &(if !mem_slice.is_empty() {
507                    mem_slice.to_device().unwrap()
508                } else {
509                    DeviceBuffer::new()
510                }),
511                i,
512            );
513        }
514        gpu_merkle_tree.finalize();
515
516        let cpu_hasher_chip =
517            Poseidon2PeripheryChip::new(vm_poseidon2_config(), POSEIDON2_DIRECT_BUS, 3);
518        let mut cpu_merkle_tree = MerkleTree::<F, DIGEST_WIDTH>::from_memory(
519            &initial_memory.memory,
520            &mem_config.memory_dimensions(),
521            &cpu_hasher_chip,
522        );
523
524        assert_eq!(
525            cpu_merkle_tree.root(),
526            gpu_merkle_tree.top_roots.to_host().unwrap()[0]
527        );
528        eprintln!("{:?}", cpu_merkle_tree.root());
529        eprintln!("{:?}", gpu_merkle_tree.top_roots.to_host().unwrap()[0]);
530
531        // Now we add some touched memory
532        // We don't care about the memory layout and whatnot, because neither implementation uses
533        // any special form of the touched blocks
534        let touched_ptrs = mem_config
535            .addr_spaces
536            .iter()
537            .enumerate()
538            .flat_map(|(i, cnf)| {
539                let mut ptrs = Vec::new();
540                for j in 0..(cnf.num_cells / DIGEST_WIDTH) {
541                    if rng.gen_bool(0.333) {
542                        ptrs.push((i as u32, (j * DIGEST_WIDTH) as u32));
543                    }
544                }
545                ptrs
546            })
547            .collect::<Vec<_>>();
548        let new_data = touched_ptrs
549            .iter()
550            .map(|_| std::array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..F::ORDER_U32))))
551            .collect::<Vec<[F; DIGEST_WIDTH]>>();
552        assert!(!touched_ptrs.is_empty());
553        cpu_merkle_tree.finalize(
554            &cpu_hasher_chip,
555            &(touched_ptrs
556                .iter()
557                .copied()
558                .zip(new_data.iter().copied())
559                .collect()),
560            &mem_config.memory_dimensions(),
561        );
562        let touched_blocks = touched_ptrs
563            .into_iter()
564            .zip(new_data)
565            .map(|(address, data)| {
566                (
567                    address,
568                    TimestampedValues {
569                        timestamp: rng.gen_range(0..(1u32 << mem_config.timestamp_max_bits)),
570                        values: data,
571                    },
572                )
573            })
574            .collect::<Vec<_>>();
575        let d_touched_blocks = touched_blocks.to_device().unwrap().as_buffer::<u32>();
576
577        gpu_merkle_tree.update_with_touched_blocks(
578            gpu_merkle_tree.calculate_unpadded_height(&touched_blocks),
579            &d_touched_blocks,
580            false,
581        );
582
583        assert_eq!(
584            cpu_merkle_tree.root(),
585            gpu_merkle_tree.top_roots.to_host().unwrap()[0]
586        );
587        eprintln!("{:?}", cpu_merkle_tree.root());
588        eprintln!("{:?}", gpu_merkle_tree.top_roots.to_host().unwrap()[0]);
589    }
590}