openvm_circuit/system/memory/merkle/
trace.rs1use std::{
2 borrow::BorrowMut,
3 sync::{atomic::AtomicU32, Arc},
4};
5
6use openvm_stark_backend::{
7 config::{Domain, StarkGenericConfig, Val},
8 p3_commit::PolynomialSpace,
9 p3_field::PrimeField32,
10 p3_matrix::dense::RowMajorMatrix,
11 prover::{cpu::CpuBackend, types::AirProvingContext},
12 ChipUsageGetter,
13};
14use tracing::instrument;
15
16use crate::{
17 arch::hasher::HasherChip,
18 system::{
19 memory::{
20 merkle::{tree::MerkleTree, FinalState, MemoryMerkleChip, MemoryMerkleCols},
21 Equipartition, MemoryImage,
22 },
23 poseidon2::{
24 Poseidon2PeripheryBaseChip, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_WIDTH,
25 },
26 },
27};
28
29impl<const CHUNK: usize, F: PrimeField32> MemoryMerkleChip<CHUNK, F> {
30 #[instrument(name = "merkle_finalize", level = "debug", skip_all)]
31 pub(crate) fn finalize(
32 &mut self,
33 initial_memory: &MemoryImage,
34 final_memory: &Equipartition<F, CHUNK>,
35 hasher: &impl HasherChip<CHUNK, F>,
36 ) {
37 assert!(self.final_state.is_none(), "Merkle chip already finalized");
38 let memory_dimensions = &self.air.memory_dimensions;
39 let mut tree = MerkleTree::from_memory(initial_memory, memory_dimensions, hasher);
40 self.final_state = Some(tree.finalize(hasher, final_memory, memory_dimensions));
41 self.top_tree = tree.top_tree(memory_dimensions.addr_space_height);
42 }
43}
44
45impl<const CHUNK: usize, F> MemoryMerkleChip<CHUNK, F>
46where
47 F: PrimeField32,
48{
49 pub fn generate_proving_ctx<SC>(&mut self) -> AirProvingContext<CpuBackend<SC>>
50 where
51 SC: StarkGenericConfig,
52 Domain<SC>: PolynomialSpace<Val = F>,
53 {
54 assert!(
55 self.final_state.is_some(),
56 "Merkle chip must finalize before trace generation"
57 );
58 let FinalState {
59 mut rows,
60 init_root,
61 final_root,
62 } = self.final_state.take().unwrap();
63 rows.reverse();
66 rows.swap(0, 1);
67
68 #[cfg(feature = "metrics")]
69 {
70 self.current_height = rows.len();
71 }
72 let width = MemoryMerkleCols::<Val<SC>, CHUNK>::width();
73 let mut height = rows.len().next_power_of_two();
74 if let Some(mut oh) = self.overridden_height {
75 oh = oh.next_power_of_two();
76 assert!(
77 oh >= height,
78 "Overridden height {oh} is less than the required height {height}"
79 );
80 height = oh;
81 }
82 let mut trace = Val::<SC>::zero_vec(width * height);
83
84 for (trace_row, row) in trace.chunks_exact_mut(width).zip(rows) {
85 *trace_row.borrow_mut() = row;
86 }
87
88 let trace = Arc::new(RowMajorMatrix::new(trace, width));
89 let pvs = init_root.into_iter().chain(final_root).collect();
90 AirProvingContext::simple(trace, pvs)
91 }
92}
93impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for MemoryMerkleChip<CHUNK, F> {
94 fn air_name(&self) -> String {
95 "Merkle".to_string()
96 }
97
98 fn current_trace_height(&self) -> usize {
99 self.final_state.as_ref().map(|s| s.rows.len()).unwrap_or(0)
100 }
101
102 fn trace_width(&self) -> usize {
103 MemoryMerkleCols::<F, CHUNK>::width()
104 }
105}
106
107pub trait SerialReceiver<T> {
108 fn receive(&self, msg: T);
109}
110
111impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> SerialReceiver<&'a [F]>
112 for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
113{
114 fn receive(&self, perm_preimage: &'a [F]) {
117 assert!(perm_preimage.len() <= PERIPHERY_POSEIDON2_WIDTH);
118 let mut state = [F::ZERO; PERIPHERY_POSEIDON2_WIDTH];
119 state[..perm_preimage.len()].copy_from_slice(perm_preimage);
120 let count = self.records.entry(state).or_insert(AtomicU32::new(0));
121 count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
122 }
123}
124
125impl<'a, F: PrimeField32> SerialReceiver<&'a [F]> for Poseidon2PeripheryChip<F> {
126 fn receive(&self, perm_preimage: &'a [F]) {
127 match self {
128 Poseidon2PeripheryChip::Register0(chip) => chip.receive(perm_preimage),
129 Poseidon2PeripheryChip::Register1(chip) => chip.receive(perm_preimage),
130 }
131 }
132}