1use std::{ffi::c_void, sync::Arc};
2
3use openvm_circuit::{
4 arch::{MemoryConfig, ADDR_SPACE_OFFSET},
5 system::memory::{merkle::MemoryMerkleCols, TimestampedEquipartition},
6 utils::next_power_of_two_or_zero,
7};
8use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
9use openvm_cuda_common::{
10 copy::{cuda_memcpy, MemCopyD2H, MemCopyH2D},
11 d_buffer::DeviceBuffer,
12 stream::{cudaStreamPerThread, default_stream_wait, CudaEvent, CudaStream},
13};
14use openvm_stark_backend::{
15 p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator},
16 p3_util::log2_ceil_usize,
17 prover::types::AirProvingContext,
18};
19use p3_field::FieldAlgebra;
20
21use super::{poseidon2::SharedBuffer, Poseidon2PeripheryChipGPU, DIGEST_WIDTH};
22
23pub mod cuda;
24use cuda::merkle_tree::*;
25
26type H = [F; DIGEST_WIDTH];
27pub const TIMESTAMPED_BLOCK_WIDTH: usize = 11;
28
29pub struct MemoryMerkleSubTree {
41 pub stream: Arc<CudaStream>,
42 pub event: Option<CudaEvent>,
43 pub buf: DeviceBuffer<H>,
44 pub height: usize,
45 pub path_len: usize,
46}
47
48impl MemoryMerkleSubTree {
49 pub fn new(addr_space_size: usize, max_size: usize) -> Self {
56 assert!(
57 max_size.is_power_of_two(),
58 "Max address space size must be a power of two"
59 );
60 let size = next_power_of_two_or_zero(addr_space_size);
61 if addr_space_size == 0 {
62 let mut res = MemoryMerkleSubTree::dummy();
63 res.height = log2_ceil_usize(max_size);
64 return res;
65 }
66 let height = log2_ceil_usize(size);
67 let path_len = log2_ceil_usize(max_size).checked_sub(height).unwrap();
68 tracing::debug!(
69 "Creating a subtree buffer, size is {} (addr space size is {})",
70 path_len + (2 * size - 1),
71 addr_space_size
72 );
73 let buf = DeviceBuffer::<H>::with_capacity(path_len + (2 * size - 1));
74
75 let created_buffer_event = CudaEvent::new().unwrap();
76 unsafe {
77 created_buffer_event.record(cudaStreamPerThread).unwrap();
78 }
79
80 let stream = Arc::new(CudaStream::new().unwrap());
81 stream.wait(&created_buffer_event).unwrap();
82 Self {
83 stream,
84 event: None,
85 height,
86 buf,
87 path_len,
88 }
89 }
90
91 pub fn dummy() -> Self {
92 Self {
93 stream: Arc::new(CudaStream::new().unwrap()),
94 event: None,
95 height: 0,
96 buf: DeviceBuffer::new(),
97 path_len: 0,
98 }
99 }
100
101 pub fn build_async(
106 &mut self,
107 d_data: &DeviceBuffer<u8>,
108 addr_space_idx: usize,
109 zero_hash: &DeviceBuffer<H>,
110 ) {
111 let event = CudaEvent::new().unwrap();
112 if self.buf.is_empty() {
113 self.buf = DeviceBuffer::with_capacity(1);
115 unsafe {
116 cuda_memcpy::<true, true>(
117 self.buf.as_mut_raw_ptr(),
118 zero_hash.as_ptr().add(self.height) as *mut c_void,
119 size_of::<H>(),
120 )
121 .unwrap();
122 event.record(cudaStreamPerThread).unwrap();
123 }
124 } else {
125 unsafe {
126 build_merkle_subtree(
127 d_data,
128 1 << self.height,
129 &self.buf,
130 self.path_len,
131 addr_space_idx as u32,
132 self.stream.as_raw(),
133 )
134 .unwrap();
135
136 if self.path_len > 0 {
137 restore_merkle_subtree_path(
138 &self.buf,
139 zero_hash,
140 self.path_len,
141 self.height + self.path_len,
142 self.stream.as_raw(),
143 )
144 .unwrap();
145 }
146 event.record(self.stream.as_raw()).unwrap();
147 }
148 }
149 self.event = Some(event);
150 }
151
152 pub fn layer_bounds(&self, depth: usize) -> (usize, usize) {
156 let global_height = self.height + self.path_len;
157 assert!(
158 depth < global_height,
159 "Depth {} out of bounds for height {}",
160 depth,
161 global_height
162 );
163 if depth >= self.path_len {
164 let d = depth - self.path_len;
166 let start = self.path_len + ((1 << d) - 1);
167 let end = self.path_len + ((1 << (d + 1)) - 1);
168 (start, end)
169 } else {
170 (depth, depth + 1)
172 }
173 }
174}
175
176pub struct MemoryMerkleTree {
191 pub stream: Arc<CudaStream>,
192 pub subtrees: Vec<MemoryMerkleSubTree>,
193 pub top_roots: DeviceBuffer<H>,
194 zero_hash: DeviceBuffer<H>,
195 pub height: usize,
196 pub hasher_buffer: SharedBuffer<F>,
197 mem_config: MemoryConfig,
198}
199
200impl MemoryMerkleTree {
201 pub fn new(mem_config: MemoryConfig, hasher_chip: Arc<Poseidon2PeripheryChipGPU>) -> Self {
204 let addr_space_sizes = mem_config
205 .addr_spaces
206 .iter()
207 .map(|ashc| {
208 assert!(
209 ashc.num_cells % DIGEST_WIDTH == 0,
210 "the number of cells must be divisible by `DIGEST_WIDTH`"
211 );
212 ashc.num_cells / DIGEST_WIDTH
213 })
214 .collect::<Vec<_>>();
215 assert!(!(addr_space_sizes.is_empty()), "Invalid config");
216
217 let num_addr_spaces = addr_space_sizes.len() - ADDR_SPACE_OFFSET as usize;
218 assert!(
219 num_addr_spaces.is_power_of_two(),
220 "Number of address spaces must be a one plus power of two"
221 );
222 for &sz in addr_space_sizes.iter().take(ADDR_SPACE_OFFSET as usize) {
223 assert!(
224 sz == 0,
225 "The first `ADDR_SPACE_OFFSET` address spaces are assumed to be empty"
226 );
227 }
228
229 let label_max_bits = mem_config.pointer_max_bits - log2_ceil_usize(DIGEST_WIDTH);
230
231 let zero_hash = DeviceBuffer::<H>::with_capacity(label_max_bits + 1);
232 let top_roots = DeviceBuffer::<H>::with_capacity(2 * num_addr_spaces - 1);
233 unsafe {
234 calculate_zero_hash(&zero_hash, label_max_bits).unwrap();
235 }
236
237 Self {
238 stream: Arc::new(CudaStream::new().unwrap()),
239 subtrees: Vec::new(),
240 top_roots,
241 height: label_max_bits + log2_ceil_usize(num_addr_spaces),
242 zero_hash,
243 hasher_buffer: hasher_chip.shared_buffer(),
244 mem_config,
245 }
246 }
247
248 pub fn mem_config(&self) -> &MemoryConfig {
249 &self.mem_config
250 }
251
252 pub fn build_async(&mut self, d_data: &DeviceBuffer<u8>, addr_space: usize) {
258 if addr_space < ADDR_SPACE_OFFSET as usize {
259 return;
260 }
261 let addr_space_idx = addr_space - ADDR_SPACE_OFFSET as usize;
262 if addr_space < self.mem_config.addr_spaces.len() && addr_space_idx == self.subtrees.len() {
263 let mut subtree = MemoryMerkleSubTree::new(
264 self.mem_config.addr_spaces[addr_space].num_cells / DIGEST_WIDTH,
265 1 << (self.zero_hash.len() - 1), );
267 subtree.build_async(d_data, addr_space_idx, &self.zero_hash);
268 self.subtrees.push(subtree);
269 } else {
270 panic!("Invalid address space index");
271 }
272 }
273
274 pub fn finalize(&self) {
277 for subtree in self.subtrees.iter() {
278 self.stream.wait(subtree.event.as_ref().unwrap()).unwrap();
279 }
280
281 let we_can_gather_bufs_event = CudaEvent::new().unwrap();
282 unsafe {
283 we_can_gather_bufs_event
284 .record(self.stream.as_raw())
285 .unwrap();
286 }
287 default_stream_wait(&we_can_gather_bufs_event).unwrap();
288
289 let roots: Vec<usize> = self
290 .subtrees
291 .iter()
292 .map(|subtree| subtree.buf.as_ptr() as usize)
293 .collect();
294 let d_roots = roots.to_device().unwrap();
295 let to_device_event = CudaEvent::new().unwrap();
296 unsafe {
297 to_device_event.record(cudaStreamPerThread).unwrap();
298 }
299 self.stream.wait(&to_device_event).unwrap();
300
301 unsafe {
302 finalize_merkle_tree(
303 &d_roots,
304 &self.top_roots,
305 self.subtrees.len(),
306 self.stream.as_raw(),
307 )
308 .unwrap();
309 }
310
311 self.stream.synchronize().unwrap();
312 }
313
314 pub fn drop_subtrees(&mut self) {
316 self.subtrees = Vec::new();
317 }
318
319 pub fn update_with_touched_blocks(
321 &self,
322 unpadded_height: usize,
323 d_touched_blocks: &DeviceBuffer<u32>, empty_touched_blocks: bool,
325 ) -> AirProvingContext<GpuBackend> {
326 let mut public_values = self.top_roots.to_host().unwrap()[0].to_vec();
327 let merkle_trace = {
328 let width = MemoryMerkleCols::<u8, DIGEST_WIDTH>::width();
329 let padded_height = next_power_of_two_or_zero(unpadded_height);
330 let output = DeviceMatrix::<F>::with_capacity(padded_height, width);
331 output.buffer().fill_zero().unwrap();
332
333 let actual_heights = self.subtrees.iter().map(|s| s.height).collect::<Vec<_>>();
334 let subtrees_pointers = self
335 .subtrees
336 .iter()
337 .map(|st| st.buf.as_ptr() as usize)
338 .collect::<Vec<_>>()
339 .to_device()
340 .unwrap();
341 unsafe {
342 update_merkle_tree(
343 &output,
344 &subtrees_pointers,
345 &self.top_roots,
346 &self.zero_hash,
347 d_touched_blocks,
348 self.height - log2_ceil_usize(self.subtrees.len()),
349 &actual_heights,
350 unpadded_height,
351 &self.hasher_buffer,
352 )
353 .unwrap();
354 }
355
356 if empty_touched_blocks {
357 let mut output_vec = output.to_host().unwrap();
359 output_vec[unpadded_height - 1 + (width - 2) * padded_height] = F::ONE; output_vec[unpadded_height - 1 + (width - 1) * padded_height] = F::ONE; DeviceMatrix::new(
362 Arc::new(output_vec.to_device().unwrap()),
363 padded_height,
364 width,
365 )
366 } else {
367 output
368 }
369 };
370 public_values.extend(self.top_roots.to_host().unwrap()[0].to_vec());
371
372 AirProvingContext::new(Vec::new(), Some(merkle_trace), public_values)
373 }
374
375 pub fn calculate_unpadded_height(
377 &self,
378 touched_memory: &TimestampedEquipartition<F, DIGEST_WIDTH>,
379 ) -> usize {
380 let md = self.mem_config.memory_dimensions();
381 let tree_height = md.overall_height();
382 let shift_address = |(sp, ptr): (u32, u32)| (sp, ptr / DIGEST_WIDTH as u32);
383 2 * if touched_memory.is_empty() {
384 tree_height
385 } else {
386 tree_height
387 + (0..(touched_memory.len() - 1))
388 .into_par_iter()
389 .map(|i| {
390 let x = md.label_to_index(shift_address(touched_memory[i].0));
391 let y = md.label_to_index(shift_address(touched_memory[i + 1].0));
392 (x ^ y).ilog2() as usize
393 })
394 .sum::<usize>()
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use std::sync::Arc;
402
403 use openvm_circuit::{
404 arch::{
405 testing::POSEIDON2_DIRECT_BUS, vm_poseidon2_config, AddressSpaceHostLayout,
406 MemoryCellType, MemoryConfig,
407 },
408 system::{
409 memory::{
410 merkle::MerkleTree,
411 online::{GuestMemory, LinearMemory},
412 AddressMap, TimestampedValues,
413 },
414 poseidon2::Poseidon2PeripheryChip,
415 },
416 };
417 use openvm_cuda_backend::prelude::F;
418 use openvm_cuda_common::{
419 copy::{MemCopyD2H, MemCopyH2D},
420 d_buffer::DeviceBuffer,
421 };
422 use openvm_instructions::{
423 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
424 NATIVE_AS,
425 };
426 use openvm_stark_sdk::utils::create_seeded_rng;
427 use p3_field::{FieldAlgebra, PrimeField32};
428 use rand::Rng;
429
430 use super::MemoryMerkleTree;
431 use crate::system::cuda::{Poseidon2PeripheryChipGPU, DIGEST_WIDTH};
432
433 #[test]
434 fn test_cuda_merkle_tree_cpu_gpu_root_equivalence() {
435 let mut rng = create_seeded_rng();
436 let mem_config = {
437 let mut addr_spaces = MemoryConfig::empty_address_space_configs(5);
438 let max_cells = 1 << 16;
439 addr_spaces[RV32_REGISTER_AS as usize].num_cells = 32 * size_of::<u32>();
440 addr_spaces[RV32_MEMORY_AS as usize].num_cells = max_cells;
441 addr_spaces[NATIVE_AS as usize].num_cells = max_cells;
442 MemoryConfig::new(2, addr_spaces, max_cells.ilog2() as usize, 29, 17, 32)
443 };
444
445 let mut initial_memory = GuestMemory::new(AddressMap::from_mem_config(&mem_config));
446 for (idx, space) in mem_config.addr_spaces.iter().enumerate() {
447 unsafe {
448 match space.layout {
449 MemoryCellType::Null => {}
450 MemoryCellType::U8 => {
451 for i in 0..space.num_cells {
452 initial_memory.write::<u8, 1>(
453 idx as u32,
454 i as u32,
455 [rng.gen_range(0..space.layout.size()) as u8],
456 );
457 }
458 }
459 MemoryCellType::U16 => {
460 for i in 0..space.num_cells {
461 initial_memory.write::<u16, 1>(
462 idx as u32,
463 i as u32,
464 [rng.gen_range(0..space.layout.size()) as u16],
465 );
466 }
467 }
468 MemoryCellType::U32 => {
469 for i in 0..space.num_cells {
470 initial_memory.write::<u32, 1>(
471 idx as u32,
472 i as u32,
473 [rng.gen_range(0..space.layout.size()) as u32],
474 );
475 }
476 }
477 MemoryCellType::Native { .. } => {
478 for i in 0..space.num_cells {
479 initial_memory.write::<F, 1>(
480 idx as u32,
481 i as u32,
482 [F::from_canonical_u32(rng.gen_range(0..F::ORDER_U32))],
483 );
484 }
485 }
486 }
487 }
488 }
489
490 let gpu_hasher_chip = Arc::new(Poseidon2PeripheryChipGPU::new(
491 (mem_config
492 .addr_spaces
493 .iter()
494 .map(|ashc| ashc.num_cells * 2 + mem_config.memory_dimensions().overall_height())
495 .sum::<usize>()
496 * 2)
497 .next_power_of_two()
498 * 2
499 * DIGEST_WIDTH, 1, ));
502 let mut gpu_merkle_tree = MemoryMerkleTree::new(mem_config.clone(), gpu_hasher_chip);
503 for (i, mem) in initial_memory.memory.get_memory().iter().enumerate() {
504 let mem_slice = mem.as_slice();
505 gpu_merkle_tree.build_async(
506 &(if !mem_slice.is_empty() {
507 mem_slice.to_device().unwrap()
508 } else {
509 DeviceBuffer::new()
510 }),
511 i,
512 );
513 }
514 gpu_merkle_tree.finalize();
515
516 let cpu_hasher_chip =
517 Poseidon2PeripheryChip::new(vm_poseidon2_config(), POSEIDON2_DIRECT_BUS, 3);
518 let mut cpu_merkle_tree = MerkleTree::<F, DIGEST_WIDTH>::from_memory(
519 &initial_memory.memory,
520 &mem_config.memory_dimensions(),
521 &cpu_hasher_chip,
522 );
523
524 assert_eq!(
525 cpu_merkle_tree.root(),
526 gpu_merkle_tree.top_roots.to_host().unwrap()[0]
527 );
528 eprintln!("{:?}", cpu_merkle_tree.root());
529 eprintln!("{:?}", gpu_merkle_tree.top_roots.to_host().unwrap()[0]);
530
531 let touched_ptrs = mem_config
535 .addr_spaces
536 .iter()
537 .enumerate()
538 .flat_map(|(i, cnf)| {
539 let mut ptrs = Vec::new();
540 for j in 0..(cnf.num_cells / DIGEST_WIDTH) {
541 if rng.gen_bool(0.333) {
542 ptrs.push((i as u32, (j * DIGEST_WIDTH) as u32));
543 }
544 }
545 ptrs
546 })
547 .collect::<Vec<_>>();
548 let new_data = touched_ptrs
549 .iter()
550 .map(|_| std::array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..F::ORDER_U32))))
551 .collect::<Vec<[F; DIGEST_WIDTH]>>();
552 assert!(!touched_ptrs.is_empty());
553 cpu_merkle_tree.finalize(
554 &cpu_hasher_chip,
555 &(touched_ptrs
556 .iter()
557 .copied()
558 .zip(new_data.iter().copied())
559 .collect()),
560 &mem_config.memory_dimensions(),
561 );
562 let touched_blocks = touched_ptrs
563 .into_iter()
564 .zip(new_data)
565 .map(|(address, data)| {
566 (
567 address,
568 TimestampedValues {
569 timestamp: rng.gen_range(0..(1u32 << mem_config.timestamp_max_bits)),
570 values: data,
571 },
572 )
573 })
574 .collect::<Vec<_>>();
575 let d_touched_blocks = touched_blocks.to_device().unwrap().as_buffer::<u32>();
576
577 gpu_merkle_tree.update_with_touched_blocks(
578 gpu_merkle_tree.calculate_unpadded_height(&touched_blocks),
579 &d_touched_blocks,
580 false,
581 );
582
583 assert_eq!(
584 cpu_merkle_tree.root(),
585 gpu_merkle_tree.top_roots.to_host().unwrap()[0]
586 );
587 eprintln!("{:?}", cpu_merkle_tree.root());
588 eprintln!("{:?}", gpu_merkle_tree.top_roots.to_host().unwrap()[0]);
589 }
590}