openvm_circuit/system/memory/merkle/
trace.rs

1use std::{
2    borrow::BorrowMut,
3    sync::{atomic::AtomicU32, Arc},
4};
5
6use openvm_stark_backend::{
7    config::{Domain, StarkGenericConfig, Val},
8    p3_commit::PolynomialSpace,
9    p3_field::PrimeField32,
10    p3_matrix::dense::RowMajorMatrix,
11    prover::{cpu::CpuBackend, types::AirProvingContext},
12    ChipUsageGetter,
13};
14use tracing::instrument;
15
16use crate::{
17    arch::hasher::HasherChip,
18    system::{
19        memory::{
20            merkle::{tree::MerkleTree, FinalState, MemoryMerkleChip, MemoryMerkleCols},
21            Equipartition, MemoryImage,
22        },
23        poseidon2::{
24            Poseidon2PeripheryBaseChip, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_WIDTH,
25        },
26    },
27};
28
29impl<const CHUNK: usize, F: PrimeField32> MemoryMerkleChip<CHUNK, F> {
30    #[instrument(name = "merkle_finalize", level = "debug", skip_all)]
31    pub(crate) fn finalize(
32        &mut self,
33        initial_memory: &MemoryImage,
34        final_memory: &Equipartition<F, CHUNK>,
35        hasher: &impl HasherChip<CHUNK, F>,
36    ) {
37        assert!(self.final_state.is_none(), "Merkle chip already finalized");
38        let mut tree = MerkleTree::from_memory(initial_memory, &self.air.memory_dimensions, hasher);
39        self.final_state = Some(tree.finalize(hasher, final_memory, &self.air.memory_dimensions));
40    }
41}
42
43impl<const CHUNK: usize, F> MemoryMerkleChip<CHUNK, F>
44where
45    F: PrimeField32,
46{
47    pub fn generate_proving_ctx<SC>(&mut self) -> AirProvingContext<CpuBackend<SC>>
48    where
49        SC: StarkGenericConfig,
50        Domain<SC>: PolynomialSpace<Val = F>,
51    {
52        assert!(
53            self.final_state.is_some(),
54            "Merkle chip must finalize before trace generation"
55        );
56        let FinalState {
57            mut rows,
58            init_root,
59            final_root,
60        } = self.final_state.take().unwrap();
61        // important that this sort be stable,
62        // because we need the initial root to be first and the final root to be second
63        rows.reverse();
64        rows.swap(0, 1);
65
66        #[cfg(feature = "metrics")]
67        {
68            self.current_height = rows.len();
69        }
70        let width = MemoryMerkleCols::<Val<SC>, CHUNK>::width();
71        let mut height = rows.len().next_power_of_two();
72        if let Some(mut oh) = self.overridden_height {
73            oh = oh.next_power_of_two();
74            assert!(
75                oh >= height,
76                "Overridden height {oh} is less than the required height {height}"
77            );
78            height = oh;
79        }
80        let mut trace = Val::<SC>::zero_vec(width * height);
81
82        for (trace_row, row) in trace.chunks_exact_mut(width).zip(rows) {
83            *trace_row.borrow_mut() = row;
84        }
85
86        let trace = Arc::new(RowMajorMatrix::new(trace, width));
87        let pvs = init_root.into_iter().chain(final_root).collect();
88        AirProvingContext::simple(trace, pvs)
89    }
90}
91impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for MemoryMerkleChip<CHUNK, F> {
92    fn air_name(&self) -> String {
93        "Merkle".to_string()
94    }
95
96    fn current_trace_height(&self) -> usize {
97        self.final_state.as_ref().map(|s| s.rows.len()).unwrap_or(0)
98    }
99
100    fn trace_width(&self) -> usize {
101        MemoryMerkleCols::<F, CHUNK>::width()
102    }
103}
104
105pub trait SerialReceiver<T> {
106    fn receive(&self, msg: T);
107}
108
109impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> SerialReceiver<&'a [F]>
110    for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
111{
112    /// Receives a permutation preimage, pads with zeros to the permutation width, and records.
113    /// The permutation preimage must have length at most the permutation width (panics otherwise).
114    fn receive(&self, perm_preimage: &'a [F]) {
115        assert!(perm_preimage.len() <= PERIPHERY_POSEIDON2_WIDTH);
116        let mut state = [F::ZERO; PERIPHERY_POSEIDON2_WIDTH];
117        state[..perm_preimage.len()].copy_from_slice(perm_preimage);
118        let count = self.records.entry(state).or_insert(AtomicU32::new(0));
119        count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
120    }
121}
122
123impl<'a, F: PrimeField32> SerialReceiver<&'a [F]> for Poseidon2PeripheryChip<F> {
124    fn receive(&self, perm_preimage: &'a [F]) {
125        match self {
126            Poseidon2PeripheryChip::Register0(chip) => chip.receive(perm_preimage),
127            Poseidon2PeripheryChip::Register1(chip) => chip.receive(perm_preimage),
128        }
129    }
130}