openvm_circuit/system/memory/merkle/
trace.rs

1use std::{
2    borrow::BorrowMut,
3    cmp::Reverse,
4    sync::{atomic::AtomicU32, Arc},
5};
6
7use openvm_stark_backend::{
8    config::{StarkGenericConfig, Val},
9    p3_field::{FieldAlgebra, PrimeField32},
10    p3_matrix::dense::RowMajorMatrix,
11    prover::types::AirProofInput,
12    AirRef, Chip, ChipUsageGetter,
13};
14use rustc_hash::FxHashSet;
15
16use crate::{
17    arch::hasher::HasherChip,
18    system::{
19        memory::{
20            controller::dimensions::MemoryDimensions,
21            merkle::{FinalState, MemoryMerkleChip, MemoryMerkleCols},
22            tree::MemoryNode::{self, NonLeaf},
23            Equipartition,
24        },
25        poseidon2::{
26            Poseidon2PeripheryBaseChip, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_WIDTH,
27        },
28    },
29};
30
31impl<const CHUNK: usize, F: PrimeField32> MemoryMerkleChip<CHUNK, F> {
32    pub fn finalize(
33        &mut self,
34        initial_tree: &MemoryNode<CHUNK, F>,
35        final_memory: &Equipartition<F, CHUNK>,
36        hasher: &mut impl HasherChip<CHUNK, F>,
37    ) {
38        assert!(self.final_state.is_none(), "Merkle chip already finalized");
39        // there needs to be a touched node with `height_section` = 0
40        // shouldn't be a leaf because
41        // trace generation will expect an interaction from MemoryInterfaceChip in that case
42        if self.touched_nodes.len() == 1 {
43            self.touch_node(1, 0, 0);
44        }
45
46        let mut rows = vec![];
47        let mut tree_helper = TreeHelper {
48            memory_dimensions: self.air.memory_dimensions,
49            final_memory,
50            touched_nodes: &self.touched_nodes,
51            trace_rows: &mut rows,
52        };
53        let final_tree = tree_helper.recur(
54            self.air.memory_dimensions.overall_height(),
55            initial_tree,
56            0,
57            0,
58            hasher,
59        );
60        self.final_state = Some(FinalState {
61            rows,
62            init_root: initial_tree.hash(),
63            final_root: final_tree.hash(),
64        });
65    }
66}
67
68impl<const CHUNK: usize, SC: StarkGenericConfig> Chip<SC> for MemoryMerkleChip<CHUNK, Val<SC>>
69where
70    Val<SC>: PrimeField32,
71{
72    fn air(&self) -> AirRef<SC> {
73        Arc::new(self.air.clone())
74    }
75
76    fn generate_air_proof_input(self) -> AirProofInput<SC> {
77        assert!(
78            self.final_state.is_some(),
79            "Merkle chip must finalize before trace generation"
80        );
81        let FinalState {
82            mut rows,
83            init_root,
84            final_root,
85        } = self.final_state.unwrap();
86        // important that this sort be stable,
87        // because we need the initial root to be first and the final root to be second
88        rows.sort_by_key(|row| Reverse(row.parent_height));
89
90        let width = MemoryMerkleCols::<Val<SC>, CHUNK>::width();
91        let mut height = rows.len().next_power_of_two();
92        if let Some(mut oh) = self.overridden_height {
93            oh = oh.next_power_of_two();
94            assert!(
95                oh >= height,
96                "Overridden height {oh} is less than the required height {height}"
97            );
98            height = oh;
99        }
100        let mut trace = Val::<SC>::zero_vec(width * height);
101
102        for (trace_row, row) in trace.chunks_exact_mut(width).zip(rows) {
103            *trace_row.borrow_mut() = row;
104        }
105
106        let trace = RowMajorMatrix::new(trace, width);
107        let pvs = init_root.into_iter().chain(final_root).collect();
108        AirProofInput::simple(trace, pvs)
109    }
110}
111impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for MemoryMerkleChip<CHUNK, F> {
112    fn air_name(&self) -> String {
113        "Merkle".to_string()
114    }
115
116    fn current_trace_height(&self) -> usize {
117        2 * self.num_touched_nonleaves
118    }
119
120    fn trace_width(&self) -> usize {
121        MemoryMerkleCols::<F, CHUNK>::width()
122    }
123}
124
125struct TreeHelper<'a, const CHUNK: usize, F: PrimeField32> {
126    memory_dimensions: MemoryDimensions,
127    final_memory: &'a Equipartition<F, CHUNK>,
128    touched_nodes: &'a FxHashSet<(usize, u32, u32)>,
129    trace_rows: &'a mut Vec<MemoryMerkleCols<F, CHUNK>>,
130}
131
132impl<const CHUNK: usize, F: PrimeField32> TreeHelper<'_, CHUNK, F> {
133    fn recur(
134        &mut self,
135        height: usize,
136        initial_node: &MemoryNode<CHUNK, F>,
137        as_label: u32,
138        address_label: u32,
139        hasher: &mut impl HasherChip<CHUNK, F>,
140    ) -> MemoryNode<CHUNK, F> {
141        if height == 0 {
142            let address_space = as_label + self.memory_dimensions.as_offset;
143            let leaf_values = *self
144                .final_memory
145                .get(&(address_space, address_label))
146                .unwrap_or(&[F::ZERO; CHUNK]);
147            MemoryNode::new_leaf(hasher.hash(&leaf_values))
148        } else if let NonLeaf {
149            left: initial_left_node,
150            right: initial_right_node,
151            ..
152        } = initial_node.clone()
153        {
154            // Tell the hasher about this hash.
155            hasher.compress_and_record(&initial_left_node.hash(), &initial_right_node.hash());
156
157            let is_as_section = height > self.memory_dimensions.address_height;
158
159            let (left_as_label, right_as_label) = if is_as_section {
160                (2 * as_label, 2 * as_label + 1)
161            } else {
162                (as_label, as_label)
163            };
164            let (left_address_label, right_address_label) = if is_as_section {
165                (address_label, address_label)
166            } else {
167                (2 * address_label, 2 * address_label + 1)
168            };
169
170            let left_is_final =
171                !self
172                    .touched_nodes
173                    .contains(&(height - 1, left_as_label, left_address_label));
174
175            let final_left_node = if left_is_final {
176                initial_left_node
177            } else {
178                Arc::new(self.recur(
179                    height - 1,
180                    &initial_left_node,
181                    left_as_label,
182                    left_address_label,
183                    hasher,
184                ))
185            };
186
187            let right_is_final =
188                !self
189                    .touched_nodes
190                    .contains(&(height - 1, right_as_label, right_address_label));
191
192            let final_right_node = if right_is_final {
193                initial_right_node
194            } else {
195                Arc::new(self.recur(
196                    height - 1,
197                    &initial_right_node,
198                    right_as_label,
199                    right_address_label,
200                    hasher,
201                ))
202            };
203
204            let final_node = MemoryNode::new_nonleaf(final_left_node, final_right_node, hasher);
205            self.add_trace_row(height, as_label, address_label, initial_node, None);
206            self.add_trace_row(
207                height,
208                as_label,
209                address_label,
210                &final_node,
211                Some([left_is_final, right_is_final]),
212            );
213            final_node
214        } else {
215            panic!("Leaf {:?} found at nonzero height {}", initial_node, height);
216        }
217    }
218
219    /// Expects `node` to be NonLeaf
220    fn add_trace_row(
221        &mut self,
222        parent_height: usize,
223        as_label: u32,
224        address_label: u32,
225        node: &MemoryNode<CHUNK, F>,
226        direction_changes: Option<[bool; 2]>,
227    ) {
228        let [left_direction_change, right_direction_change] =
229            direction_changes.unwrap_or([false; 2]);
230        let cols = if let NonLeaf { hash, left, right } = node {
231            MemoryMerkleCols {
232                expand_direction: if direction_changes.is_none() {
233                    F::ONE
234                } else {
235                    F::NEG_ONE
236                },
237                height_section: F::from_bool(parent_height > self.memory_dimensions.address_height),
238                parent_height: F::from_canonical_usize(parent_height),
239                is_root: F::from_bool(parent_height == self.memory_dimensions.overall_height()),
240                parent_as_label: F::from_canonical_u32(as_label),
241                parent_address_label: F::from_canonical_u32(address_label),
242                parent_hash: *hash,
243                left_child_hash: left.hash(),
244                right_child_hash: right.hash(),
245                left_direction_different: F::from_bool(left_direction_change),
246                right_direction_different: F::from_bool(right_direction_change),
247            }
248        } else {
249            panic!("trace_rows expects node = {:?} to be NonLeaf", node);
250        };
251        self.trace_rows.push(cols);
252    }
253}
254
255pub trait SerialReceiver<T> {
256    fn receive(&mut self, msg: T);
257}
258
259impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> SerialReceiver<&'a [F]>
260    for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
261{
262    /// Receives a permutation preimage, pads with zeros to the permutation width, and records.
263    /// The permutation preimage must have length at most the permutation width (panics otherwise).
264    fn receive(&mut self, perm_preimage: &'a [F]) {
265        assert!(perm_preimage.len() <= PERIPHERY_POSEIDON2_WIDTH);
266        let mut state = [F::ZERO; PERIPHERY_POSEIDON2_WIDTH];
267        state[..perm_preimage.len()].copy_from_slice(perm_preimage);
268        let count = self.records.entry(state).or_insert(AtomicU32::new(0));
269        count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
270    }
271}
272
273impl<'a, F: PrimeField32> SerialReceiver<&'a [F]> for Poseidon2PeripheryChip<F> {
274    fn receive(&mut self, perm_preimage: &'a [F]) {
275        match self {
276            Poseidon2PeripheryChip::Register0(chip) => chip.receive(perm_preimage),
277            Poseidon2PeripheryChip::Register1(chip) => chip.receive(perm_preimage),
278        }
279    }
280}