openvm_circuit/system/memory/merkle/
trace.rs

1use std::{
2    borrow::BorrowMut,
3    sync::{atomic::AtomicU32, Arc},
4};
5
6use openvm_stark_backend::{
7    config::{Domain, StarkGenericConfig, Val},
8    p3_commit::PolynomialSpace,
9    p3_field::PrimeField32,
10    p3_matrix::dense::RowMajorMatrix,
11    prover::{cpu::CpuBackend, types::AirProvingContext},
12    ChipUsageGetter,
13};
14use tracing::instrument;
15
16use crate::{
17    arch::hasher::HasherChip,
18    system::{
19        memory::{
20            merkle::{tree::MerkleTree, FinalState, MemoryMerkleChip, MemoryMerkleCols},
21            Equipartition, MemoryImage,
22        },
23        poseidon2::{
24            Poseidon2PeripheryBaseChip, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_WIDTH,
25        },
26    },
27};
28
29impl<const CHUNK: usize, F: PrimeField32> MemoryMerkleChip<CHUNK, F> {
30    #[instrument(name = "merkle_finalize", level = "debug", skip_all)]
31    pub(crate) fn finalize(
32        &mut self,
33        initial_memory: &MemoryImage,
34        final_memory: &Equipartition<F, CHUNK>,
35        hasher: &impl HasherChip<CHUNK, F>,
36    ) {
37        assert!(self.final_state.is_none(), "Merkle chip already finalized");
38        let memory_dimensions = &self.air.memory_dimensions;
39        let mut tree = MerkleTree::from_memory(initial_memory, memory_dimensions, hasher);
40        self.final_state = Some(tree.finalize(hasher, final_memory, memory_dimensions));
41        self.top_tree = tree.top_tree(memory_dimensions.addr_space_height);
42    }
43}
44
45impl<const CHUNK: usize, F> MemoryMerkleChip<CHUNK, F>
46where
47    F: PrimeField32,
48{
49    pub fn generate_proving_ctx<SC>(&mut self) -> AirProvingContext<CpuBackend<SC>>
50    where
51        SC: StarkGenericConfig,
52        Domain<SC>: PolynomialSpace<Val = F>,
53    {
54        assert!(
55            self.final_state.is_some(),
56            "Merkle chip must finalize before trace generation"
57        );
58        let FinalState {
59            mut rows,
60            init_root,
61            final_root,
62        } = self.final_state.take().unwrap();
63        // important that this sort be stable,
64        // because we need the initial root to be first and the final root to be second
65        rows.reverse();
66        rows.swap(0, 1);
67
68        #[cfg(feature = "metrics")]
69        {
70            self.current_height = rows.len();
71        }
72        let width = MemoryMerkleCols::<Val<SC>, CHUNK>::width();
73        let mut height = rows.len().next_power_of_two();
74        if let Some(mut oh) = self.overridden_height {
75            oh = oh.next_power_of_two();
76            assert!(
77                oh >= height,
78                "Overridden height {oh} is less than the required height {height}"
79            );
80            height = oh;
81        }
82        let mut trace = Val::<SC>::zero_vec(width * height);
83
84        for (trace_row, row) in trace.chunks_exact_mut(width).zip(rows) {
85            *trace_row.borrow_mut() = row;
86        }
87
88        let trace = Arc::new(RowMajorMatrix::new(trace, width));
89        let pvs = init_root.into_iter().chain(final_root).collect();
90        AirProvingContext::simple(trace, pvs)
91    }
92}
93impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for MemoryMerkleChip<CHUNK, F> {
94    fn air_name(&self) -> String {
95        "Merkle".to_string()
96    }
97
98    fn current_trace_height(&self) -> usize {
99        self.final_state.as_ref().map(|s| s.rows.len()).unwrap_or(0)
100    }
101
102    fn trace_width(&self) -> usize {
103        MemoryMerkleCols::<F, CHUNK>::width()
104    }
105}
106
107pub trait SerialReceiver<T> {
108    fn receive(&self, msg: T);
109}
110
111impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> SerialReceiver<&'a [F]>
112    for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
113{
114    /// Receives a permutation preimage, pads with zeros to the permutation width, and records.
115    /// The permutation preimage must have length at most the permutation width (panics otherwise).
116    fn receive(&self, perm_preimage: &'a [F]) {
117        assert!(perm_preimage.len() <= PERIPHERY_POSEIDON2_WIDTH);
118        let mut state = [F::ZERO; PERIPHERY_POSEIDON2_WIDTH];
119        state[..perm_preimage.len()].copy_from_slice(perm_preimage);
120        let count = self.records.entry(state).or_insert(AtomicU32::new(0));
121        count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
122    }
123}
124
125impl<'a, F: PrimeField32> SerialReceiver<&'a [F]> for Poseidon2PeripheryChip<F> {
126    fn receive(&self, perm_preimage: &'a [F]) {
127        match self {
128            Poseidon2PeripheryChip::Register0(chip) => chip.receive(perm_preimage),
129            Poseidon2PeripheryChip::Register1(chip) => chip.receive(perm_preimage),
130        }
131    }
132}