openvm_circuit/system/memory/merkle/
air.rsuse std::{borrow::Borrow, iter};
use openvm_stark_backend::{
interaction::InteractionBuilder,
p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir},
p3_field::{AbstractField, Field},
p3_matrix::Matrix,
rap::{BaseAirWithPublicValues, PartitionedBaseAir},
};
use super::{DirectCompressionBus, MemoryMerkleBus};
use crate::system::memory::merkle::{MemoryDimensions, MemoryMerkleCols, MemoryMerklePvs};
#[derive(Clone, Debug)]
pub struct MemoryMerkleAir<const CHUNK: usize> {
pub memory_dimensions: MemoryDimensions,
pub merkle_bus: MemoryMerkleBus,
pub compression_bus: DirectCompressionBus,
}
impl<const CHUNK: usize, F: Field> PartitionedBaseAir<F> for MemoryMerkleAir<CHUNK> {}
impl<const CHUNK: usize, F: Field> BaseAir<F> for MemoryMerkleAir<CHUNK> {
fn width(&self) -> usize {
MemoryMerkleCols::<F, CHUNK>::width()
}
}
impl<const CHUNK: usize, F: Field> BaseAirWithPublicValues<F> for MemoryMerkleAir<CHUNK> {
fn num_public_values(&self) -> usize {
MemoryMerklePvs::<F, CHUNK>::width()
}
}
impl<const CHUNK: usize, AB: InteractionBuilder + AirBuilderWithPublicValues> Air<AB>
for MemoryMerkleAir<CHUNK>
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let (local, next) = (main.row_slice(0), main.row_slice(1));
let local: &MemoryMerkleCols<_, CHUNK> = (*local).borrow();
let next: &MemoryMerkleCols<_, CHUNK> = (*next).borrow();
builder.assert_eq(
local.expand_direction,
local.expand_direction * local.expand_direction * local.expand_direction,
);
builder.assert_bool(local.left_direction_different);
builder.assert_bool(local.right_direction_different);
builder
.when_ne(local.expand_direction, AB::F::NEG_ONE)
.assert_zero(local.left_direction_different);
builder
.when_ne(local.expand_direction, AB::F::NEG_ONE)
.assert_zero(local.right_direction_different);
builder
.when_transition()
.assert_bool(local.parent_height - next.parent_height);
builder
.when_transition()
.assert_bool(local.height_section - next.height_section);
builder
.when_transition()
.assert_bool(local.is_root - next.is_root);
builder.when_first_row().assert_one(local.height_section);
builder.when_first_row().assert_one(local.is_root);
builder.when_first_row().assert_one(next.is_root);
builder.when_last_row().assert_zero(local.height_section);
builder.when_last_row().assert_zero(local.is_root);
builder
.when_transition()
.when_ne(
local.parent_height,
AB::F::from_canonical_usize(self.memory_dimensions.address_height + 1),
)
.assert_eq(local.height_section, next.height_section);
builder
.when_transition()
.when_ne(
next.parent_height,
AB::F::from_canonical_usize(self.memory_dimensions.address_height),
)
.assert_eq(local.height_section, next.height_section);
builder
.when(local.is_root)
.when(next.is_root)
.assert_eq(local.expand_direction - next.expand_direction, AB::F::TWO);
builder.when(local.is_root).assert_eq(
local.parent_height,
AB::Expr::from_canonical_usize(self.memory_dimensions.overall_height()),
);
let &MemoryMerklePvs::<_, CHUNK> {
initial_root,
final_root,
} = builder.public_values().borrow();
for i in 0..CHUNK {
builder
.when_first_row()
.assert_eq(local.parent_hash[i], initial_root[i]);
builder
.when_first_row()
.assert_eq(next.parent_hash[i], final_root[i]);
}
self.eval_interactions(builder, local);
}
}
impl<const CHUNK: usize> MemoryMerkleAir<CHUNK> {
pub fn eval_interactions<AB: InteractionBuilder>(
&self,
builder: &mut AB,
local: &MemoryMerkleCols<AB::Var, CHUNK>,
) {
builder.push_send(
self.merkle_bus.0,
[
local.expand_direction.into(),
local.parent_height.into(),
local.parent_as_label.into(),
local.parent_address_label.into(),
]
.into_iter()
.chain(local.parent_hash.into_iter().map(Into::into)),
(AB::Expr::ONE - local.is_root) * local.expand_direction,
);
builder.push_receive(
self.merkle_bus.0,
[
local.expand_direction + (local.left_direction_different * AB::F::TWO),
local.parent_height - AB::F::ONE,
local.parent_as_label * (AB::Expr::ONE + local.height_section),
local.parent_address_label * (AB::Expr::TWO - local.height_section),
]
.into_iter()
.chain(local.left_child_hash.into_iter().map(Into::into)),
local.expand_direction.into(),
);
builder.push_receive(
self.merkle_bus.0,
[
local.expand_direction + (local.right_direction_different * AB::F::TWO),
local.parent_height - AB::F::ONE,
(local.parent_as_label * (AB::Expr::ONE + local.height_section))
+ local.height_section,
(local.parent_address_label * (AB::Expr::TWO - local.height_section))
+ (AB::Expr::ONE - local.height_section),
]
.into_iter()
.chain(local.right_child_hash.into_iter().map(Into::into)),
local.expand_direction.into(),
);
let compress_fields = iter::empty()
.chain(local.left_child_hash)
.chain(local.right_child_hash)
.chain(local.parent_hash);
builder.push_send(
self.compression_bus.0,
compress_fields,
local.expand_direction * local.expand_direction,
);
}
}