openvm_circuit/system/memory/merkle/
mod.rs

1use std::array;
2
3use openvm_stark_backend::{
4    interaction::PermutationCheckBus, p3_field::PrimeField32, p3_maybe_rayon::prelude::*,
5};
6
7use super::{controller::dimensions::MemoryDimensions, online::LinearMemory};
8use crate::{
9    arch::AddressSpaceHostLayout,
10    system::memory::{online::PAGE_SIZE, AddressMap},
11};
12
13mod air;
14mod columns;
15pub mod public_values;
16mod trace;
17mod tree;
18
19pub use air::*;
20pub use columns::*;
21pub(super) use trace::SerialReceiver;
22pub use tree::*;
23
24#[cfg(test)]
25mod tests;
26
27pub struct MemoryMerkleChip<const CHUNK: usize, F> {
28    pub air: MemoryMerkleAir<CHUNK>,
29    final_state: Option<FinalState<CHUNK, F>>,
30    overridden_height: Option<usize>,
31    /// Used for metric collection purposes only
32    #[cfg(feature = "metrics")]
33    pub(crate) current_height: usize,
34}
35#[derive(Debug)]
36pub struct FinalState<const CHUNK: usize, F> {
37    rows: Vec<MemoryMerkleCols<F, CHUNK>>,
38    init_root: [F; CHUNK],
39    final_root: [F; CHUNK],
40}
41
42impl<const CHUNK: usize, F: PrimeField32> MemoryMerkleChip<CHUNK, F> {
43    /// `compression_bus` is the bus for direct (no-memory involved) interactions to call the
44    /// cryptographic compression function.
45    pub fn new(
46        memory_dimensions: MemoryDimensions,
47        merkle_bus: PermutationCheckBus,
48        compression_bus: PermutationCheckBus,
49    ) -> Self {
50        assert!(memory_dimensions.addr_space_height > 0);
51        assert!(memory_dimensions.address_height > 0);
52        Self {
53            air: MemoryMerkleAir {
54                memory_dimensions,
55                merkle_bus,
56                compression_bus,
57            },
58            final_state: None,
59            overridden_height: None,
60            #[cfg(feature = "metrics")]
61            current_height: 0,
62        }
63    }
64    pub fn set_overridden_height(&mut self, override_height: usize) {
65        self.overridden_height = Some(override_height);
66    }
67}
68
69#[tracing::instrument(level = "info", skip_all)]
70fn memory_to_vec_partition<F: PrimeField32, const N: usize>(
71    memory: &AddressMap,
72    md: &MemoryDimensions,
73) -> Vec<(u64, [F; N])> {
74    (0..memory.mem.len())
75        .into_par_iter()
76        .map(move |as_idx| {
77            let space_mem = memory.mem[as_idx].as_slice();
78            let addr_space_layout = memory.config[as_idx].layout;
79            let cell_size = addr_space_layout.size();
80            debug_assert_eq!(PAGE_SIZE % (cell_size * N), 0);
81
82            let num_nonzero_pages = space_mem
83                .par_chunks(PAGE_SIZE)
84                .enumerate()
85                .flat_map(|(idx, page)| {
86                    if page.iter().any(|x| *x != 0) {
87                        Some(idx + 1)
88                    } else {
89                        None
90                    }
91                })
92                .max()
93                .unwrap_or(0);
94
95            let space_mem = &space_mem[..(num_nonzero_pages * PAGE_SIZE).min(space_mem.len())];
96            let mut num_elements = space_mem.len() / (cell_size * N);
97            // virtual memory may be larger than dimensions due to rounding up to page size
98            num_elements = num_elements.min(1 << md.address_height);
99
100            (0..num_elements)
101                .into_par_iter()
102                .map(move |idx| {
103                    (
104                        md.label_to_index((as_idx as u32, idx as u32)),
105                        array::from_fn(|i| unsafe {
106                            // SAFETY: idx < num_elements = space_mem.len() / (cell_size * N) so ptr
107                            // is within bounds. We are reading one cell at a time, so alignment is
108                            // guaranteed.
109                            let ptr: *const u8 =
110                                space_mem.as_ptr().add(idx * cell_size * N + i * cell_size);
111                            addr_space_layout
112                                .to_field(&*core::ptr::slice_from_raw_parts(ptr, cell_size))
113                        }),
114                    )
115                })
116                .collect::<Vec<_>>()
117        })
118        .collect::<Vec<_>>()
119        .into_iter()
120        .flatten()
121        .collect::<Vec<_>>()
122}