use std::{array, borrow::Borrow, cmp::min};
use openvm_circuit::{
arch::ExecutionBridge,
system::memory::{offline_checker::MemoryBridge, MemoryAddress},
};
use openvm_circuit_primitives::{
bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir,
};
use openvm_instructions::riscv::{
RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS,
};
use openvm_sha256_air::{
compose, Sha256Air, SHA256_BLOCK_U8S, SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW,
SHA256_WORD_U16S, SHA256_WORD_U8S,
};
use openvm_sha256_transpiler::Rv32Sha256Opcode;
use openvm_stark_backend::{
interaction::InteractionBuilder,
p3_air::{Air, AirBuilder, BaseAir},
p3_field::{Field, FieldAlgebra},
p3_matrix::Matrix,
rap::{BaseAirWithPublicValues, PartitionedBaseAir},
};
use super::{
Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH,
SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE,
};
#[derive(Clone, Debug, derive_new::new)]
pub struct Sha256VmAir {
pub execution_bridge: ExecutionBridge,
pub memory_bridge: MemoryBridge,
pub bitwise_lookup_bus: BitwiseOperationLookupBus,
pub ptr_max_bits: usize,
pub(super) offset: usize,
pub(super) sha256_subair: Sha256Air,
pub(super) padding_encoder: Encoder,
}
impl<F: Field> BaseAirWithPublicValues<F> for Sha256VmAir {}
impl<F: Field> PartitionedBaseAir<F> for Sha256VmAir {}
impl<F: Field> BaseAir<F> for Sha256VmAir {
fn width(&self) -> usize {
SHA256VM_WIDTH
}
}
impl<AB: InteractionBuilder> Air<AB> for Sha256VmAir {
fn eval(&self, builder: &mut AB) {
self.eval_padding(builder);
self.eval_transitions(builder);
self.eval_reads(builder);
self.eval_last_row(builder);
self.sha256_subair.eval(builder, SHA256VM_CONTROL_WIDTH);
}
}
#[allow(dead_code, non_camel_case_types)]
pub(super) enum PaddingFlags {
NotConsidered,
NotPadding,
FirstPadding0,
FirstPadding1,
FirstPadding2,
FirstPadding3,
FirstPadding4,
FirstPadding5,
FirstPadding6,
FirstPadding7,
FirstPadding8,
FirstPadding9,
FirstPadding10,
FirstPadding11,
FirstPadding12,
FirstPadding13,
FirstPadding14,
FirstPadding15,
FirstPadding0_LastRow,
FirstPadding1_LastRow,
FirstPadding2_LastRow,
FirstPadding3_LastRow,
FirstPadding4_LastRow,
FirstPadding5_LastRow,
FirstPadding6_LastRow,
FirstPadding7_LastRow,
EntirePaddingLastRow,
EntirePadding,
}
impl PaddingFlags {
pub const COUNT: usize = EntirePadding as usize;
}
use PaddingFlags::*;
impl Sha256VmAir {
fn eval_padding<AB: InteractionBuilder>(&self, builder: &mut AB) {
let main = builder.main();
let (local, next) = (main.row_slice(0), main.row_slice(1));
let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
self.padding_encoder
.eval(builder, &local_cols.control.pad_flags);
builder.assert_one(self.padding_encoder.contains_flag_range::<AB>(
&local_cols.control.pad_flags,
NotConsidered as usize..=EntirePadding as usize,
));
Self::eval_padding_transitions(self, builder, local_cols, next_cols);
Self::eval_padding_row(self, builder, local_cols);
}
fn eval_padding_transitions<AB: InteractionBuilder>(
&self,
builder: &mut AB,
local: &Sha256VmRoundCols<AB::Var>,
next: &Sha256VmRoundCols<AB::Var>,
) {
let next_is_lastest_row = next.inner.flags.is_digest_row * next.inner.flags.is_last_block;
builder.assert_bool(local.control.padding_occurred);
builder
.when(next_is_lastest_row.clone())
.assert_one(local.control.padding_occurred);
builder
.when(next_is_lastest_row.clone())
.assert_zero(next.control.padding_occurred);
builder
.when(local.control.padding_occurred - next_is_lastest_row.clone())
.assert_one(next.control.padding_occurred);
builder
.when_transition()
.when(not(next.inner.flags.is_first_4_rows) - next_is_lastest_row)
.assert_eq(
next.control.padding_occurred,
local.control.padding_occurred,
);
let next_is_first_padding_row =
next.control.padding_occurred - local.control.padding_occurred;
let next_row_idx = self.sha256_subair.row_idx_encoder.flag_with_val::<AB>(
&next.inner.flags.row_idx,
&(0..4).map(|x| (x, x)).collect::<Vec<_>>(),
);
let next_padding_offset = self.padding_encoder.flag_with_val::<AB>(
&next.control.pad_flags,
&(0..16)
.map(|i| (FirstPadding0 as usize + i, i))
.collect::<Vec<_>>(),
) + self.padding_encoder.flag_with_val::<AB>(
&next.control.pad_flags,
&(0..8)
.map(|i| (FirstPadding0_LastRow as usize + i, i))
.collect::<Vec<_>>(),
);
let expected_len = next.inner.flags.local_block_idx
* next.control.padding_occurred
* AB::Expr::from_canonical_usize(SHA256_BLOCK_U8S)
+ next_row_idx * AB::Expr::from_canonical_usize(SHA256_READ_SIZE)
+ next_padding_offset;
builder.when(next_is_first_padding_row).assert_eq(
expected_len,
next.control.len * next.control.padding_occurred,
);
let is_next_first_padding = self.padding_encoder.contains_flag_range::<AB>(
&next.control.pad_flags,
FirstPadding0 as usize..=FirstPadding7_LastRow as usize,
);
let is_next_last_padding = self.padding_encoder.contains_flag_range::<AB>(
&next.control.pad_flags,
FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
);
let is_next_entire_padding = self.padding_encoder.contains_flag_range::<AB>(
&next.control.pad_flags,
EntirePaddingLastRow as usize..=EntirePadding as usize,
);
let is_next_not_considered = self
.padding_encoder
.contains_flag::<AB>(&next.control.pad_flags, &[NotConsidered as usize]);
let is_next_not_padding = self
.padding_encoder
.contains_flag::<AB>(&next.control.pad_flags, &[NotPadding as usize]);
let is_next_4th_row = self
.sha256_subair
.row_idx_encoder
.contains_flag::<AB>(&next.inner.flags.row_idx, &[3]);
builder.assert_eq(
not(next.inner.flags.is_first_4_rows),
is_next_not_considered,
);
builder.when(next.inner.flags.is_first_4_rows).assert_eq(
local.control.padding_occurred * next.control.padding_occurred,
is_next_entire_padding,
);
builder.when(next.inner.flags.is_first_4_rows).assert_eq(
not(local.control.padding_occurred) * next.control.padding_occurred,
is_next_first_padding,
);
builder
.when(next.inner.flags.is_first_4_rows)
.assert_eq(not(next.control.padding_occurred), is_next_not_padding);
builder
.when(next.inner.flags.is_last_block)
.assert_eq(is_next_4th_row, is_next_last_padding);
}
fn eval_padding_row<AB: InteractionBuilder>(
&self,
builder: &mut AB,
local: &Sha256VmRoundCols<AB::Var>,
) {
let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
local.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U8S)]
[i % (SHA256_WORD_U8S)]
});
let get_ith_byte = |i: usize| {
let word_idx = i / SHA256_ROUNDS_PER_ROW;
let word = local.inner.message_schedule.w[word_idx].map(|x| x.into());
let byte_idx = 4 - i % 4 - 1;
compose::<AB::Expr>(&word[byte_idx * 8..(byte_idx + 1) * 8], 1)
};
let is_not_padding = self
.padding_encoder
.contains_flag::<AB>(&local.control.pad_flags, &[NotPadding as usize]);
for (i, message_byte) in message.iter().enumerate() {
let w = get_ith_byte(i);
let should_be_message = is_not_padding.clone()
+ if i < 15 {
self.padding_encoder.contains_flag_range::<AB>(
&local.control.pad_flags,
FirstPadding0 as usize + i + 1..=FirstPadding15 as usize,
)
} else {
AB::Expr::ZERO
}
+ if i < 7 {
self.padding_encoder.contains_flag_range::<AB>(
&local.control.pad_flags,
FirstPadding0_LastRow as usize + i + 1..=FirstPadding7_LastRow as usize,
)
} else {
AB::Expr::ZERO
};
builder
.when(should_be_message)
.assert_eq(w.clone(), *message_byte);
let should_be_zero = self
.padding_encoder
.contains_flag::<AB>(&local.control.pad_flags, &[EntirePadding as usize])
+ if i < 12 {
self.padding_encoder.contains_flag::<AB>(
&local.control.pad_flags,
&[EntirePaddingLastRow as usize],
) + if i > 0 {
self.padding_encoder.contains_flag_range::<AB>(
&local.control.pad_flags,
FirstPadding0_LastRow as usize
..=min(
FirstPadding0_LastRow as usize + i - 1,
FirstPadding7_LastRow as usize,
),
)
} else {
AB::Expr::ZERO
}
} else {
AB::Expr::ZERO
}
+ if i > 0 {
self.padding_encoder.contains_flag_range::<AB>(
&local.control.pad_flags,
FirstPadding0 as usize..=FirstPadding0 as usize + i - 1,
)
} else {
AB::Expr::ZERO
};
builder.when(should_be_zero).assert_zero(w.clone());
let should_be_128 = self
.padding_encoder
.contains_flag::<AB>(&local.control.pad_flags, &[FirstPadding0 as usize + i])
+ if i < 8 {
self.padding_encoder.contains_flag::<AB>(
&local.control.pad_flags,
&[FirstPadding0_LastRow as usize + i],
)
} else {
AB::Expr::ZERO
};
builder
.when(should_be_128)
.assert_eq(AB::Expr::from_canonical_u32(1 << 7), w);
}
let appended_len = compose::<AB::Expr>(
&[
get_ith_byte(15),
get_ith_byte(14),
get_ith_byte(13),
get_ith_byte(12),
],
RV32_CELL_BITS,
);
let actual_len = local.control.len;
let is_last_padding_row = self.padding_encoder.contains_flag_range::<AB>(
&local.control.pad_flags,
FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize,
);
builder.when(is_last_padding_row.clone()).assert_eq(
appended_len * AB::F::from_canonical_usize(RV32_CELL_BITS).inverse(), actual_len,
);
builder.when(is_last_padding_row.clone()).assert_zero(
local.inner.message_schedule.w[3][0] + local.inner.message_schedule.w[3][1],
);
for i in 8..12 {
builder
.when(is_last_padding_row.clone())
.assert_zero(get_ith_byte(i));
}
}
fn eval_transitions<AB: InteractionBuilder>(&self, builder: &mut AB) {
let main = builder.main();
let (local, next) = (main.row_slice(0), main.row_slice(1));
let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
let next_cols: &Sha256VmRoundCols<AB::Var> = next[..SHA256VM_ROUND_WIDTH].borrow();
let is_last_row =
local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
builder
.when_transition()
.when(not::<AB::Expr>(is_last_row.clone()))
.assert_eq(next_cols.control.len, local_cols.control.len);
let read_ptr_delta = local_cols.inner.flags.is_first_4_rows
* AB::Expr::from_canonical_usize(SHA256_READ_SIZE);
builder
.when_transition()
.when(not::<AB::Expr>(is_last_row.clone()))
.assert_eq(
next_cols.control.read_ptr,
local_cols.control.read_ptr + read_ptr_delta,
);
let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE;
builder
.when_transition()
.when(not::<AB::Expr>(is_last_row.clone()))
.assert_eq(
next_cols.control.cur_timestamp,
local_cols.control.cur_timestamp + timestamp_delta,
);
}
fn eval_reads<AB: InteractionBuilder>(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local_cols: &Sha256VmRoundCols<AB::Var> = local[..SHA256VM_ROUND_WIDTH].borrow();
let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| {
local_cols.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U16S * 2)]
[i % (SHA256_WORD_U16S * 2)]
});
self.memory_bridge
.read(
MemoryAddress::new(
AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
local_cols.control.read_ptr,
),
message,
local_cols.control.cur_timestamp,
&local_cols.read_aux,
)
.eval(builder, local_cols.inner.flags.is_first_4_rows);
}
fn eval_last_row<AB: InteractionBuilder>(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local_cols: &Sha256VmDigestCols<AB::Var> = local[..SHA256VM_DIGEST_WIDTH].borrow();
let timestamp: AB::Var = local_cols.from_state.timestamp;
let mut timestamp_delta: usize = 0;
let mut timestamp_pp = || {
timestamp_delta += 1;
timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
};
let is_last_row =
local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row;
self.memory_bridge
.read(
MemoryAddress::new(
AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
local_cols.rd_ptr,
),
local_cols.dst_ptr,
timestamp_pp(),
&local_cols.register_reads_aux[0],
)
.eval(builder, is_last_row.clone());
self.memory_bridge
.read(
MemoryAddress::new(
AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
local_cols.rs1_ptr,
),
local_cols.src_ptr,
timestamp_pp(),
&local_cols.register_reads_aux[1],
)
.eval(builder, is_last_row.clone());
self.memory_bridge
.read(
MemoryAddress::new(
AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
local_cols.rs2_ptr,
),
local_cols.len_data,
timestamp_pp(),
&local_cols.register_reads_aux[2],
)
.eval(builder, is_last_row.clone());
let shift = AB::Expr::from_canonical_usize(
1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits),
);
self.bitwise_lookup_bus
.send_range(
local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(),
)
.eval(builder, is_last_row.clone());
let time_delta = (local_cols.inner.flags.local_block_idx + AB::Expr::ONE)
* AB::Expr::from_canonical_usize(4);
let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(SHA256_READ_SIZE);
let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| {
local_cols.inner.final_hash[i / SHA256_WORD_U8S]
[SHA256_WORD_U8S - i % SHA256_WORD_U8S - 1]
});
let dst_ptr_val =
compose::<AB::Expr>(&local_cols.dst_ptr.map(|x| x.into()), RV32_CELL_BITS);
self.memory_bridge
.write(
MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), dst_ptr_val),
result,
timestamp_pp() + time_delta.clone(),
&local_cols.writes_aux,
)
.eval(builder, is_last_row.clone());
self.execution_bridge
.execute_and_increment_pc(
AB::Expr::from_canonical_usize(Rv32Sha256Opcode::SHA256 as usize + self.offset),
[
local_cols.rd_ptr.into(),
local_cols.rs1_ptr.into(),
local_cols.rs2_ptr.into(),
AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
],
local_cols.from_state,
AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(),
)
.eval(builder, is_last_row.clone());
let len_val = compose::<AB::Expr>(&local_cols.len_data.map(|x| x.into()), RV32_CELL_BITS);
builder
.when(is_last_row.clone())
.assert_eq(local_cols.control.len, len_val);
let src_val = compose::<AB::Expr>(&local_cols.src_ptr.map(|x| x.into()), RV32_CELL_BITS);
builder
.when(is_last_row.clone())
.assert_eq(local_cols.control.read_ptr, src_val + read_ptr_delta);
builder.when(is_last_row.clone()).assert_eq(
local_cols.control.cur_timestamp,
local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta,
);
}
}