openvm_circuit/system/memory/tree/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
pub mod public_values;

use std::{collections::BTreeMap, sync::Arc};

use openvm_stark_backend::p3_field::PrimeField32;
use MemoryNode::*;

use super::manager::dimensions::MemoryDimensions;
use crate::{
    arch::hasher::{Hasher, HasherChip},
    system::memory::Equipartition,
};

#[derive(Clone, Debug, PartialEq)]
pub enum MemoryNode<const CHUNK: usize, F: PrimeField32> {
    Leaf {
        values: [F; CHUNK],
    },
    NonLeaf {
        hash: [F; CHUNK],
        left: Arc<MemoryNode<CHUNK, F>>,
        right: Arc<MemoryNode<CHUNK, F>>,
    },
}

impl<const CHUNK: usize, F: PrimeField32> MemoryNode<CHUNK, F> {
    pub fn hash(&self) -> [F; CHUNK] {
        match self {
            Leaf { values: hash } => *hash,
            NonLeaf { hash, .. } => *hash,
        }
    }

    pub fn new_leaf(values: [F; CHUNK]) -> Self {
        Leaf { values }
    }

    pub fn new_nonleaf(
        left: Arc<MemoryNode<CHUNK, F>>,
        right: Arc<MemoryNode<CHUNK, F>>,
        hasher: &mut impl HasherChip<CHUNK, F>,
    ) -> Self {
        NonLeaf {
            hash: hasher.compress_and_record(&left.hash(), &right.hash()),
            left,
            right,
        }
    }

    /// Returns a tree of height `height` with all leaves set to `leaf_value`.
    pub fn construct_uniform(
        height: usize,
        leaf_value: [F; CHUNK],
        hasher: &impl Hasher<CHUNK, F>,
    ) -> MemoryNode<CHUNK, F> {
        if height == 0 {
            Self::new_leaf(leaf_value)
        } else {
            let child = Arc::new(Self::construct_uniform(height - 1, leaf_value, hasher));
            NonLeaf {
                hash: hasher.compress(&child.hash(), &child.hash()),
                left: child.clone(),
                right: child,
            }
        }
    }

    fn from_memory(
        memory: &BTreeMap<u64, [F; CHUNK]>,
        height: usize,
        from: u64,
        hasher: &impl Hasher<CHUNK, F>,
    ) -> MemoryNode<CHUNK, F> {
        let mut range = memory.range(from..from + (1 << height));
        if height == 0 {
            let values = *memory.get(&from).unwrap_or(&[F::ZERO; CHUNK]);
            MemoryNode::new_leaf(hasher.hash(&values))
        } else if range.next().is_none() {
            let leaf_value = hasher.hash(&[F::ZERO; CHUNK]);
            MemoryNode::construct_uniform(height, leaf_value, hasher)
        } else {
            let midpoint = from + (1 << (height - 1));
            let left = Self::from_memory(memory, height - 1, from, hasher);
            let right = Self::from_memory(memory, height - 1, midpoint, hasher);
            NonLeaf {
                hash: hasher.compress(&left.hash(), &right.hash()),
                left: Arc::new(left),
                right: Arc::new(right),
            }
        }
    }

    pub fn tree_from_memory(
        memory_dimensions: MemoryDimensions,
        memory: &Equipartition<F, CHUNK>,
        hasher: &impl Hasher<CHUNK, F>,
    ) -> MemoryNode<CHUNK, F> {
        // Construct a BTreeMap that includes the address space in the label calculation,
        // representing the entire memory tree.
        let mut memory_modified = BTreeMap::new();
        for (&label, &values) in memory {
            let index = memory_dimensions.label_to_index(label);
            memory_modified.insert(index, values);
        }
        Self::from_memory(
            &memory_modified,
            memory_dimensions.overall_height(),
            0,
            hasher,
        )
    }
}