openvm_circuit/system/memory/adapter/
air.rs

1use std::{borrow::Borrow, mem::size_of};
2
3use openvm_circuit_primitives::{
4    is_less_than::{IsLessThanIo, IsLtSubAir},
5    SubAir,
6};
7use openvm_stark_backend::{
8    interaction::InteractionBuilder,
9    p3_air::{Air, AirBuilder, BaseAir},
10    p3_field::FieldAlgebra,
11    p3_matrix::Matrix,
12    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
13};
14
15use crate::system::memory::{
16    adapter::columns::AccessAdapterCols, offline_checker::MemoryBus, MemoryAddress,
17};
18
19#[derive(Clone, Debug)]
20pub struct AccessAdapterAir<const N: usize> {
21    pub memory_bus: MemoryBus,
22    pub lt_air: IsLtSubAir,
23}
24
25impl<T, const N: usize> BaseAirWithPublicValues<T> for AccessAdapterAir<N> {}
26impl<T, const N: usize> PartitionedBaseAir<T> for AccessAdapterAir<N> {}
27impl<T, const N: usize> BaseAir<T> for AccessAdapterAir<N> {
28    fn width(&self) -> usize {
29        size_of::<AccessAdapterCols<u8, N>>()
30    }
31}
32
33impl<const N: usize, AB: InteractionBuilder> Air<AB> for AccessAdapterAir<N> {
34    fn eval(&self, builder: &mut AB) {
35        let main = builder.main();
36
37        let local = main.row_slice(0);
38        let local: &AccessAdapterCols<AB::Var, N> = (*local).borrow();
39
40        builder.assert_bool(local.is_split);
41        builder.assert_bool(local.is_valid);
42        builder.assert_bool(local.is_right_larger);
43
44        // Timestamp constraints:
45        // - if `is_split`, then all timestamps are equal.
46        // - if `is_merge`, then parent_timestamp = max(left_timestamp, right_timestamp)
47
48        builder
49            .when(local.is_split)
50            .assert_eq(local.left_timestamp, local.right_timestamp);
51
52        self.lt_air.eval(
53            builder,
54            (
55                IsLessThanIo {
56                    x: local.left_timestamp.into(),
57                    y: local.right_timestamp.into(),
58                    out: local.is_right_larger.into(),
59                    count: local.is_valid.into(),
60                },
61                &local.lt_aux,
62            ),
63        );
64
65        let parent_timestamp = local.is_right_larger * local.right_timestamp
66            + (AB::Expr::ONE - local.is_right_larger) * local.left_timestamp;
67
68        // assuming valid:
69        // Split = 1 => direction = 1 => receive parent with count 1, send left/right with count 1
70        // Split = 0 => direction = -1 => receive parent with count -1, send left/right with count -1
71        let direction = local.is_valid * (AB::Expr::TWO * local.is_split - AB::Expr::ONE);
72
73        self.memory_bus
74            .receive(local.address, local.values.to_vec(), parent_timestamp)
75            .eval(builder, direction.clone());
76
77        self.memory_bus
78            .send(
79                local.address,
80                local.values[..N / 2].to_vec(),
81                local.left_timestamp,
82            )
83            .eval(builder, direction.clone());
84
85        self.memory_bus
86            .send(
87                MemoryAddress::new(
88                    local.address.address_space,
89                    local.address.pointer + AB::Expr::from_canonical_usize(N / 2),
90                ),
91                local.values[N / 2..].to_vec(),
92                local.right_timestamp,
93            )
94            .eval(builder, direction.clone());
95    }
96}