openvm_circuit/system/memory/merkle/
tree.rs1use openvm_stark_backend::{
2 p3_field::PrimeField32,
3 p3_maybe_rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator},
4};
5use rustc_hash::FxHashMap;
6
7use super::{FinalState, MemoryMerkleCols};
8use crate::{
9 arch::hasher::{Hasher, HasherChip},
10 system::memory::{
11 dimensions::MemoryDimensions, merkle::memory_to_vec_partition, AddressMap, Equipartition,
12 },
13};
14
15#[derive(Debug)]
16pub struct MerkleTree<F, const CHUNK: usize> {
17 height: usize,
20 zero_nodes: Vec<[F; CHUNK]>,
22 nodes: FxHashMap<u64, [F; CHUNK]>,
24}
25
26impl<F: PrimeField32, const CHUNK: usize> MerkleTree<F, CHUNK> {
27 pub fn new(height: usize, hasher: &impl Hasher<CHUNK, F>) -> Self {
28 Self {
29 height,
30 zero_nodes: (0..height + 1)
31 .scan(hasher.hash(&[F::ZERO; CHUNK]), |acc, _| {
32 let result = Some(*acc);
33 *acc = hasher.compress(acc, acc);
34 result
35 })
36 .collect(),
37 nodes: FxHashMap::default(),
38 }
39 }
40
41 pub fn root(&self) -> [F; CHUNK] {
42 self.get_node(1)
43 }
44
45 pub fn get_node(&self, index: u64) -> [F; CHUNK] {
46 self.nodes
47 .get(&index)
48 .cloned()
49 .unwrap_or(self.zero_nodes[self.height - index.ilog2() as usize])
50 }
51
52 #[allow(clippy::type_complexity)]
53 fn process_layers<CompressFn>(
55 &mut self,
56 layer: Vec<(u64, [F; CHUNK])>,
57 md: &MemoryDimensions,
58 mut rows: Option<&mut Vec<MemoryMerkleCols<F, CHUNK>>>,
59 compress: CompressFn,
60 ) where
61 CompressFn: Fn(&[F; CHUNK], &[F; CHUNK]) -> [F; CHUNK] + Send + Sync,
62 {
63 let mut new_entries = layer;
64 let mut layer = new_entries
65 .par_iter()
66 .map(|(index, values)| {
67 let old_values = self.nodes.get(index).unwrap_or(&self.zero_nodes[0]);
68 (*index, *values, *old_values)
69 })
70 .collect::<Vec<_>>();
71 for height in 1..=self.height {
72 let new_layer = layer
73 .iter()
74 .enumerate()
75 .filter_map(|(i, (index, values, old_values))| {
76 if i > 0 && layer[i - 1].0 ^ 1 == *index {
77 return None;
78 }
79
80 let par_index = index >> 1;
81
82 if i + 1 < layer.len() && layer[i + 1].0 == index ^ 1 {
83 let (_, sibling_values, sibling_old_values) = &layer[i + 1];
84 Some((
85 par_index,
86 Some((values, old_values)),
87 Some((sibling_values, sibling_old_values)),
88 ))
89 } else if index & 1 == 0 {
90 Some((par_index, Some((values, old_values)), None))
91 } else {
92 Some((par_index, None, Some((values, old_values))))
93 }
94 })
95 .collect::<Vec<_>>();
96
97 match rows {
98 None => {
99 layer = new_layer
100 .into_par_iter()
101 .map(|(par_index, left, right)| {
102 let left = if let Some(left) = left {
103 left.0
104 } else {
105 &self.get_node(2 * par_index)
106 };
107 let right = if let Some(right) = right {
108 right.0
109 } else {
110 &self.get_node(2 * par_index + 1)
111 };
112 let combined = compress(left, right);
113 let par_old_values = self.get_node(par_index);
114 (par_index, combined, par_old_values)
115 })
116 .collect();
117 }
118 Some(ref mut rows) => {
119 let label_section_height = md.address_height.saturating_sub(height);
120 let (tmp, new_rows): (Vec<(u64, [F; CHUNK], [F; CHUNK])>, Vec<[_; 2]>) =
121 new_layer
122 .into_par_iter()
123 .map(|(par_index, left, right)| {
124 let parent_address_label =
125 (par_index & ((1 << label_section_height) - 1)) as u32;
126 let parent_as_label = ((par_index & !(1 << (self.height - height)))
127 >> label_section_height)
128 as u32;
129 let left_node;
130 let (left, old_left, changed_left) = match left {
131 Some((left, old_left)) => (left, old_left, true),
132 None => {
133 left_node = self.get_node(2 * par_index);
134 (&left_node, &left_node, false)
135 }
136 };
137 let right_node;
138 let (right, old_right, changed_right) = match right {
139 Some((right, old_right)) => (right, old_right, true),
140 None => {
141 right_node = self.get_node(2 * par_index + 1);
142 (&right_node, &right_node, false)
143 }
144 };
145 let combined = compress(left, right);
146 compress(old_left, old_right);
149 let par_old_values = self.get_node(par_index);
150 (
151 (par_index, combined, par_old_values),
152 [
153 MemoryMerkleCols {
154 expand_direction: F::ONE,
155 height_section: F::from_bool(
156 height > md.address_height,
157 ),
158 parent_height: F::from_canonical_usize(height),
159 is_root: F::from_bool(height == md.overall_height()),
160 parent_as_label: F::from_canonical_u32(parent_as_label),
161 parent_address_label: F::from_canonical_u32(
162 parent_address_label,
163 ),
164 parent_hash: par_old_values,
165 left_child_hash: *old_left,
166 right_child_hash: *old_right,
167 left_direction_different: F::ZERO,
168 right_direction_different: F::ZERO,
169 },
170 MemoryMerkleCols {
171 expand_direction: F::NEG_ONE,
172 height_section: F::from_bool(
173 height > md.address_height,
174 ),
175 parent_height: F::from_canonical_usize(height),
176 is_root: F::from_bool(height == md.overall_height()),
177 parent_as_label: F::from_canonical_u32(parent_as_label),
178 parent_address_label: F::from_canonical_u32(
179 parent_address_label,
180 ),
181 parent_hash: combined,
182 left_child_hash: *left,
183 right_child_hash: *right,
184 left_direction_different: F::from_bool(!changed_left),
185 right_direction_different: F::from_bool(!changed_right),
186 },
187 ],
188 )
189 })
190 .unzip();
191 rows.extend(new_rows.into_iter().flatten());
192 layer = tmp;
193 }
194 }
195 new_entries.extend(layer.iter().map(|(idx, values, _)| (*idx, *values)));
196 }
197
198 if self.nodes.is_empty() {
199 self.nodes = FxHashMap::from_iter(new_entries);
201 } else {
202 self.nodes.extend(new_entries);
203 }
204 }
205
206 pub fn from_memory(
207 memory: &AddressMap,
208 md: &MemoryDimensions,
209 hasher: &(impl Hasher<CHUNK, F> + Sync),
210 ) -> Self {
211 let mut tree = Self::new(md.overall_height(), hasher);
212 let layer: Vec<_> = memory_to_vec_partition(memory, md)
213 .par_iter()
214 .map(|(idx, v)| ((1 << tree.height) + idx, hasher.hash(v)))
215 .collect();
216 tree.process_layers(layer, md, None, |left, right| hasher.compress(left, right));
217 tree
218 }
219
220 pub fn finalize(
221 &mut self,
222 hasher: &impl HasherChip<CHUNK, F>,
223 touched: &Equipartition<F, CHUNK>,
224 md: &MemoryDimensions,
225 ) -> FinalState<CHUNK, F> {
226 let init_root = self.get_node(1);
227 let layer: Vec<_> = if !touched.is_empty() {
228 touched
229 .iter()
230 .map(|((addr_sp, ptr), v)| {
231 (
232 (1 << self.height) + md.label_to_index((*addr_sp, *ptr / CHUNK as u32)),
233 hasher.hash(v),
234 )
235 })
236 .collect()
237 } else {
238 let index = 1 << self.height;
239 vec![(index, self.get_node(index))]
240 };
241 let mut rows = Vec::with_capacity(if layer.is_empty() {
242 0
243 } else {
244 layer
245 .iter()
246 .zip(layer.iter().skip(1))
247 .fold(md.overall_height(), |acc, ((lhs, _), (rhs, _))| {
248 acc + (lhs ^ rhs).ilog2() as usize
249 })
250 });
251 self.process_layers(layer, md, Some(&mut rows), |left, right| {
252 hasher.compress_and_record(left, right)
253 });
254 if touched.is_empty() {
255 rows[1].left_direction_different = F::ONE;
258 rows[1].right_direction_different = F::ONE;
259 }
260 let final_root = self.get_node(1);
261 FinalState {
262 rows,
263 init_root,
264 final_root,
265 }
266 }
267}