1use std::{ffi::c_void, sync::Arc};
2
3use openvm_circuit::{
4 arch::{MemoryConfig, ADDR_SPACE_OFFSET},
5 system::memory::{merkle::MemoryMerkleCols, TimestampedEquipartition},
6 utils::next_power_of_two_or_zero,
7};
8use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
9use openvm_cuda_common::{
10 copy::{cuda_memcpy, MemCopyD2H, MemCopyH2D},
11 d_buffer::DeviceBuffer,
12 stream::{
13 cudaStreamPerThread, current_stream_sync, default_stream_wait, CudaEvent, CudaStream,
14 },
15};
16use openvm_stark_backend::{
17 p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator},
18 p3_util::log2_ceil_usize,
19 prover::types::AirProvingContext,
20};
21use p3_field::FieldAlgebra;
22
23use super::{poseidon2::SharedBuffer, Poseidon2PeripheryChipGPU, DIGEST_WIDTH};
24
25pub mod cuda;
26use cuda::merkle_tree::*;
27
28type H = [F; DIGEST_WIDTH];
29pub const TIMESTAMPED_BLOCK_WIDTH: usize = 11;
30
31pub struct MemoryMerkleSubTree {
43 pub stream: Arc<CudaStream>,
44 #[allow(dead_code)] created_buffer_event: Option<CudaEvent>,
46 build_completion_event: Option<CudaEvent>,
47 pub buf: DeviceBuffer<H>,
48 pub height: usize,
49 pub path_len: usize,
50}
51
52impl MemoryMerkleSubTree {
53 pub fn new(addr_space_size: usize, max_size: usize) -> Self {
60 assert!(
61 max_size.is_power_of_two(),
62 "Max address space size must be a power of two"
63 );
64 let size = next_power_of_two_or_zero(addr_space_size);
65 if addr_space_size == 0 {
66 let mut res = MemoryMerkleSubTree::dummy();
67 res.height = log2_ceil_usize(max_size);
68 return res;
69 }
70 let height = log2_ceil_usize(size);
71 let path_len = log2_ceil_usize(max_size).checked_sub(height).unwrap();
72 tracing::debug!(
73 "Creating a subtree buffer, size is {} (addr space size is {})",
74 path_len + (2 * size - 1),
75 addr_space_size
76 );
77 let buf = DeviceBuffer::<H>::with_capacity(path_len + (2 * size - 1));
78
79 let created_buffer_event = CudaEvent::new().unwrap();
80 unsafe {
81 created_buffer_event.record(cudaStreamPerThread).unwrap();
82 }
83
84 let stream = Arc::new(CudaStream::new().unwrap());
85 stream.wait(&created_buffer_event).unwrap();
86 Self {
87 stream,
88 created_buffer_event: Some(created_buffer_event),
89 build_completion_event: None,
90 height,
91 buf,
92 path_len,
93 }
94 }
95
96 pub fn dummy() -> Self {
97 Self {
98 stream: Arc::new(CudaStream::new().unwrap()),
99 created_buffer_event: None,
100 build_completion_event: None,
101 height: 0,
102 buf: DeviceBuffer::new(),
103 path_len: 0,
104 }
105 }
106
107 pub fn build_async(
112 &mut self,
113 d_data: &DeviceBuffer<u8>,
114 addr_space_idx: usize,
115 zero_hash: &DeviceBuffer<H>,
116 ) {
117 let event = CudaEvent::new().unwrap();
118 if self.buf.is_empty() {
119 self.buf = DeviceBuffer::with_capacity(1);
121 unsafe {
122 cuda_memcpy::<true, true>(
123 self.buf.as_mut_raw_ptr(),
124 zero_hash.as_ptr().add(self.height) as *mut c_void,
125 size_of::<H>(),
126 )
127 .unwrap();
128 event.record(cudaStreamPerThread).unwrap();
129 }
130 } else {
131 unsafe {
132 build_merkle_subtree(
133 d_data,
134 1 << self.height,
135 &self.buf,
136 self.path_len,
137 addr_space_idx as u32,
138 self.stream.as_raw(),
139 )
140 .unwrap();
141
142 if self.path_len > 0 {
143 restore_merkle_subtree_path(
144 &self.buf,
145 zero_hash,
146 self.path_len,
147 self.height + self.path_len,
148 self.stream.as_raw(),
149 )
150 .unwrap();
151 }
152 event.record(self.stream.as_raw()).unwrap();
153 }
154 }
155 self.build_completion_event = Some(event);
156 }
157
158 pub fn layer_bounds(&self, depth: usize) -> (usize, usize) {
162 let global_height = self.height + self.path_len;
163 assert!(
164 depth < global_height,
165 "Depth {depth} out of bounds for height {global_height}",
166 );
167 if depth >= self.path_len {
168 let d = depth - self.path_len;
170 let start = self.path_len + ((1 << d) - 1);
171 let end = self.path_len + ((1 << (d + 1)) - 1);
172 (start, end)
173 } else {
174 (depth, depth + 1)
176 }
177 }
178}
179
180pub struct MemoryMerkleTree {
195 pub subtrees: Vec<MemoryMerkleSubTree>,
196 pub top_roots: DeviceBuffer<H>,
197 zero_hash: DeviceBuffer<H>,
198 pub height: usize,
199 pub hasher_buffer: SharedBuffer<F>,
200 mem_config: MemoryConfig,
201 pub(crate) top_roots_host: Vec<H>,
202}
203
204impl MemoryMerkleTree {
205 pub fn new(mem_config: MemoryConfig, hasher_chip: Arc<Poseidon2PeripheryChipGPU>) -> Self {
208 let addr_space_sizes = mem_config
209 .addr_spaces
210 .iter()
211 .map(|ashc| {
212 assert!(
213 ashc.num_cells % DIGEST_WIDTH == 0,
214 "the number of cells must be divisible by `DIGEST_WIDTH`"
215 );
216 ashc.num_cells / DIGEST_WIDTH
217 })
218 .collect::<Vec<_>>();
219 assert!(!(addr_space_sizes.is_empty()), "Invalid config");
220
221 let num_addr_spaces = addr_space_sizes.len() - ADDR_SPACE_OFFSET as usize;
222 assert!(
223 num_addr_spaces.is_power_of_two(),
224 "Number of address spaces must be a one plus power of two"
225 );
226 for &sz in addr_space_sizes.iter().take(ADDR_SPACE_OFFSET as usize) {
227 assert!(
228 sz == 0,
229 "The first `ADDR_SPACE_OFFSET` address spaces are assumed to be empty"
230 );
231 }
232
233 let label_max_bits = mem_config.pointer_max_bits - log2_ceil_usize(DIGEST_WIDTH);
234
235 let zero_hash = DeviceBuffer::<H>::with_capacity(label_max_bits + 1);
236 let top_roots = DeviceBuffer::<H>::with_capacity(2 * num_addr_spaces - 1);
237 unsafe {
238 calculate_zero_hash(&zero_hash, label_max_bits).unwrap();
239 }
240
241 Self {
242 subtrees: Vec::new(),
243 top_roots,
244 height: label_max_bits + log2_ceil_usize(num_addr_spaces),
245 zero_hash,
246 hasher_buffer: hasher_chip.shared_buffer(),
247 mem_config,
248 top_roots_host: vec![],
249 }
250 }
251
252 pub fn mem_config(&self) -> &MemoryConfig {
253 &self.mem_config
254 }
255
256 pub fn build_async(&mut self, d_data: &DeviceBuffer<u8>, addr_space: usize) {
262 if addr_space < ADDR_SPACE_OFFSET as usize {
263 return;
264 }
265 let addr_space_idx = addr_space - ADDR_SPACE_OFFSET as usize;
266 if addr_space < self.mem_config.addr_spaces.len() && addr_space_idx == self.subtrees.len() {
267 let mut subtree = MemoryMerkleSubTree::new(
268 self.mem_config.addr_spaces[addr_space].num_cells / DIGEST_WIDTH,
269 1 << (self.zero_hash.len() - 1), );
271 subtree.build_async(d_data, addr_space_idx, &self.zero_hash);
272 self.subtrees.push(subtree);
273 } else {
274 panic!("Invalid address space index");
275 }
276 }
277
278 pub fn finalize(&mut self) {
281 for subtree in self.subtrees.iter() {
283 default_stream_wait(
284 subtree
285 .build_completion_event
286 .as_ref()
287 .expect("Subtree build event does not exist"),
288 )
289 .unwrap();
290 }
291
292 let roots: Vec<usize> = self
293 .subtrees
294 .iter()
295 .map(|subtree| subtree.buf.as_ptr() as usize)
296 .collect();
297 let d_roots = roots.to_device().unwrap();
298
299 unsafe {
300 finalize_merkle_tree(
301 &d_roots,
302 &self.top_roots,
303 self.subtrees.len(),
304 cudaStreamPerThread,
305 )
306 .unwrap();
307 }
308 }
309
310 pub fn drop_subtrees(&mut self) {
316 for subtree in self.subtrees.iter() {
318 subtree.stream.synchronize().unwrap();
319 if let Some(_event) = subtree.build_completion_event.as_ref() {
320 tracing::warn!(
321 "Dropping merkle subtree before build_async event has been destroyed"
322 );
323 }
324 }
325 current_stream_sync().unwrap();
326 self.subtrees.clear();
329 }
330
331 pub fn update_with_touched_blocks(
333 &mut self,
334 unpadded_height: usize,
335 d_touched_blocks: &DeviceBuffer<u32>, empty_touched_blocks: bool,
337 ) -> AirProvingContext<GpuBackend> {
338 let mut public_values = self.top_roots.to_host().unwrap()[0].to_vec();
339 for subtree in &mut self.subtrees {
342 subtree.build_completion_event = None;
343 }
344 let merkle_trace = {
345 let width = MemoryMerkleCols::<u8, DIGEST_WIDTH>::width();
346 let padded_height = next_power_of_two_or_zero(unpadded_height);
347 let output = DeviceMatrix::<F>::with_capacity(padded_height, width);
348 output.buffer().fill_zero().unwrap();
349
350 let actual_heights = self.subtrees.iter().map(|s| s.height).collect::<Vec<_>>();
351 let subtrees_pointers = self
352 .subtrees
353 .iter()
354 .map(|st| st.buf.as_ptr() as usize)
355 .collect::<Vec<_>>()
356 .to_device()
357 .unwrap();
358 unsafe {
359 update_merkle_tree(
360 &output,
361 &subtrees_pointers,
362 &self.top_roots,
363 &self.zero_hash,
364 d_touched_blocks,
365 self.height - log2_ceil_usize(self.subtrees.len()),
366 &actual_heights,
367 unpadded_height,
368 &self.hasher_buffer,
369 )
370 .unwrap();
371 }
372
373 if empty_touched_blocks {
374 let mut output_vec = output.to_host().unwrap();
376 output_vec[unpadded_height - 1 + (width - 2) * padded_height] = F::ONE; output_vec[unpadded_height - 1 + (width - 1) * padded_height] = F::ONE; DeviceMatrix::new(
379 Arc::new(output_vec.to_device().unwrap()),
380 padded_height,
381 width,
382 )
383 } else {
384 output
385 }
386 };
387 self.top_roots_host = self.top_roots.to_host().unwrap();
388 public_values.extend(self.top_roots_host[0]);
389
390 AirProvingContext::new(Vec::new(), Some(merkle_trace), public_values)
391 }
392
393 pub fn calculate_unpadded_height(
395 &self,
396 touched_memory: &TimestampedEquipartition<F, DIGEST_WIDTH>,
397 ) -> usize {
398 let md = self.mem_config.memory_dimensions();
399 let tree_height = md.overall_height();
400 let shift_address = |(sp, ptr): (u32, u32)| (sp, ptr / DIGEST_WIDTH as u32);
401 2 * if touched_memory.is_empty() {
402 tree_height
403 } else {
404 tree_height
405 + (0..(touched_memory.len() - 1))
406 .into_par_iter()
407 .map(|i| {
408 let x = md.label_to_index(shift_address(touched_memory[i].0));
409 let y = md.label_to_index(shift_address(touched_memory[i + 1].0));
410 (x ^ y).ilog2() as usize
411 })
412 .sum::<usize>()
413 }
414 }
415}
416
417impl Drop for MemoryMerkleTree {
418 fn drop(&mut self) {
419 self.drop_subtrees();
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use std::sync::Arc;
426
427 use openvm_circuit::{
428 arch::{
429 testing::POSEIDON2_DIRECT_BUS, vm_poseidon2_config, AddressSpaceHostLayout,
430 MemoryCellType, MemoryConfig,
431 },
432 system::{
433 memory::{
434 merkle::MerkleTree,
435 online::{GuestMemory, LinearMemory},
436 AddressMap, TimestampedValues,
437 },
438 poseidon2::Poseidon2PeripheryChip,
439 },
440 };
441 use openvm_cuda_backend::prelude::F;
442 use openvm_cuda_common::{
443 copy::{MemCopyD2H, MemCopyH2D},
444 d_buffer::DeviceBuffer,
445 };
446 use openvm_instructions::{
447 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
448 NATIVE_AS,
449 };
450 use openvm_stark_sdk::utils::create_seeded_rng;
451 use p3_field::{FieldAlgebra, PrimeField32};
452 use rand::Rng;
453
454 use super::MemoryMerkleTree;
455 use crate::system::cuda::{Poseidon2PeripheryChipGPU, DIGEST_WIDTH};
456
457 #[test]
458 fn test_cuda_merkle_tree_cpu_gpu_root_equivalence() {
459 let mut rng = create_seeded_rng();
460 let mem_config = {
461 let mut addr_spaces = MemoryConfig::empty_address_space_configs(5);
462 let max_cells = 1 << 16;
463 addr_spaces[RV32_REGISTER_AS as usize].num_cells = 32 * size_of::<u32>();
464 addr_spaces[RV32_MEMORY_AS as usize].num_cells = max_cells;
465 addr_spaces[NATIVE_AS as usize].num_cells = max_cells;
466 MemoryConfig::new(2, addr_spaces, max_cells.ilog2() as usize, 29, 17, 32)
467 };
468
469 let mut initial_memory = GuestMemory::new(AddressMap::from_mem_config(&mem_config));
470 for (idx, space) in mem_config.addr_spaces.iter().enumerate() {
471 unsafe {
472 match space.layout {
473 MemoryCellType::Null => {}
474 MemoryCellType::U8 => {
475 for i in 0..space.num_cells {
476 initial_memory.write::<u8, 1>(
477 idx as u32,
478 i as u32,
479 [rng.gen_range(0..space.layout.size()) as u8],
480 );
481 }
482 }
483 MemoryCellType::U16 => {
484 for i in 0..space.num_cells {
485 initial_memory.write::<u16, 1>(
486 idx as u32,
487 i as u32,
488 [rng.gen_range(0..space.layout.size()) as u16],
489 );
490 }
491 }
492 MemoryCellType::U32 => {
493 for i in 0..space.num_cells {
494 initial_memory.write::<u32, 1>(
495 idx as u32,
496 i as u32,
497 [rng.gen_range(0..space.layout.size()) as u32],
498 );
499 }
500 }
501 MemoryCellType::Native { .. } => {
502 for i in 0..space.num_cells {
503 initial_memory.write::<F, 1>(
504 idx as u32,
505 i as u32,
506 [F::from_canonical_u32(rng.gen_range(0..F::ORDER_U32))],
507 );
508 }
509 }
510 }
511 }
512 }
513
514 let gpu_hasher_chip = Arc::new(Poseidon2PeripheryChipGPU::new(
515 (mem_config
516 .addr_spaces
517 .iter()
518 .map(|ashc| ashc.num_cells * 2 + mem_config.memory_dimensions().overall_height())
519 .sum::<usize>()
520 * 2)
521 .next_power_of_two()
522 * 2
523 * DIGEST_WIDTH, 1, ));
526 let mut gpu_merkle_tree = MemoryMerkleTree::new(mem_config.clone(), gpu_hasher_chip);
527 for (i, mem) in initial_memory.memory.get_memory().iter().enumerate() {
528 let mem_slice = mem.as_slice();
529 gpu_merkle_tree.build_async(
530 &(if !mem_slice.is_empty() {
531 mem_slice.to_device().unwrap()
532 } else {
533 DeviceBuffer::new()
534 }),
535 i,
536 );
537 }
538 gpu_merkle_tree.finalize();
539
540 let cpu_hasher_chip =
541 Poseidon2PeripheryChip::new(vm_poseidon2_config(), POSEIDON2_DIRECT_BUS, 3);
542 let mut cpu_merkle_tree = MerkleTree::<F, DIGEST_WIDTH>::from_memory(
543 &initial_memory.memory,
544 &mem_config.memory_dimensions(),
545 &cpu_hasher_chip,
546 );
547
548 assert_eq!(
549 cpu_merkle_tree.root(),
550 gpu_merkle_tree.top_roots.to_host().unwrap()[0]
551 );
552 eprintln!("{:?}", cpu_merkle_tree.root());
553 eprintln!("{:?}", gpu_merkle_tree.top_roots.to_host().unwrap()[0]);
554
555 let touched_ptrs = mem_config
559 .addr_spaces
560 .iter()
561 .enumerate()
562 .flat_map(|(i, cnf)| {
563 let mut ptrs = Vec::new();
564 for j in 0..(cnf.num_cells / DIGEST_WIDTH) {
565 if rng.gen_bool(0.333) {
566 ptrs.push((i as u32, (j * DIGEST_WIDTH) as u32));
567 }
568 }
569 ptrs
570 })
571 .collect::<Vec<_>>();
572 let new_data = touched_ptrs
573 .iter()
574 .map(|_| std::array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..F::ORDER_U32))))
575 .collect::<Vec<[F; DIGEST_WIDTH]>>();
576 assert!(!touched_ptrs.is_empty());
577 cpu_merkle_tree.finalize(
578 &cpu_hasher_chip,
579 &(touched_ptrs
580 .iter()
581 .copied()
582 .zip(new_data.iter().copied())
583 .collect()),
584 &mem_config.memory_dimensions(),
585 );
586 let touched_blocks = touched_ptrs
587 .into_iter()
588 .zip(new_data)
589 .map(|(address, data)| {
590 (
591 address,
592 TimestampedValues {
593 timestamp: rng.gen_range(0..(1u32 << mem_config.timestamp_max_bits)),
594 values: data,
595 },
596 )
597 })
598 .collect::<Vec<_>>();
599 let d_touched_blocks = touched_blocks.to_device().unwrap().as_buffer::<u32>();
600
601 gpu_merkle_tree.update_with_touched_blocks(
602 gpu_merkle_tree.calculate_unpadded_height(&touched_blocks),
603 &d_touched_blocks,
604 false,
605 );
606
607 assert_eq!(
608 cpu_merkle_tree.root(),
609 gpu_merkle_tree.top_roots.to_host().unwrap()[0]
610 );
611 eprintln!("{:?}", cpu_merkle_tree.root());
612 eprintln!("{:?}", gpu_merkle_tree.top_roots.to_host().unwrap()[0]);
613 }
614}