openvm_circuit/system/memory/merkle/
mod.rs

1use std::array;
2
3use openvm_stark_backend::{
4    interaction::PermutationCheckBus, p3_field::PrimeField32, p3_maybe_rayon::prelude::*,
5};
6
7use super::{controller::dimensions::MemoryDimensions, online::LinearMemory};
8use crate::{
9    arch::AddressSpaceHostLayout,
10    system::memory::{online::PAGE_SIZE, AddressMap},
11};
12
13mod air;
14mod columns;
15pub mod public_values;
16mod trace;
17mod tree;
18
19pub use air::*;
20pub use columns::*;
21pub(super) use trace::SerialReceiver;
22pub use tree::*;
23
24#[cfg(test)]
25mod tests;
26
27pub struct MemoryMerkleChip<const CHUNK: usize, F> {
28    pub air: MemoryMerkleAir<CHUNK>,
29    final_state: Option<FinalState<CHUNK, F>>,
30    overridden_height: Option<usize>,
31    pub(crate) top_tree: Vec<[F; CHUNK]>,
32    /// Used for metric collection purposes only
33    #[cfg(feature = "metrics")]
34    pub(crate) current_height: usize,
35}
36#[derive(Debug)]
37pub struct FinalState<const CHUNK: usize, F> {
38    rows: Vec<MemoryMerkleCols<F, CHUNK>>,
39    init_root: [F; CHUNK],
40    final_root: [F; CHUNK],
41}
42
43impl<const CHUNK: usize, F: PrimeField32> MemoryMerkleChip<CHUNK, F> {
44    /// `compression_bus` is the bus for direct (no-memory involved) interactions to call the
45    /// cryptographic compression function.
46    pub fn new(
47        memory_dimensions: MemoryDimensions,
48        merkle_bus: PermutationCheckBus,
49        compression_bus: PermutationCheckBus,
50    ) -> Self {
51        assert!(memory_dimensions.addr_space_height > 0);
52        assert!(memory_dimensions.address_height > 0);
53        Self {
54            air: MemoryMerkleAir {
55                memory_dimensions,
56                merkle_bus,
57                compression_bus,
58            },
59            final_state: None,
60            overridden_height: None,
61            top_tree: vec![],
62            #[cfg(feature = "metrics")]
63            current_height: 0,
64        }
65    }
66    pub fn set_overridden_height(&mut self, override_height: usize) {
67        self.overridden_height = Some(override_height);
68    }
69}
70
71#[tracing::instrument(level = "info", skip_all)]
72fn memory_to_vec_partition<F: PrimeField32, const N: usize>(
73    memory: &AddressMap,
74    md: &MemoryDimensions,
75) -> Vec<(u64, [F; N])> {
76    (0..memory.mem.len())
77        .into_par_iter()
78        .map(move |as_idx| {
79            let space_mem = memory.mem[as_idx].as_slice();
80            let addr_space_layout = memory.config[as_idx].layout;
81            let cell_size = addr_space_layout.size();
82            debug_assert_eq!(PAGE_SIZE % (cell_size * N), 0);
83
84            let num_nonzero_pages = space_mem
85                .par_chunks(PAGE_SIZE)
86                .enumerate()
87                .flat_map(|(idx, page)| {
88                    if page.iter().any(|x| *x != 0) {
89                        Some(idx + 1)
90                    } else {
91                        None
92                    }
93                })
94                .max()
95                .unwrap_or(0);
96
97            let space_mem = &space_mem[..(num_nonzero_pages * PAGE_SIZE).min(space_mem.len())];
98            let mut num_elements = space_mem.len() / (cell_size * N);
99            // virtual memory may be larger than dimensions due to rounding up to page size
100            num_elements = num_elements.min(1 << md.address_height);
101
102            (0..num_elements)
103                .into_par_iter()
104                .map(move |idx| {
105                    (
106                        md.label_to_index((as_idx as u32, idx as u32)),
107                        array::from_fn(|i| unsafe {
108                            // SAFETY: idx < num_elements = space_mem.len() / (cell_size * N) so ptr
109                            // is within bounds. We are reading one cell at a time, so alignment is
110                            // guaranteed.
111                            let ptr: *const u8 =
112                                space_mem.as_ptr().add(idx * cell_size * N + i * cell_size);
113                            addr_space_layout
114                                .to_field(&*core::ptr::slice_from_raw_parts(ptr, cell_size))
115                        }),
116                    )
117                })
118                .collect::<Vec<_>>()
119        })
120        .collect::<Vec<_>>()
121        .into_iter()
122        .flatten()
123        .collect::<Vec<_>>()
124}