openvm_circuit/system/memory/merkle/
air.rs

1use std::{borrow::Borrow, iter};
2
3use openvm_stark_backend::{
4    interaction::{InteractionBuilder, PermutationCheckBus},
5    p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir},
6    p3_field::{Field, PrimeCharacteristicRing},
7    p3_matrix::Matrix,
8    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
9};
10
11use crate::system::memory::merkle::{MemoryDimensions, MemoryMerkleCols, MemoryMerklePvs};
12
13#[derive(Clone, Debug)]
14pub struct MemoryMerkleAir<const CHUNK: usize> {
15    pub memory_dimensions: MemoryDimensions,
16    pub merkle_bus: PermutationCheckBus,
17    pub compression_bus: PermutationCheckBus,
18}
19
20impl<const CHUNK: usize, F: Field> PartitionedBaseAir<F> for MemoryMerkleAir<CHUNK> {}
21impl<const CHUNK: usize, F: Field> BaseAir<F> for MemoryMerkleAir<CHUNK> {
22    fn width(&self) -> usize {
23        MemoryMerkleCols::<F, CHUNK>::width()
24    }
25}
26impl<const CHUNK: usize, F: Field> BaseAirWithPublicValues<F> for MemoryMerkleAir<CHUNK> {
27    fn num_public_values(&self) -> usize {
28        MemoryMerklePvs::<F, CHUNK>::width()
29    }
30}
31
32impl<const CHUNK: usize, AB: InteractionBuilder + AirBuilderWithPublicValues> Air<AB>
33    for MemoryMerkleAir<CHUNK>
34{
35    fn eval(&self, builder: &mut AB) {
36        let main = builder.main();
37        let (local, next) = (
38            main.row_slice(0).expect("window should have two elements"),
39            main.row_slice(1).expect("window should have two elements"),
40        );
41        let local: &MemoryMerkleCols<_, CHUNK> = (*local).borrow();
42        let next: &MemoryMerkleCols<_, CHUNK> = (*next).borrow();
43
44        // `expand_direction` should be -1, 0, 1
45        builder.assert_eq(
46            local.expand_direction,
47            local.expand_direction * local.expand_direction * local.expand_direction,
48        );
49
50        builder.assert_bool(local.left_direction_different);
51        builder.assert_bool(local.right_direction_different);
52
53        // if `expand_direction` != -1, then `*_direction_different` should be 0
54        builder
55            .when_ne(local.expand_direction, AB::F::NEG_ONE)
56            .assert_zero(local.left_direction_different);
57        builder
58            .when_ne(local.expand_direction, AB::F::NEG_ONE)
59            .assert_zero(local.right_direction_different);
60
61        // rows should be sorted in descending order
62        // independently by `parent_height`, `height_section`, `is_root`
63        builder
64            .when_transition()
65            .assert_bool(local.parent_height - next.parent_height);
66        builder
67            .when_transition()
68            .assert_bool(local.height_section - next.height_section);
69        builder
70            .when_transition()
71            .assert_bool(local.is_root - next.is_root);
72
73        // row with greatest height should have `height_section` = 1
74        builder.when_first_row().assert_one(local.height_section);
75        // two rows with greatest height should have `is_root` = 1
76        builder.when_first_row().assert_one(local.is_root);
77        builder.when_first_row().assert_one(next.is_root);
78        // row with least height should have `height_section` = 0, `is_root` = 0
79        builder.when_last_row().assert_zero(local.height_section);
80        builder.when_last_row().assert_zero(local.is_root);
81        // `height_section` changes from 0 to 1 only when `parent_height` changes from
82        // `address_height` to `address_height` + 1
83        builder
84            .when_transition()
85            .when_ne(
86                local.parent_height,
87                AB::F::from_usize(self.memory_dimensions.address_height + 1),
88            )
89            .assert_eq(local.height_section, next.height_section);
90        builder
91            .when_transition()
92            .when_ne(
93                next.parent_height,
94                AB::F::from_usize(self.memory_dimensions.address_height),
95            )
96            .assert_eq(local.height_section, next.height_section);
97        // two adjacent rows with `is_root` = 1 should have
98        // the first `expand_direction` = 1, the second `expand_direction` = -1
99        builder
100            .when(local.is_root)
101            .when(next.is_root)
102            .assert_eq(local.expand_direction - next.expand_direction, AB::F::TWO);
103
104        // roots should have correct height
105        builder.when(local.is_root).assert_eq(
106            local.parent_height,
107            AB::Expr::from_usize(self.memory_dimensions.overall_height()),
108        );
109
110        // constrain public values
111        let &MemoryMerklePvs::<_, CHUNK> {
112            initial_root,
113            final_root,
114        } = builder.public_values().borrow();
115        for i in 0..CHUNK {
116            builder
117                .when_first_row()
118                .assert_eq(local.parent_hash[i], initial_root[i]);
119            builder
120                .when_first_row()
121                .assert_eq(next.parent_hash[i], final_root[i]);
122        }
123
124        self.eval_interactions(builder, local);
125    }
126}
127
128impl<const CHUNK: usize> MemoryMerkleAir<CHUNK> {
129    pub fn eval_interactions<AB: InteractionBuilder>(
130        &self,
131        builder: &mut AB,
132        local: &MemoryMerkleCols<AB::Var, CHUNK>,
133    ) {
134        // interaction does not occur for first two rows;
135        // for those, parent hash value comes from public values
136        self.merkle_bus.interact(
137            builder,
138            [
139                local.expand_direction.into(),
140                local.parent_height.into(),
141                local.parent_as_label.into(),
142                local.parent_address_label.into(),
143            ]
144            .into_iter()
145            .chain(local.parent_hash.into_iter().map(Into::into)),
146            // count can probably be made degree 1 if necessary
147            (AB::Expr::ONE - local.is_root) * local.expand_direction,
148        );
149
150        self.merkle_bus.interact(
151            builder,
152            [
153                local.expand_direction + (local.left_direction_different * AB::F::TWO),
154                local.parent_height - AB::F::ONE,
155                local.parent_as_label * (AB::Expr::ONE + local.height_section),
156                local.parent_address_label * (AB::Expr::TWO - local.height_section),
157            ]
158            .into_iter()
159            .chain(local.left_child_hash.into_iter().map(Into::into)),
160            -local.expand_direction.into(),
161        );
162
163        self.merkle_bus.interact(
164            builder,
165            [
166                local.expand_direction + (local.right_direction_different * AB::F::TWO),
167                local.parent_height - AB::F::ONE,
168                (local.parent_as_label * (AB::Expr::ONE + local.height_section))
169                    + local.height_section,
170                (local.parent_address_label * (AB::Expr::TWO - local.height_section))
171                    + (AB::Expr::ONE - local.height_section),
172            ]
173            .into_iter()
174            .chain(local.right_child_hash.into_iter().map(Into::into)),
175            -local.expand_direction.into(),
176        );
177
178        let compress_fields = iter::empty()
179            .chain(local.left_child_hash)
180            .chain(local.right_child_hash)
181            .chain(local.parent_hash);
182        self.compression_bus.interact(
183            builder,
184            compress_fields,
185            local.expand_direction * local.expand_direction,
186        );
187    }
188}