openvm_circuit/system/memory/merkle/
trace.rs1use std::{
2 borrow::BorrowMut,
3 sync::{atomic::AtomicU32, Arc},
4};
5
6use openvm_stark_backend::{
7 config::{Domain, StarkGenericConfig, Val},
8 p3_commit::PolynomialSpace,
9 p3_field::PrimeField32,
10 p3_matrix::dense::RowMajorMatrix,
11 prover::{cpu::CpuBackend, types::AirProvingContext},
12 ChipUsageGetter,
13};
14use tracing::instrument;
15
16use crate::{
17 arch::hasher::HasherChip,
18 system::{
19 memory::{
20 merkle::{tree::MerkleTree, FinalState, MemoryMerkleChip, MemoryMerkleCols},
21 Equipartition, MemoryImage,
22 },
23 poseidon2::{
24 Poseidon2PeripheryBaseChip, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_WIDTH,
25 },
26 },
27};
28
29impl<const CHUNK: usize, F: PrimeField32> MemoryMerkleChip<CHUNK, F> {
30 #[instrument(name = "merkle_finalize", level = "debug", skip_all)]
31 pub(crate) fn finalize(
32 &mut self,
33 initial_memory: &MemoryImage,
34 final_memory: &Equipartition<F, CHUNK>,
35 hasher: &impl HasherChip<CHUNK, F>,
36 ) {
37 assert!(self.final_state.is_none(), "Merkle chip already finalized");
38 let mut tree = MerkleTree::from_memory(initial_memory, &self.air.memory_dimensions, hasher);
39 self.final_state = Some(tree.finalize(hasher, final_memory, &self.air.memory_dimensions));
40 }
41}
42
43impl<const CHUNK: usize, F> MemoryMerkleChip<CHUNK, F>
44where
45 F: PrimeField32,
46{
47 pub fn generate_proving_ctx<SC>(&mut self) -> AirProvingContext<CpuBackend<SC>>
48 where
49 SC: StarkGenericConfig,
50 Domain<SC>: PolynomialSpace<Val = F>,
51 {
52 assert!(
53 self.final_state.is_some(),
54 "Merkle chip must finalize before trace generation"
55 );
56 let FinalState {
57 mut rows,
58 init_root,
59 final_root,
60 } = self.final_state.take().unwrap();
61 rows.reverse();
64 rows.swap(0, 1);
65
66 #[cfg(feature = "metrics")]
67 {
68 self.current_height = rows.len();
69 }
70 let width = MemoryMerkleCols::<Val<SC>, CHUNK>::width();
71 let mut height = rows.len().next_power_of_two();
72 if let Some(mut oh) = self.overridden_height {
73 oh = oh.next_power_of_two();
74 assert!(
75 oh >= height,
76 "Overridden height {oh} is less than the required height {height}"
77 );
78 height = oh;
79 }
80 let mut trace = Val::<SC>::zero_vec(width * height);
81
82 for (trace_row, row) in trace.chunks_exact_mut(width).zip(rows) {
83 *trace_row.borrow_mut() = row;
84 }
85
86 let trace = Arc::new(RowMajorMatrix::new(trace, width));
87 let pvs = init_root.into_iter().chain(final_root).collect();
88 AirProvingContext::simple(trace, pvs)
89 }
90}
91impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for MemoryMerkleChip<CHUNK, F> {
92 fn air_name(&self) -> String {
93 "Merkle".to_string()
94 }
95
96 fn current_trace_height(&self) -> usize {
97 self.final_state.as_ref().map(|s| s.rows.len()).unwrap_or(0)
98 }
99
100 fn trace_width(&self) -> usize {
101 MemoryMerkleCols::<F, CHUNK>::width()
102 }
103}
104
105pub trait SerialReceiver<T> {
106 fn receive(&self, msg: T);
107}
108
109impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> SerialReceiver<&'a [F]>
110 for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
111{
112 fn receive(&self, perm_preimage: &'a [F]) {
115 assert!(perm_preimage.len() <= PERIPHERY_POSEIDON2_WIDTH);
116 let mut state = [F::ZERO; PERIPHERY_POSEIDON2_WIDTH];
117 state[..perm_preimage.len()].copy_from_slice(perm_preimage);
118 let count = self.records.entry(state).or_insert(AtomicU32::new(0));
119 count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
120 }
121}
122
123impl<'a, F: PrimeField32> SerialReceiver<&'a [F]> for Poseidon2PeripheryChip<F> {
124 fn receive(&self, perm_preimage: &'a [F]) {
125 match self {
126 Poseidon2PeripheryChip::Register0(chip) => chip.receive(perm_preimage),
127 Poseidon2PeripheryChip::Register1(chip) => chip.receive(perm_preimage),
128 }
129 }
130}