openvm_circuit/system/memory/merkle/
air.rs

1use std::{borrow::Borrow, iter};
2
3use openvm_stark_backend::{
4    interaction::{InteractionBuilder, PermutationCheckBus},
5    p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir},
6    p3_field::{Field, FieldAlgebra},
7    p3_matrix::Matrix,
8    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
9};
10
11use crate::system::memory::merkle::{MemoryDimensions, MemoryMerkleCols, MemoryMerklePvs};
12
13#[derive(Clone, Debug)]
14pub struct MemoryMerkleAir<const CHUNK: usize> {
15    pub memory_dimensions: MemoryDimensions,
16    pub merkle_bus: PermutationCheckBus,
17    pub compression_bus: PermutationCheckBus,
18}
19
20impl<const CHUNK: usize, F: Field> PartitionedBaseAir<F> for MemoryMerkleAir<CHUNK> {}
21impl<const CHUNK: usize, F: Field> BaseAir<F> for MemoryMerkleAir<CHUNK> {
22    fn width(&self) -> usize {
23        MemoryMerkleCols::<F, CHUNK>::width()
24    }
25}
26impl<const CHUNK: usize, F: Field> BaseAirWithPublicValues<F> for MemoryMerkleAir<CHUNK> {
27    fn num_public_values(&self) -> usize {
28        MemoryMerklePvs::<F, CHUNK>::width()
29    }
30}
31
32impl<const CHUNK: usize, AB: InteractionBuilder + AirBuilderWithPublicValues> Air<AB>
33    for MemoryMerkleAir<CHUNK>
34{
35    fn eval(&self, builder: &mut AB) {
36        let main = builder.main();
37        let (local, next) = (main.row_slice(0), main.row_slice(1));
38        let local: &MemoryMerkleCols<_, CHUNK> = (*local).borrow();
39        let next: &MemoryMerkleCols<_, CHUNK> = (*next).borrow();
40
41        // `expand_direction` should be -1, 0, 1
42        builder.assert_eq(
43            local.expand_direction,
44            local.expand_direction * local.expand_direction * local.expand_direction,
45        );
46
47        builder.assert_bool(local.left_direction_different);
48        builder.assert_bool(local.right_direction_different);
49
50        // if `expand_direction` != -1, then `*_direction_different` should be 0
51        builder
52            .when_ne(local.expand_direction, AB::F::NEG_ONE)
53            .assert_zero(local.left_direction_different);
54        builder
55            .when_ne(local.expand_direction, AB::F::NEG_ONE)
56            .assert_zero(local.right_direction_different);
57
58        // rows should be sorted in descending order
59        // independently by `parent_height`, `height_section`, `is_root`
60        builder
61            .when_transition()
62            .assert_bool(local.parent_height - next.parent_height);
63        builder
64            .when_transition()
65            .assert_bool(local.height_section - next.height_section);
66        builder
67            .when_transition()
68            .assert_bool(local.is_root - next.is_root);
69
70        // row with greatest height should have `height_section` = 1
71        builder.when_first_row().assert_one(local.height_section);
72        // two rows with greatest height should have `is_root` = 1
73        builder.when_first_row().assert_one(local.is_root);
74        builder.when_first_row().assert_one(next.is_root);
75        // row with least height should have `height_section` = 0, `is_root` = 0
76        builder.when_last_row().assert_zero(local.height_section);
77        builder.when_last_row().assert_zero(local.is_root);
78        // `height_section` changes from 0 to 1 only when `parent_height` changes from `address_height` to `address_height` + 1
79        builder
80            .when_transition()
81            .when_ne(
82                local.parent_height,
83                AB::F::from_canonical_usize(self.memory_dimensions.address_height + 1),
84            )
85            .assert_eq(local.height_section, next.height_section);
86        builder
87            .when_transition()
88            .when_ne(
89                next.parent_height,
90                AB::F::from_canonical_usize(self.memory_dimensions.address_height),
91            )
92            .assert_eq(local.height_section, next.height_section);
93        // two adjacent rows with `is_root` = 1 should have
94        // the first `expand_direction` = 1, the second `expand_direction` = -1
95        builder
96            .when(local.is_root)
97            .when(next.is_root)
98            .assert_eq(local.expand_direction - next.expand_direction, AB::F::TWO);
99
100        // roots should have correct height
101        builder.when(local.is_root).assert_eq(
102            local.parent_height,
103            AB::Expr::from_canonical_usize(self.memory_dimensions.overall_height()),
104        );
105
106        // constrain public values
107        let &MemoryMerklePvs::<_, CHUNK> {
108            initial_root,
109            final_root,
110        } = builder.public_values().borrow();
111        for i in 0..CHUNK {
112            builder
113                .when_first_row()
114                .assert_eq(local.parent_hash[i], initial_root[i]);
115            builder
116                .when_first_row()
117                .assert_eq(next.parent_hash[i], final_root[i]);
118        }
119
120        self.eval_interactions(builder, local);
121    }
122}
123
124impl<const CHUNK: usize> MemoryMerkleAir<CHUNK> {
125    pub fn eval_interactions<AB: InteractionBuilder>(
126        &self,
127        builder: &mut AB,
128        local: &MemoryMerkleCols<AB::Var, CHUNK>,
129    ) {
130        // interaction does not occur for first two rows;
131        // for those, parent hash value comes from public values
132        self.merkle_bus.interact(
133            builder,
134            [
135                local.expand_direction.into(),
136                local.parent_height.into(),
137                local.parent_as_label.into(),
138                local.parent_address_label.into(),
139            ]
140            .into_iter()
141            .chain(local.parent_hash.into_iter().map(Into::into)),
142            // count can probably be made degree 1 if necessary
143            (AB::Expr::ONE - local.is_root) * local.expand_direction,
144        );
145
146        self.merkle_bus.interact(
147            builder,
148            [
149                local.expand_direction + (local.left_direction_different * AB::F::TWO),
150                local.parent_height - AB::F::ONE,
151                local.parent_as_label * (AB::Expr::ONE + local.height_section),
152                local.parent_address_label * (AB::Expr::TWO - local.height_section),
153            ]
154            .into_iter()
155            .chain(local.left_child_hash.into_iter().map(Into::into)),
156            -local.expand_direction.into(),
157        );
158
159        self.merkle_bus.interact(
160            builder,
161            [
162                local.expand_direction + (local.right_direction_different * AB::F::TWO),
163                local.parent_height - AB::F::ONE,
164                (local.parent_as_label * (AB::Expr::ONE + local.height_section))
165                    + local.height_section,
166                (local.parent_address_label * (AB::Expr::TWO - local.height_section))
167                    + (AB::Expr::ONE - local.height_section),
168            ]
169            .into_iter()
170            .chain(local.right_child_hash.into_iter().map(Into::into)),
171            -local.expand_direction.into(),
172        );
173
174        let compress_fields = iter::empty()
175            .chain(local.left_child_hash)
176            .chain(local.right_child_hash)
177            .chain(local.parent_hash);
178        self.compression_bus.interact(
179            builder,
180            compress_fields,
181            local.expand_direction * local.expand_direction,
182        );
183    }
184}