use core::borrow::Borrow;
use core::marker::PhantomData;
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::{AbstractField, Field};
use p3_matrix::Matrix;
use p3_poseidon2::GenericPoseidon2LinearLayers;
use crate::columns::{num_cols, Poseidon2Cols};
use crate::constants::RoundConstants;
use crate::{FullRound, PartialRound, SBox};
#[derive(Debug)]
pub struct Poseidon2Air<
F: Field,
LinearLayers,
const WIDTH: usize,
const SBOX_DEGREE: u64,
const SBOX_REGISTERS: usize,
const HALF_FULL_ROUNDS: usize,
const PARTIAL_ROUNDS: usize,
> {
constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
_phantom: PhantomData<LinearLayers>,
}
impl<
F: Field,
LinearLayers,
const WIDTH: usize,
const SBOX_DEGREE: u64,
const SBOX_REGISTERS: usize,
const HALF_FULL_ROUNDS: usize,
const PARTIAL_ROUNDS: usize,
>
Poseidon2Air<
F,
LinearLayers,
WIDTH,
SBOX_DEGREE,
SBOX_REGISTERS,
HALF_FULL_ROUNDS,
PARTIAL_ROUNDS,
>
{
pub fn new(constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>) -> Self {
Self {
constants,
_phantom: PhantomData,
}
}
}
impl<
F: Field,
LinearLayers: Sync,
const WIDTH: usize,
const SBOX_DEGREE: u64,
const SBOX_REGISTERS: usize,
const HALF_FULL_ROUNDS: usize,
const PARTIAL_ROUNDS: usize,
> BaseAir<F>
for Poseidon2Air<
F,
LinearLayers,
WIDTH,
SBOX_DEGREE,
SBOX_REGISTERS,
HALF_FULL_ROUNDS,
PARTIAL_ROUNDS,
>
{
fn width(&self) -> usize {
num_cols::<WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>()
}
}
pub(crate) fn eval<
AB: AirBuilder,
LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
const WIDTH: usize,
const SBOX_DEGREE: u64,
const SBOX_REGISTERS: usize,
const HALF_FULL_ROUNDS: usize,
const PARTIAL_ROUNDS: usize,
>(
air: &Poseidon2Air<
AB::F,
LinearLayers,
WIDTH,
SBOX_DEGREE,
SBOX_REGISTERS,
HALF_FULL_ROUNDS,
PARTIAL_ROUNDS,
>,
builder: &mut AB,
local: &Poseidon2Cols<
AB::Var,
WIDTH,
SBOX_DEGREE,
SBOX_REGISTERS,
HALF_FULL_ROUNDS,
PARTIAL_ROUNDS,
>,
) {
let mut state: [AB::Expr; WIDTH] = local.inputs.map(|x| x.into());
LinearLayers::external_linear_layer(&mut state);
for round in 0..HALF_FULL_ROUNDS {
eval_full_round::<AB, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
&mut state,
&local.beginning_full_rounds[round],
&air.constants.beginning_full_round_constants[round],
builder,
);
}
for round in 0..PARTIAL_ROUNDS {
eval_partial_round::<AB, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
&mut state,
&local.partial_rounds[round],
&air.constants.partial_round_constants[round],
builder,
);
}
for round in 0..HALF_FULL_ROUNDS {
eval_full_round::<AB, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
&mut state,
&local.ending_full_rounds[round],
&air.constants.ending_full_round_constants[round],
builder,
);
}
}
impl<
AB: AirBuilder,
LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
const WIDTH: usize,
const SBOX_DEGREE: u64,
const SBOX_REGISTERS: usize,
const HALF_FULL_ROUNDS: usize,
const PARTIAL_ROUNDS: usize,
> Air<AB>
for Poseidon2Air<
AB::F,
LinearLayers,
WIDTH,
SBOX_DEGREE,
SBOX_REGISTERS,
HALF_FULL_ROUNDS,
PARTIAL_ROUNDS,
>
{
#[inline]
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &Poseidon2Cols<
AB::Var,
WIDTH,
SBOX_DEGREE,
SBOX_REGISTERS,
HALF_FULL_ROUNDS,
PARTIAL_ROUNDS,
> = (*local).borrow();
eval::<
AB,
LinearLayers,
WIDTH,
SBOX_DEGREE,
SBOX_REGISTERS,
HALF_FULL_ROUNDS,
PARTIAL_ROUNDS,
>(self, builder, local);
}
}
#[inline]
fn eval_full_round<
AB: AirBuilder,
LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
const WIDTH: usize,
const SBOX_DEGREE: u64,
const SBOX_REGISTERS: usize,
>(
state: &mut [AB::Expr; WIDTH],
full_round: &FullRound<AB::Var, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
round_constants: &[AB::F; WIDTH],
builder: &mut AB,
) {
for (i, (s, r)) in state.iter_mut().zip(round_constants.iter()).enumerate() {
*s = s.clone() + *r;
eval_sbox(&full_round.sbox[i], s, builder);
}
LinearLayers::external_linear_layer(state);
for (state_i, post_i) in state.iter_mut().zip(full_round.post) {
builder.assert_eq(state_i.clone(), post_i);
*state_i = post_i.into();
}
}
#[inline]
fn eval_partial_round<
AB: AirBuilder,
LinearLayers: GenericPoseidon2LinearLayers<AB::Expr, WIDTH>,
const WIDTH: usize,
const SBOX_DEGREE: u64,
const SBOX_REGISTERS: usize,
>(
state: &mut [AB::Expr; WIDTH],
partial_round: &PartialRound<AB::Var, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
round_constant: &AB::F,
builder: &mut AB,
) {
state[0] = state[0].clone() + *round_constant;
eval_sbox(&partial_round.sbox, &mut state[0], builder);
builder.assert_eq(state[0].clone(), partial_round.post_sbox);
state[0] = partial_round.post_sbox.into();
LinearLayers::internal_linear_layer(state);
}
#[inline]
fn eval_sbox<AB, const DEGREE: u64, const REGISTERS: usize>(
sbox: &SBox<AB::Var, DEGREE, REGISTERS>,
x: &mut AB::Expr,
builder: &mut AB,
) where
AB: AirBuilder,
{
*x = match (DEGREE, REGISTERS) {
(3, 0) => x.cube(),
(5, 0) => x.exp_const_u64::<5>(),
(7, 0) => x.exp_const_u64::<7>(),
(5, 1) => {
let committed_x3 = sbox.0[0].into();
let x2 = x.square();
builder.assert_eq(committed_x3.clone(), x2.clone() * x.clone());
committed_x3 * x2
}
(7, 1) => {
let committed_x3 = sbox.0[0].into();
builder.assert_eq(committed_x3.clone(), x.cube());
committed_x3.square() * x.clone()
}
(11, 2) => {
let committed_x3 = sbox.0[0].into();
let committed_x9 = sbox.0[1].into();
let x2 = x.square();
builder.assert_eq(committed_x3.clone(), x2.clone() * x.clone());
builder.assert_eq(committed_x9.clone(), committed_x3.cube());
committed_x9 * x2
}
_ => panic!(
"Unexpected (DEGREE, REGISTERS) of ({}, {})",
DEGREE, REGISTERS
),
}
}