openvm_circuit/system/memory/tree/
mod.rs
1pub mod public_values;
2
3use std::{ops::Range, sync::Arc};
4
5use openvm_stark_backend::{p3_field::PrimeField32, p3_maybe_rayon::prelude::*};
6use MemoryNode::*;
7
8use super::controller::dimensions::MemoryDimensions;
9use crate::{
10 arch::hasher::{Hasher, HasherChip},
11 system::memory::MemoryImage,
12};
13
14#[derive(Clone, Debug, PartialEq)]
15pub enum MemoryNode<const CHUNK: usize, F: PrimeField32> {
16 Leaf {
17 values: [F; CHUNK],
18 },
19 NonLeaf {
20 hash: [F; CHUNK],
21 left: Arc<MemoryNode<CHUNK, F>>,
22 right: Arc<MemoryNode<CHUNK, F>>,
23 },
24}
25
26impl<const CHUNK: usize, F: PrimeField32> MemoryNode<CHUNK, F> {
27 pub fn hash(&self) -> [F; CHUNK] {
28 match self {
29 Leaf { values: hash } => *hash,
30 NonLeaf { hash, .. } => *hash,
31 }
32 }
33
34 pub fn new_leaf(values: [F; CHUNK]) -> Self {
35 Leaf { values }
36 }
37
38 pub fn new_nonleaf(
39 left: Arc<MemoryNode<CHUNK, F>>,
40 right: Arc<MemoryNode<CHUNK, F>>,
41 hasher: &mut impl HasherChip<CHUNK, F>,
42 ) -> Self {
43 NonLeaf {
44 hash: hasher.compress_and_record(&left.hash(), &right.hash()),
45 left,
46 right,
47 }
48 }
49
50 pub fn construct_uniform(
52 height: usize,
53 leaf_value: [F; CHUNK],
54 hasher: &impl Hasher<CHUNK, F>,
55 ) -> MemoryNode<CHUNK, F> {
56 if height == 0 {
57 Self::new_leaf(leaf_value)
58 } else {
59 let child = Arc::new(Self::construct_uniform(height - 1, leaf_value, hasher));
60 NonLeaf {
61 hash: hasher.compress(&child.hash(), &child.hash()),
62 left: child.clone(),
63 right: child,
64 }
65 }
66 }
67
68 fn from_memory(
69 memory: &[(u64, F)],
70 lookup_range: Range<usize>,
71 length: u64,
72 from: u64,
73 hasher: &(impl Hasher<CHUNK, F> + Sync),
74 zero_leaf: &MemoryNode<CHUNK, F>,
75 ) -> MemoryNode<CHUNK, F> {
76 if length == CHUNK as u64 {
77 if lookup_range.is_empty() {
78 zero_leaf.clone()
79 } else {
80 debug_assert_eq!(memory[lookup_range.start].0, from);
81 let mut values = [F::ZERO; CHUNK];
82 for (index, value) in memory[lookup_range].iter() {
83 values[(index % CHUNK as u64) as usize] = *value;
84 }
85 MemoryNode::new_leaf(hasher.hash(&values))
86 }
87 } else if lookup_range.is_empty() {
88 let leaf_value = hasher.hash(&[F::ZERO; CHUNK]);
89 MemoryNode::construct_uniform(
90 (length / CHUNK as u64).trailing_zeros() as usize,
91 leaf_value,
92 hasher,
93 )
94 } else {
95 let midpoint = from + length / 2;
96 let mid = {
97 let mut left = lookup_range.start;
98 let mut right = lookup_range.end;
99 if memory[left].0 >= midpoint {
100 left
101 } else {
102 while left + 1 < right {
103 let mid = left + (right - left) / 2;
104 if memory[mid].0 < midpoint {
105 left = mid;
106 } else {
107 right = mid;
108 }
109 }
110 right
111 }
112 };
113 let (left, right) = join(
114 || {
115 Self::from_memory(
116 memory,
117 lookup_range.start..mid,
118 length >> 1,
119 from,
120 hasher,
121 zero_leaf,
122 )
123 },
124 || {
125 Self::from_memory(
126 memory,
127 mid..lookup_range.end,
128 length >> 1,
129 midpoint,
130 hasher,
131 zero_leaf,
132 )
133 },
134 );
135 NonLeaf {
136 hash: hasher.compress(&left.hash(), &right.hash()),
137 left: Arc::new(left),
138 right: Arc::new(right),
139 }
140 }
141 }
142
143 pub fn tree_from_memory(
144 memory_dimensions: MemoryDimensions,
145 memory: &MemoryImage<F>,
146 hasher: &(impl Hasher<CHUNK, F> + Sync),
147 ) -> MemoryNode<CHUNK, F> {
148 let memory_items = memory
151 .items()
152 .filter(|((_, ptr), _)| *ptr as usize / CHUNK < (1 << memory_dimensions.address_height))
153 .map(|((address_space, pointer), value)| {
154 (
155 memory_dimensions.label_to_index((address_space, pointer / CHUNK as u32))
156 * CHUNK as u64
157 + (pointer % CHUNK as u32) as u64,
158 value,
159 )
160 })
161 .collect::<Vec<_>>();
162 debug_assert!(memory_items.is_sorted_by_key(|(addr, _)| addr));
163 debug_assert!(
164 memory_items.last().map_or(0, |(addr, _)| *addr)
165 < ((CHUNK as u64) << memory_dimensions.overall_height())
166 );
167 let zero_leaf = MemoryNode::new_leaf(hasher.hash(&[F::ZERO; CHUNK]));
168 Self::from_memory(
169 &memory_items,
170 0..memory_items.len(),
171 (CHUNK as u64) << memory_dimensions.overall_height(),
172 0,
173 hasher,
174 &zero_leaf,
175 )
176 }
177}