1use std::{
2 borrow::BorrowMut,
3 cmp::Reverse,
4 sync::{atomic::AtomicU32, Arc},
5};
6
7use openvm_stark_backend::{
8 config::{StarkGenericConfig, Val},
9 p3_field::{FieldAlgebra, PrimeField32},
10 p3_matrix::dense::RowMajorMatrix,
11 prover::types::AirProofInput,
12 AirRef, Chip, ChipUsageGetter,
13};
14use rustc_hash::FxHashSet;
15
16use crate::{
17 arch::hasher::HasherChip,
18 system::{
19 memory::{
20 controller::dimensions::MemoryDimensions,
21 merkle::{FinalState, MemoryMerkleChip, MemoryMerkleCols},
22 tree::MemoryNode::{self, NonLeaf},
23 Equipartition,
24 },
25 poseidon2::{
26 Poseidon2PeripheryBaseChip, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_WIDTH,
27 },
28 },
29};
30
31impl<const CHUNK: usize, F: PrimeField32> MemoryMerkleChip<CHUNK, F> {
32 pub fn finalize(
33 &mut self,
34 initial_tree: &MemoryNode<CHUNK, F>,
35 final_memory: &Equipartition<F, CHUNK>,
36 hasher: &mut impl HasherChip<CHUNK, F>,
37 ) {
38 assert!(self.final_state.is_none(), "Merkle chip already finalized");
39 if self.touched_nodes.len() == 1 {
43 self.touch_node(1, 0, 0);
44 }
45
46 let mut rows = vec![];
47 let mut tree_helper = TreeHelper {
48 memory_dimensions: self.air.memory_dimensions,
49 final_memory,
50 touched_nodes: &self.touched_nodes,
51 trace_rows: &mut rows,
52 };
53 let final_tree = tree_helper.recur(
54 self.air.memory_dimensions.overall_height(),
55 initial_tree,
56 0,
57 0,
58 hasher,
59 );
60 self.final_state = Some(FinalState {
61 rows,
62 init_root: initial_tree.hash(),
63 final_root: final_tree.hash(),
64 });
65 }
66}
67
68impl<const CHUNK: usize, SC: StarkGenericConfig> Chip<SC> for MemoryMerkleChip<CHUNK, Val<SC>>
69where
70 Val<SC>: PrimeField32,
71{
72 fn air(&self) -> AirRef<SC> {
73 Arc::new(self.air.clone())
74 }
75
76 fn generate_air_proof_input(self) -> AirProofInput<SC> {
77 assert!(
78 self.final_state.is_some(),
79 "Merkle chip must finalize before trace generation"
80 );
81 let FinalState {
82 mut rows,
83 init_root,
84 final_root,
85 } = self.final_state.unwrap();
86 rows.sort_by_key(|row| Reverse(row.parent_height));
89
90 let width = MemoryMerkleCols::<Val<SC>, CHUNK>::width();
91 let mut height = rows.len().next_power_of_two();
92 if let Some(mut oh) = self.overridden_height {
93 oh = oh.next_power_of_two();
94 assert!(
95 oh >= height,
96 "Overridden height {oh} is less than the required height {height}"
97 );
98 height = oh;
99 }
100 let mut trace = Val::<SC>::zero_vec(width * height);
101
102 for (trace_row, row) in trace.chunks_exact_mut(width).zip(rows) {
103 *trace_row.borrow_mut() = row;
104 }
105
106 let trace = RowMajorMatrix::new(trace, width);
107 let pvs = init_root.into_iter().chain(final_root).collect();
108 AirProofInput::simple(trace, pvs)
109 }
110}
111impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for MemoryMerkleChip<CHUNK, F> {
112 fn air_name(&self) -> String {
113 "Merkle".to_string()
114 }
115
116 fn current_trace_height(&self) -> usize {
117 2 * self.num_touched_nonleaves
118 }
119
120 fn trace_width(&self) -> usize {
121 MemoryMerkleCols::<F, CHUNK>::width()
122 }
123}
124
125struct TreeHelper<'a, const CHUNK: usize, F: PrimeField32> {
126 memory_dimensions: MemoryDimensions,
127 final_memory: &'a Equipartition<F, CHUNK>,
128 touched_nodes: &'a FxHashSet<(usize, u32, u32)>,
129 trace_rows: &'a mut Vec<MemoryMerkleCols<F, CHUNK>>,
130}
131
132impl<const CHUNK: usize, F: PrimeField32> TreeHelper<'_, CHUNK, F> {
133 fn recur(
134 &mut self,
135 height: usize,
136 initial_node: &MemoryNode<CHUNK, F>,
137 as_label: u32,
138 address_label: u32,
139 hasher: &mut impl HasherChip<CHUNK, F>,
140 ) -> MemoryNode<CHUNK, F> {
141 if height == 0 {
142 let address_space = as_label + self.memory_dimensions.as_offset;
143 let leaf_values = *self
144 .final_memory
145 .get(&(address_space, address_label))
146 .unwrap_or(&[F::ZERO; CHUNK]);
147 MemoryNode::new_leaf(hasher.hash(&leaf_values))
148 } else if let NonLeaf {
149 left: initial_left_node,
150 right: initial_right_node,
151 ..
152 } = initial_node.clone()
153 {
154 hasher.compress_and_record(&initial_left_node.hash(), &initial_right_node.hash());
156
157 let is_as_section = height > self.memory_dimensions.address_height;
158
159 let (left_as_label, right_as_label) = if is_as_section {
160 (2 * as_label, 2 * as_label + 1)
161 } else {
162 (as_label, as_label)
163 };
164 let (left_address_label, right_address_label) = if is_as_section {
165 (address_label, address_label)
166 } else {
167 (2 * address_label, 2 * address_label + 1)
168 };
169
170 let left_is_final =
171 !self
172 .touched_nodes
173 .contains(&(height - 1, left_as_label, left_address_label));
174
175 let final_left_node = if left_is_final {
176 initial_left_node
177 } else {
178 Arc::new(self.recur(
179 height - 1,
180 &initial_left_node,
181 left_as_label,
182 left_address_label,
183 hasher,
184 ))
185 };
186
187 let right_is_final =
188 !self
189 .touched_nodes
190 .contains(&(height - 1, right_as_label, right_address_label));
191
192 let final_right_node = if right_is_final {
193 initial_right_node
194 } else {
195 Arc::new(self.recur(
196 height - 1,
197 &initial_right_node,
198 right_as_label,
199 right_address_label,
200 hasher,
201 ))
202 };
203
204 let final_node = MemoryNode::new_nonleaf(final_left_node, final_right_node, hasher);
205 self.add_trace_row(height, as_label, address_label, initial_node, None);
206 self.add_trace_row(
207 height,
208 as_label,
209 address_label,
210 &final_node,
211 Some([left_is_final, right_is_final]),
212 );
213 final_node
214 } else {
215 panic!("Leaf {:?} found at nonzero height {}", initial_node, height);
216 }
217 }
218
219 fn add_trace_row(
221 &mut self,
222 parent_height: usize,
223 as_label: u32,
224 address_label: u32,
225 node: &MemoryNode<CHUNK, F>,
226 direction_changes: Option<[bool; 2]>,
227 ) {
228 let [left_direction_change, right_direction_change] =
229 direction_changes.unwrap_or([false; 2]);
230 let cols = if let NonLeaf { hash, left, right } = node {
231 MemoryMerkleCols {
232 expand_direction: if direction_changes.is_none() {
233 F::ONE
234 } else {
235 F::NEG_ONE
236 },
237 height_section: F::from_bool(parent_height > self.memory_dimensions.address_height),
238 parent_height: F::from_canonical_usize(parent_height),
239 is_root: F::from_bool(parent_height == self.memory_dimensions.overall_height()),
240 parent_as_label: F::from_canonical_u32(as_label),
241 parent_address_label: F::from_canonical_u32(address_label),
242 parent_hash: *hash,
243 left_child_hash: left.hash(),
244 right_child_hash: right.hash(),
245 left_direction_different: F::from_bool(left_direction_change),
246 right_direction_different: F::from_bool(right_direction_change),
247 }
248 } else {
249 panic!("trace_rows expects node = {:?} to be NonLeaf", node);
250 };
251 self.trace_rows.push(cols);
252 }
253}
254
255pub trait SerialReceiver<T> {
256 fn receive(&mut self, msg: T);
257}
258
259impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> SerialReceiver<&'a [F]>
260 for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
261{
262 fn receive(&mut self, perm_preimage: &'a [F]) {
265 assert!(perm_preimage.len() <= PERIPHERY_POSEIDON2_WIDTH);
266 let mut state = [F::ZERO; PERIPHERY_POSEIDON2_WIDTH];
267 state[..perm_preimage.len()].copy_from_slice(perm_preimage);
268 let count = self.records.entry(state).or_insert(AtomicU32::new(0));
269 count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
270 }
271}
272
273impl<'a, F: PrimeField32> SerialReceiver<&'a [F]> for Poseidon2PeripheryChip<F> {
274 fn receive(&mut self, perm_preimage: &'a [F]) {
275 match self {
276 Poseidon2PeripheryChip::Register0(chip) => chip.receive(perm_preimage),
277 Poseidon2PeripheryChip::Register1(chip) => chip.receive(perm_preimage),
278 }
279 }
280}