use alloc::vec::Vec;
use p3_field::AbstractField;
use p3_mds::MdsPermutation;
use p3_symmetric::Permutation;
use rand::distributions::{Distribution, Standard};
use rand::Rng;
#[inline(always)]
fn apply_hl_mat4<AF>(x: &mut [AF; 4])
where
AF: AbstractField,
{
let t0 = x[0].clone() + x[1].clone();
let t1 = x[2].clone() + x[3].clone();
let t2 = x[1].clone() + x[1].clone() + t1.clone();
let t3 = x[3].clone() + x[3].clone() + t0.clone();
let t4 = t1.double().double() + t3.clone();
let t5 = t0.double().double() + t2.clone();
let t6 = t3 + t5.clone();
let t7 = t2 + t4.clone();
x[0] = t6;
x[1] = t5;
x[2] = t7;
x[3] = t4;
}
#[inline(always)]
fn apply_mat4<AF>(x: &mut [AF; 4])
where
AF: AbstractField,
{
let t01 = x[0].clone() + x[1].clone();
let t23 = x[2].clone() + x[3].clone();
let t0123 = t01.clone() + t23.clone();
let t01123 = t0123.clone() + x[1].clone();
let t01233 = t0123.clone() + x[3].clone();
x[3] = t01233.clone() + x[0].double(); x[1] = t01123.clone() + x[2].double(); x[0] = t01123 + t01; x[2] = t01233 + t23; }
#[derive(Clone, Default)]
pub struct HLMDSMat4;
impl<AF: AbstractField> Permutation<[AF; 4]> for HLMDSMat4 {
#[inline(always)]
fn permute(&self, input: [AF; 4]) -> [AF; 4] {
let mut output = input;
self.permute_mut(&mut output);
output
}
#[inline(always)]
fn permute_mut(&self, input: &mut [AF; 4]) {
apply_hl_mat4(input)
}
}
impl<AF: AbstractField> MdsPermutation<AF, 4> for HLMDSMat4 {}
#[derive(Clone, Default)]
pub struct MDSMat4;
impl<AF: AbstractField> Permutation<[AF; 4]> for MDSMat4 {
#[inline(always)]
fn permute(&self, input: [AF; 4]) -> [AF; 4] {
let mut output = input;
self.permute_mut(&mut output);
output
}
#[inline(always)]
fn permute_mut(&self, input: &mut [AF; 4]) {
apply_mat4(input)
}
}
impl<AF: AbstractField> MdsPermutation<AF, 4> for MDSMat4 {}
#[inline(always)]
pub fn mds_light_permutation<
AF: AbstractField,
MdsPerm4: MdsPermutation<AF, 4>,
const WIDTH: usize,
>(
state: &mut [AF; WIDTH],
mdsmat: &MdsPerm4,
) {
match WIDTH {
2 => {
let sum = state[0].clone() + state[1].clone();
state[0] += sum.clone();
state[1] += sum;
}
3 => {
let sum = state[0].clone() + state[1].clone() + state[2].clone();
state[0] += sum.clone();
state[1] += sum.clone();
state[2] += sum;
}
4 | 8 | 12 | 16 | 20 | 24 => {
for chunk in state.chunks_exact_mut(4) {
mdsmat.permute_mut(chunk.try_into().unwrap());
}
let sums: [AF; 4] = core::array::from_fn(|k| {
(0..WIDTH)
.step_by(4)
.map(|j| state[j + k].clone())
.sum::<AF>()
});
state
.iter_mut()
.enumerate()
.for_each(|(i, elem)| *elem += sums[i % 4].clone());
}
_ => {
panic!("Unsupported width");
}
}
}
#[derive(Debug, Clone)]
pub struct ExternalLayerConstants<T, const WIDTH: usize> {
initial: Vec<[T; WIDTH]>,
terminal: Vec<[T; WIDTH]>, }
impl<T, const WIDTH: usize> ExternalLayerConstants<T, WIDTH> {
pub fn new(initial: Vec<[T; WIDTH]>, terminal: Vec<[T; WIDTH]>) -> Self {
assert_eq!(
initial.len(),
terminal.len(),
"The number of initial and terminal external rounds should be equal."
);
Self { initial, terminal }
}
pub fn new_from_rng<R: Rng>(external_round_number: usize, rng: &mut R) -> Self
where
Standard: Distribution<[T; WIDTH]>,
{
let half_f = external_round_number / 2;
assert_eq!(
2 * half_f,
external_round_number,
"The total number of external rounds should be even"
);
let initial_constants = rng.sample_iter(Standard).take(half_f).collect();
let terminal_constants = rng.sample_iter(Standard).take(half_f).collect();
Self::new(initial_constants, terminal_constants)
}
pub fn new_from_saved_array<U, const N: usize>(
[initial, terminal]: [[[U; WIDTH]; N]; 2],
conversion_fn: fn([U; WIDTH]) -> [T; WIDTH],
) -> Self
where
T: Clone,
{
let initial_consts = initial.map(conversion_fn).to_vec();
let terminal_consts = terminal.map(conversion_fn).to_vec();
Self::new(initial_consts, terminal_consts)
}
pub fn get_initial_constants(&self) -> &Vec<[T; WIDTH]> {
&self.initial
}
pub fn get_terminal_constants(&self) -> &Vec<[T; WIDTH]> {
&self.terminal
}
}
pub trait ExternalLayerConstructor<AF, const WIDTH: usize>
where
AF: AbstractField,
{
fn new_from_constants(external_constants: ExternalLayerConstants<AF::F, WIDTH>) -> Self;
}
pub trait ExternalLayer<AF, const WIDTH: usize, const D: u64>: Sync + Clone
where
AF: AbstractField,
{
fn permute_state_initial(&self, state: &mut [AF; WIDTH]);
fn permute_state_terminal(&self, state: &mut [AF; WIDTH]);
}
#[inline]
pub fn external_terminal_permute_state<
AF: AbstractField,
CT: Copy, MdsPerm4: MdsPermutation<AF, 4>,
const WIDTH: usize,
>(
state: &mut [AF; WIDTH],
terminal_external_constants: &[[CT; WIDTH]],
add_rc_and_sbox: fn(&mut AF, CT),
mat4: &MdsPerm4,
) {
for elem in terminal_external_constants.iter() {
state
.iter_mut()
.zip(elem.iter())
.for_each(|(s, &rc)| add_rc_and_sbox(s, rc));
mds_light_permutation(state, mat4);
}
}
#[inline]
pub fn external_initial_permute_state<
AF: AbstractField,
CT: Copy, MdsPerm4: MdsPermutation<AF, 4>,
const WIDTH: usize,
>(
state: &mut [AF; WIDTH],
initial_external_constants: &[[CT; WIDTH]],
add_rc_and_sbox: fn(&mut AF, CT),
mat4: &MdsPerm4,
) {
mds_light_permutation(state, mat4);
external_terminal_permute_state(state, initial_external_constants, add_rc_and_sbox, mat4)
}