openvm_circuit/system/memory/tree/
mod.rs

1pub mod public_values;
2
3use std::{ops::Range, sync::Arc};
4
5use openvm_stark_backend::{p3_field::PrimeField32, p3_maybe_rayon::prelude::*};
6use MemoryNode::*;
7
8use super::controller::dimensions::MemoryDimensions;
9use crate::{
10    arch::hasher::{Hasher, HasherChip},
11    system::memory::MemoryImage,
12};
13
14#[derive(Clone, Debug, PartialEq)]
15pub enum MemoryNode<const CHUNK: usize, F: PrimeField32> {
16    Leaf {
17        values: [F; CHUNK],
18    },
19    NonLeaf {
20        hash: [F; CHUNK],
21        left: Arc<MemoryNode<CHUNK, F>>,
22        right: Arc<MemoryNode<CHUNK, F>>,
23    },
24}
25
26impl<const CHUNK: usize, F: PrimeField32> MemoryNode<CHUNK, F> {
27    pub fn hash(&self) -> [F; CHUNK] {
28        match self {
29            Leaf { values: hash } => *hash,
30            NonLeaf { hash, .. } => *hash,
31        }
32    }
33
34    pub fn new_leaf(values: [F; CHUNK]) -> Self {
35        Leaf { values }
36    }
37
38    pub fn new_nonleaf(
39        left: Arc<MemoryNode<CHUNK, F>>,
40        right: Arc<MemoryNode<CHUNK, F>>,
41        hasher: &mut impl HasherChip<CHUNK, F>,
42    ) -> Self {
43        NonLeaf {
44            hash: hasher.compress_and_record(&left.hash(), &right.hash()),
45            left,
46            right,
47        }
48    }
49
50    /// Returns a tree of height `height` with all leaves set to `leaf_value`.
51    pub fn construct_uniform(
52        height: usize,
53        leaf_value: [F; CHUNK],
54        hasher: &impl Hasher<CHUNK, F>,
55    ) -> MemoryNode<CHUNK, F> {
56        if height == 0 {
57            Self::new_leaf(leaf_value)
58        } else {
59            let child = Arc::new(Self::construct_uniform(height - 1, leaf_value, hasher));
60            NonLeaf {
61                hash: hasher.compress(&child.hash(), &child.hash()),
62                left: child.clone(),
63                right: child,
64            }
65        }
66    }
67
68    fn from_memory(
69        memory: &[(u64, F)],
70        lookup_range: Range<usize>,
71        length: u64,
72        from: u64,
73        hasher: &(impl Hasher<CHUNK, F> + Sync),
74        zero_leaf: &MemoryNode<CHUNK, F>,
75    ) -> MemoryNode<CHUNK, F> {
76        if length == CHUNK as u64 {
77            if lookup_range.is_empty() {
78                zero_leaf.clone()
79            } else {
80                debug_assert_eq!(memory[lookup_range.start].0, from);
81                let mut values = [F::ZERO; CHUNK];
82                for (index, value) in memory[lookup_range].iter() {
83                    values[(index % CHUNK as u64) as usize] = *value;
84                }
85                MemoryNode::new_leaf(hasher.hash(&values))
86            }
87        } else if lookup_range.is_empty() {
88            let leaf_value = hasher.hash(&[F::ZERO; CHUNK]);
89            MemoryNode::construct_uniform(
90                (length / CHUNK as u64).trailing_zeros() as usize,
91                leaf_value,
92                hasher,
93            )
94        } else {
95            let midpoint = from + length / 2;
96            let mid = {
97                let mut left = lookup_range.start;
98                let mut right = lookup_range.end;
99                if memory[left].0 >= midpoint {
100                    left
101                } else {
102                    while left + 1 < right {
103                        let mid = left + (right - left) / 2;
104                        if memory[mid].0 < midpoint {
105                            left = mid;
106                        } else {
107                            right = mid;
108                        }
109                    }
110                    right
111                }
112            };
113            let (left, right) = join(
114                || {
115                    Self::from_memory(
116                        memory,
117                        lookup_range.start..mid,
118                        length >> 1,
119                        from,
120                        hasher,
121                        zero_leaf,
122                    )
123                },
124                || {
125                    Self::from_memory(
126                        memory,
127                        mid..lookup_range.end,
128                        length >> 1,
129                        midpoint,
130                        hasher,
131                        zero_leaf,
132                    )
133                },
134            );
135            NonLeaf {
136                hash: hasher.compress(&left.hash(), &right.hash()),
137                left: Arc::new(left),
138                right: Arc::new(right),
139            }
140        }
141    }
142
143    pub fn tree_from_memory(
144        memory_dimensions: MemoryDimensions,
145        memory: &MemoryImage<F>,
146        hasher: &(impl Hasher<CHUNK, F> + Sync),
147    ) -> MemoryNode<CHUNK, F> {
148        // Construct a Vec that includes the address space in the label calculation,
149        // representing the entire memory tree.
150        let memory_items = memory
151            .items()
152            .filter(|((_, ptr), _)| *ptr as usize / CHUNK < (1 << memory_dimensions.address_height))
153            .map(|((address_space, pointer), value)| {
154                (
155                    memory_dimensions.label_to_index((address_space, pointer / CHUNK as u32))
156                        * CHUNK as u64
157                        + (pointer % CHUNK as u32) as u64,
158                    value,
159                )
160            })
161            .collect::<Vec<_>>();
162        debug_assert!(memory_items.is_sorted_by_key(|(addr, _)| addr));
163        debug_assert!(
164            memory_items.last().map_or(0, |(addr, _)| *addr)
165                < ((CHUNK as u64) << memory_dimensions.overall_height())
166        );
167        let zero_leaf = MemoryNode::new_leaf(hasher.hash(&[F::ZERO; CHUNK]));
168        Self::from_memory(
169            &memory_items,
170            0..memory_items.len(),
171            (CHUNK as u64) << memory_dimensions.overall_height(),
172            0,
173            hasher,
174            &zero_leaf,
175        )
176    }
177}