use core::borrow::Borrow;
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::AbstractField;
use p3_matrix::Matrix;
use crate::columns::{KeccakCols, NUM_KECCAK_COLS};
use crate::constants::rc_value_bit;
use crate::logic::{andn_gen, xor3_gen, xor_gen};
use crate::round_flags::eval_round_flags;
use crate::{BITS_PER_LIMB, NUM_ROUNDS, U64_LIMBS};
#[derive(Debug)]
pub struct KeccakAir {}
impl<F> BaseAir<F> for KeccakAir {
fn width(&self) -> usize {
NUM_KECCAK_COLS
}
}
impl<AB: AirBuilder> Air<AB> for KeccakAir {
#[inline]
fn eval(&self, builder: &mut AB) {
eval_round_flags(builder);
let main = builder.main();
let (local, next) = (main.row_slice(0), main.row_slice(1));
let local: &KeccakCols<AB::Var> = (*local).borrow();
let next: &KeccakCols<AB::Var> = (*next).borrow();
let first_step = local.step_flags[0];
let final_step = local.step_flags[NUM_ROUNDS - 1];
let not_final_step = AB::Expr::ONE - final_step;
for y in 0..5 {
for x in 0..5 {
for limb in 0..U64_LIMBS {
builder
.when(first_step)
.assert_eq(local.preimage[y][x][limb], local.a[y][x][limb]);
}
}
}
builder.assert_bool(local.export);
builder
.when(not_final_step.clone())
.assert_zero(local.export);
for y in 0..5 {
for x in 0..5 {
for limb in 0..U64_LIMBS {
builder
.when(not_final_step.clone())
.when_transition()
.assert_eq(local.preimage[y][x][limb], next.preimage[y][x][limb]);
}
}
}
for x in 0..5 {
for z in 0..64 {
builder.assert_bool(local.c[x][z]);
let xor = xor3_gen::<AB::Expr>(
local.c[x][z].into(),
local.c[(x + 4) % 5][z].into(),
local.c[(x + 1) % 5][(z + 63) % 64].into(),
);
let c_prime = local.c_prime[x][z];
builder.assert_eq(c_prime, xor);
}
}
for y in 0..5 {
for x in 0..5 {
let get_bit = |z| {
let a_prime: AB::Var = local.a_prime[y][x][z];
let c: AB::Var = local.c[x][z];
let c_prime: AB::Var = local.c_prime[x][z];
xor3_gen::<AB::Expr>(a_prime.into(), c.into(), c_prime.into())
};
for limb in 0..U64_LIMBS {
let a_limb = local.a[y][x][limb];
let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
.rev()
.fold(AB::Expr::ZERO, |acc, z| {
builder.assert_bool(local.a_prime[y][x][z]);
acc.double() + get_bit(z)
});
builder.assert_eq(computed_limb, a_limb);
}
}
}
for x in 0..5 {
for z in 0..64 {
let sum: AB::Expr = (0..5).map(|y| local.a_prime[y][x][z].into()).sum();
let diff = sum - local.c_prime[x][z];
let four = AB::Expr::from_canonical_u8(4);
builder.assert_zero(diff.clone() * (diff.clone() - AB::Expr::TWO) * (diff - four));
}
}
for y in 0..5 {
for x in 0..5 {
let get_bit = |z| {
let andn = andn_gen::<AB::Expr>(
local.b((x + 1) % 5, y, z).into(),
local.b((x + 2) % 5, y, z).into(),
);
xor_gen::<AB::Expr>(local.b(x, y, z).into(), andn)
};
for limb in 0..U64_LIMBS {
let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
.rev()
.fold(AB::Expr::ZERO, |acc, z| acc.double() + get_bit(z));
builder.assert_eq(computed_limb, local.a_prime_prime[y][x][limb]);
}
}
}
for limb in 0..U64_LIMBS {
let computed_a_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
..(limb + 1) * BITS_PER_LIMB)
.rev()
.fold(AB::Expr::ZERO, |acc, z| {
builder.assert_bool(local.a_prime_prime_0_0_bits[z]);
acc.double() + local.a_prime_prime_0_0_bits[z]
});
let a_prime_prime_0_0_limb = local.a_prime_prime[0][0][limb];
builder.assert_eq(computed_a_prime_prime_0_0_limb, a_prime_prime_0_0_limb);
}
let get_xored_bit = |i| {
let mut rc_bit_i = AB::Expr::ZERO;
for r in 0..NUM_ROUNDS {
let this_round = local.step_flags[r];
let this_round_constant = AB::Expr::from_canonical_u8(rc_value_bit(r, i));
rc_bit_i += this_round * this_round_constant;
}
xor_gen::<AB::Expr>(local.a_prime_prime_0_0_bits[i].into(), rc_bit_i)
};
for limb in 0..U64_LIMBS {
let a_prime_prime_prime_0_0_limb = local.a_prime_prime_prime_0_0_limbs[limb];
let computed_a_prime_prime_prime_0_0_limb = (limb * BITS_PER_LIMB
..(limb + 1) * BITS_PER_LIMB)
.rev()
.fold(AB::Expr::ZERO, |acc, z| acc.double() + get_xored_bit(z));
builder.assert_eq(
computed_a_prime_prime_prime_0_0_limb,
a_prime_prime_prime_0_0_limb,
);
}
for x in 0..5 {
for y in 0..5 {
for limb in 0..U64_LIMBS {
let output = local.a_prime_prime_prime(y, x, limb);
let input = next.a[y][x][limb];
builder
.when_transition()
.when(not_final_step.clone())
.assert_eq(output, input);
}
}
}
}
}