openvm_circuit/system/memory/merkle/
tree.rs

1use openvm_stark_backend::{
2    p3_field::PrimeField32,
3    p3_maybe_rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator},
4};
5use rustc_hash::FxHashMap;
6
7use super::{FinalState, MemoryMerkleCols};
8use crate::{
9    arch::hasher::{Hasher, HasherChip},
10    system::memory::{
11        dimensions::MemoryDimensions, merkle::memory_to_vec_partition, AddressMap, Equipartition,
12    },
13};
14
15#[derive(Debug)]
16pub struct MerkleTree<F, const CHUNK: usize> {
17    /// Height of the tree -- the root is the only node at height `height`,
18    /// and the leaves are at height `0`.
19    height: usize,
20    /// Nodes corresponding to all zeroes.
21    zero_nodes: Vec<[F; CHUNK]>,
22    /// Nodes in the tree that have ever been touched.
23    nodes: FxHashMap<u64, [F; CHUNK]>,
24}
25
26impl<F: PrimeField32, const CHUNK: usize> MerkleTree<F, CHUNK> {
27    pub fn new(height: usize, hasher: &impl Hasher<CHUNK, F>) -> Self {
28        Self {
29            height,
30            zero_nodes: (0..height + 1)
31                .scan(hasher.hash(&[F::ZERO; CHUNK]), |acc, _| {
32                    let result = Some(*acc);
33                    *acc = hasher.compress(acc, acc);
34                    result
35                })
36                .collect(),
37            nodes: FxHashMap::default(),
38        }
39    }
40
41    pub fn root(&self) -> [F; CHUNK] {
42        self.get_node(1)
43    }
44
45    pub fn get_node(&self, index: u64) -> [F; CHUNK] {
46        self.nodes
47            .get(&index)
48            .cloned()
49            .unwrap_or(self.zero_nodes[self.height - index.ilog2() as usize])
50    }
51
52    #[allow(clippy::type_complexity)]
53    /// Shared logic for both from_memory and finalize.
54    fn process_layers<CompressFn>(
55        &mut self,
56        layer: Vec<(u64, [F; CHUNK])>,
57        md: &MemoryDimensions,
58        mut rows: Option<&mut Vec<MemoryMerkleCols<F, CHUNK>>>,
59        compress: CompressFn,
60    ) where
61        CompressFn: Fn(&[F; CHUNK], &[F; CHUNK]) -> [F; CHUNK] + Send + Sync,
62    {
63        let mut new_entries = layer;
64        let mut layer = new_entries
65            .par_iter()
66            .map(|(index, values)| {
67                let old_values = self.nodes.get(index).unwrap_or(&self.zero_nodes[0]);
68                (*index, *values, *old_values)
69            })
70            .collect::<Vec<_>>();
71        for height in 1..=self.height {
72            let new_layer = layer
73                .iter()
74                .enumerate()
75                .filter_map(|(i, (index, values, old_values))| {
76                    if i > 0 && layer[i - 1].0 ^ 1 == *index {
77                        return None;
78                    }
79
80                    let par_index = index >> 1;
81
82                    if i + 1 < layer.len() && layer[i + 1].0 == index ^ 1 {
83                        let (_, sibling_values, sibling_old_values) = &layer[i + 1];
84                        Some((
85                            par_index,
86                            Some((values, old_values)),
87                            Some((sibling_values, sibling_old_values)),
88                        ))
89                    } else if index & 1 == 0 {
90                        Some((par_index, Some((values, old_values)), None))
91                    } else {
92                        Some((par_index, None, Some((values, old_values))))
93                    }
94                })
95                .collect::<Vec<_>>();
96
97            match rows {
98                None => {
99                    layer = new_layer
100                        .into_par_iter()
101                        .map(|(par_index, left, right)| {
102                            let left = if let Some(left) = left {
103                                left.0
104                            } else {
105                                &self.get_node(2 * par_index)
106                            };
107                            let right = if let Some(right) = right {
108                                right.0
109                            } else {
110                                &self.get_node(2 * par_index + 1)
111                            };
112                            let combined = compress(left, right);
113                            let par_old_values = self.get_node(par_index);
114                            (par_index, combined, par_old_values)
115                        })
116                        .collect();
117                }
118                Some(ref mut rows) => {
119                    let label_section_height = md.address_height.saturating_sub(height);
120                    let (tmp, new_rows): (Vec<(u64, [F; CHUNK], [F; CHUNK])>, Vec<[_; 2]>) =
121                        new_layer
122                            .into_par_iter()
123                            .map(|(par_index, left, right)| {
124                                let parent_address_label =
125                                    (par_index & ((1 << label_section_height) - 1)) as u32;
126                                let parent_as_label = ((par_index & !(1 << (self.height - height)))
127                                    >> label_section_height)
128                                    as u32;
129                                let left_node;
130                                let (left, old_left, changed_left) = match left {
131                                    Some((left, old_left)) => (left, old_left, true),
132                                    None => {
133                                        left_node = self.get_node(2 * par_index);
134                                        (&left_node, &left_node, false)
135                                    }
136                                };
137                                let right_node;
138                                let (right, old_right, changed_right) = match right {
139                                    Some((right, old_right)) => (right, old_right, true),
140                                    None => {
141                                        right_node = self.get_node(2 * par_index + 1);
142                                        (&right_node, &right_node, false)
143                                    }
144                                };
145                                let combined = compress(left, right);
146                                // This is a hacky way to say:
147                                // "and we also want to record the old values"
148                                compress(old_left, old_right);
149                                let par_old_values = self.get_node(par_index);
150                                (
151                                    (par_index, combined, par_old_values),
152                                    [
153                                        MemoryMerkleCols {
154                                            expand_direction: F::ONE,
155                                            height_section: F::from_bool(
156                                                height > md.address_height,
157                                            ),
158                                            parent_height: F::from_canonical_usize(height),
159                                            is_root: F::from_bool(height == md.overall_height()),
160                                            parent_as_label: F::from_canonical_u32(parent_as_label),
161                                            parent_address_label: F::from_canonical_u32(
162                                                parent_address_label,
163                                            ),
164                                            parent_hash: par_old_values,
165                                            left_child_hash: *old_left,
166                                            right_child_hash: *old_right,
167                                            left_direction_different: F::ZERO,
168                                            right_direction_different: F::ZERO,
169                                        },
170                                        MemoryMerkleCols {
171                                            expand_direction: F::NEG_ONE,
172                                            height_section: F::from_bool(
173                                                height > md.address_height,
174                                            ),
175                                            parent_height: F::from_canonical_usize(height),
176                                            is_root: F::from_bool(height == md.overall_height()),
177                                            parent_as_label: F::from_canonical_u32(parent_as_label),
178                                            parent_address_label: F::from_canonical_u32(
179                                                parent_address_label,
180                                            ),
181                                            parent_hash: combined,
182                                            left_child_hash: *left,
183                                            right_child_hash: *right,
184                                            left_direction_different: F::from_bool(!changed_left),
185                                            right_direction_different: F::from_bool(!changed_right),
186                                        },
187                                    ],
188                                )
189                            })
190                            .unzip();
191                    rows.extend(new_rows.into_iter().flatten());
192                    layer = tmp;
193                }
194            }
195            new_entries.extend(layer.iter().map(|(idx, values, _)| (*idx, *values)));
196        }
197
198        if self.nodes.is_empty() {
199            // This, for example, should happen in every `from_memory` call
200            self.nodes = FxHashMap::from_iter(new_entries);
201        } else {
202            self.nodes.extend(new_entries);
203        }
204    }
205
206    pub fn from_memory(
207        memory: &AddressMap,
208        md: &MemoryDimensions,
209        hasher: &(impl Hasher<CHUNK, F> + Sync),
210    ) -> Self {
211        let mut tree = Self::new(md.overall_height(), hasher);
212        let layer: Vec<_> = memory_to_vec_partition(memory, md)
213            .par_iter()
214            .map(|(idx, v)| ((1 << tree.height) + idx, hasher.hash(v)))
215            .collect();
216        tree.process_layers(layer, md, None, |left, right| hasher.compress(left, right));
217        tree
218    }
219
220    pub fn finalize(
221        &mut self,
222        hasher: &impl HasherChip<CHUNK, F>,
223        touched: &Equipartition<F, CHUNK>,
224        md: &MemoryDimensions,
225    ) -> FinalState<CHUNK, F> {
226        let init_root = self.get_node(1);
227        let layer: Vec<_> = if !touched.is_empty() {
228            touched
229                .iter()
230                .map(|((addr_sp, ptr), v)| {
231                    (
232                        (1 << self.height) + md.label_to_index((*addr_sp, *ptr / CHUNK as u32)),
233                        hasher.hash(v),
234                    )
235                })
236                .collect()
237        } else {
238            let index = 1 << self.height;
239            vec![(index, self.get_node(index))]
240        };
241        let mut rows = Vec::with_capacity(if layer.is_empty() {
242            0
243        } else {
244            layer
245                .iter()
246                .zip(layer.iter().skip(1))
247                .fold(md.overall_height(), |acc, ((lhs, _), (rhs, _))| {
248                    acc + (lhs ^ rhs).ilog2() as usize
249                })
250        });
251        self.process_layers(layer, md, Some(&mut rows), |left, right| {
252            hasher.compress_and_record(left, right)
253        });
254        if touched.is_empty() {
255            // If we made an artificial touch, we need to change the direction changes for the
256            // leaves
257            rows[1].left_direction_different = F::ONE;
258            rows[1].right_direction_different = F::ONE;
259        }
260        let final_root = self.get_node(1);
261        FinalState {
262            rows,
263            init_root,
264            final_root,
265        }
266    }
267}