openvm_circuit/system/memory/merkle/
tree.rs

1use openvm_stark_backend::{
2    p3_field::PrimeField32,
3    p3_maybe_rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator},
4};
5use rustc_hash::FxHashMap;
6
7use super::{FinalState, MemoryMerkleCols};
8use crate::{
9    arch::hasher::{Hasher, HasherChip},
10    system::memory::{
11        dimensions::MemoryDimensions, merkle::memory_to_vec_partition, AddressMap, Equipartition,
12    },
13};
14
15#[derive(Debug)]
16pub struct MerkleTree<F, const CHUNK: usize> {
17    /// Height of the tree -- the root is the only node at height `height`,
18    /// and the leaves are at height `0`.
19    height: usize,
20    /// Nodes corresponding to all zeroes.
21    zero_nodes: Vec<[F; CHUNK]>,
22    /// Nodes in the tree that have ever been touched.
23    nodes: FxHashMap<u64, [F; CHUNK]>,
24}
25
26impl<F: PrimeField32, const CHUNK: usize> MerkleTree<F, CHUNK> {
27    pub fn new(height: usize, hasher: &impl Hasher<CHUNK, F>) -> Self {
28        Self {
29            height,
30            zero_nodes: (0..height + 1)
31                .scan(hasher.hash(&[F::ZERO; CHUNK]), |acc, _| {
32                    let result = Some(*acc);
33                    *acc = hasher.compress(acc, acc);
34                    result
35                })
36                .collect(),
37            nodes: FxHashMap::default(),
38        }
39    }
40
41    pub fn root(&self) -> [F; CHUNK] {
42        self.get_node(1)
43    }
44
45    pub fn get_node(&self, index: u64) -> [F; CHUNK] {
46        self.nodes
47            .get(&index)
48            .cloned()
49            .unwrap_or(self.zero_nodes[self.height - index.ilog2() as usize])
50    }
51
52    #[allow(clippy::type_complexity)]
53    /// Shared logic for both from_memory and finalize.
54    fn process_layers<CompressFn>(
55        &mut self,
56        layer: Vec<(u64, [F; CHUNK])>,
57        md: &MemoryDimensions,
58        mut rows: Option<&mut Vec<MemoryMerkleCols<F, CHUNK>>>,
59        compress: CompressFn,
60    ) where
61        CompressFn: Fn(&[F; CHUNK], &[F; CHUNK]) -> [F; CHUNK] + Send + Sync,
62    {
63        let mut new_entries = layer;
64        let mut layer = new_entries
65            .par_iter()
66            .map(|(index, values)| {
67                let old_values = self.nodes.get(index).unwrap_or(&self.zero_nodes[0]);
68                (*index, *values, *old_values)
69            })
70            .collect::<Vec<_>>();
71        for height in 1..=self.height {
72            let new_layer = layer
73                .iter()
74                .enumerate()
75                .filter_map(|(i, (index, values, old_values))| {
76                    if i > 0 && layer[i - 1].0 ^ 1 == *index {
77                        return None;
78                    }
79
80                    let par_index = index >> 1;
81
82                    if i + 1 < layer.len() && layer[i + 1].0 == index ^ 1 {
83                        let (_, sibling_values, sibling_old_values) = &layer[i + 1];
84                        Some((
85                            par_index,
86                            Some((values, old_values)),
87                            Some((sibling_values, sibling_old_values)),
88                        ))
89                    } else if index & 1 == 0 {
90                        Some((par_index, Some((values, old_values)), None))
91                    } else {
92                        Some((par_index, None, Some((values, old_values))))
93                    }
94                })
95                .collect::<Vec<_>>();
96
97            match rows {
98                None => {
99                    layer = new_layer
100                        .into_par_iter()
101                        .map(|(par_index, left, right)| {
102                            let left = if let Some(left) = left {
103                                left.0
104                            } else {
105                                &self.get_node(2 * par_index)
106                            };
107                            let right = if let Some(right) = right {
108                                right.0
109                            } else {
110                                &self.get_node(2 * par_index + 1)
111                            };
112                            let combined = compress(left, right);
113                            let par_old_values = self.get_node(par_index);
114                            (par_index, combined, par_old_values)
115                        })
116                        .collect();
117                }
118                Some(ref mut rows) => {
119                    let label_section_height = md.address_height.saturating_sub(height);
120                    let (tmp, new_rows): (Vec<(u64, [F; CHUNK], [F; CHUNK])>, Vec<[_; 2]>) =
121                        new_layer
122                            .into_par_iter()
123                            .map(|(par_index, left, right)| {
124                                let parent_address_label =
125                                    (par_index & ((1 << label_section_height) - 1)) as u32;
126                                let parent_as_label = ((par_index & !(1 << (self.height - height)))
127                                    >> label_section_height)
128                                    as u32;
129                                let left_node;
130                                let (left, old_left, changed_left) = match left {
131                                    Some((left, old_left)) => (left, old_left, true),
132                                    None => {
133                                        left_node = self.get_node(2 * par_index);
134                                        (&left_node, &left_node, false)
135                                    }
136                                };
137                                let right_node;
138                                let (right, old_right, changed_right) = match right {
139                                    Some((right, old_right)) => (right, old_right, true),
140                                    None => {
141                                        right_node = self.get_node(2 * par_index + 1);
142                                        (&right_node, &right_node, false)
143                                    }
144                                };
145                                let combined = compress(left, right);
146                                // This is a hacky way to say:
147                                // "and we also want to record the old values"
148                                compress(old_left, old_right);
149                                let par_old_values = self.get_node(par_index);
150                                (
151                                    (par_index, combined, par_old_values),
152                                    [
153                                        MemoryMerkleCols {
154                                            expand_direction: F::ONE,
155                                            height_section: F::from_bool(
156                                                height > md.address_height,
157                                            ),
158                                            parent_height: F::from_usize(height),
159                                            is_root: F::from_bool(height == md.overall_height()),
160                                            parent_as_label: F::from_u32(parent_as_label),
161                                            parent_address_label: F::from_u32(parent_address_label),
162                                            parent_hash: par_old_values,
163                                            left_child_hash: *old_left,
164                                            right_child_hash: *old_right,
165                                            left_direction_different: F::ZERO,
166                                            right_direction_different: F::ZERO,
167                                        },
168                                        MemoryMerkleCols {
169                                            expand_direction: F::NEG_ONE,
170                                            height_section: F::from_bool(
171                                                height > md.address_height,
172                                            ),
173                                            parent_height: F::from_usize(height),
174                                            is_root: F::from_bool(height == md.overall_height()),
175                                            parent_as_label: F::from_u32(parent_as_label),
176                                            parent_address_label: F::from_u32(parent_address_label),
177                                            parent_hash: combined,
178                                            left_child_hash: *left,
179                                            right_child_hash: *right,
180                                            left_direction_different: F::from_bool(!changed_left),
181                                            right_direction_different: F::from_bool(!changed_right),
182                                        },
183                                    ],
184                                )
185                            })
186                            .unzip();
187                    rows.extend(new_rows.into_iter().flatten());
188                    layer = tmp;
189                }
190            }
191            new_entries.extend(layer.iter().map(|(idx, values, _)| (*idx, *values)));
192        }
193
194        if self.nodes.is_empty() {
195            // This, for example, should happen in every `from_memory` call
196            self.nodes = FxHashMap::from_iter(new_entries);
197        } else {
198            self.nodes.extend(new_entries);
199        }
200    }
201
202    pub fn from_memory(
203        memory: &AddressMap,
204        md: &MemoryDimensions,
205        hasher: &(impl Hasher<CHUNK, F> + Sync),
206    ) -> Self {
207        let mut tree = Self::new(md.overall_height(), hasher);
208        let layer: Vec<_> = memory_to_vec_partition(memory, md)
209            .par_iter()
210            .map(|(idx, v)| ((1 << tree.height) + idx, hasher.hash(v)))
211            .collect();
212        tree.process_layers(layer, md, None, |left, right| hasher.compress(left, right));
213        tree
214    }
215
216    pub fn finalize(
217        &mut self,
218        hasher: &impl HasherChip<CHUNK, F>,
219        touched: &Equipartition<F, CHUNK>,
220        md: &MemoryDimensions,
221    ) -> FinalState<CHUNK, F> {
222        let init_root = self.get_node(1);
223        let layer: Vec<_> = if !touched.is_empty() {
224            touched
225                .iter()
226                .map(|((addr_sp, ptr), v)| {
227                    (
228                        (1 << self.height) + md.label_to_index((*addr_sp, *ptr / CHUNK as u32)),
229                        hasher.hash(v),
230                    )
231                })
232                .collect()
233        } else {
234            let index = 1 << self.height;
235            vec![(index, self.get_node(index))]
236        };
237        let mut rows = Vec::with_capacity(if layer.is_empty() {
238            0
239        } else {
240            layer
241                .iter()
242                .zip(layer.iter().skip(1))
243                .fold(md.overall_height(), |acc, ((lhs, _), (rhs, _))| {
244                    acc + (lhs ^ rhs).ilog2() as usize
245                })
246        });
247        self.process_layers(layer, md, Some(&mut rows), |left, right| {
248            hasher.compress_and_record(left, right)
249        });
250        if touched.is_empty() {
251            // If we made an artificial touch, we need to change the direction changes for the
252            // leaves
253            rows[1].left_direction_different = F::ONE;
254            rows[1].right_direction_different = F::ONE;
255        }
256        let final_root = self.get_node(1);
257        FinalState {
258            rows,
259            init_root,
260            final_root,
261        }
262    }
263
264    pub fn top_tree(&self, top_height: usize) -> Vec<[F; CHUNK]> {
265        // tree root is at index 1
266        (0..(2 << top_height) - 1)
267            .map(|i| self.get_node(i + 1))
268            .collect()
269    }
270}