use std::convert::TryInto;
use std::fmt;
use std::iter;
use std::marker::PhantomData;
use ff::{FromUniformBytes, PrimeField};
pub(crate) mod grain;
pub(crate) mod mds;
mod fields;
#[macro_use]
mod binops;
#[cfg(test)]
pub(crate) mod bn256;
#[cfg(test)]
pub(crate) mod pasta;
mod p128pow5t3;
mod p128pow5t3_compact;
pub use p128pow5t3::P128Pow5T3;
#[allow(unused_imports)]
pub(crate) use p128pow5t3::P128Pow5T3Constants;
pub use p128pow5t3_compact::P128Pow5T3Compact;
use grain::SboxType;
pub(crate) type State<F, const T: usize> = [F; T];
pub(crate) type SpongeRate<F, const RATE: usize> = [Option<F>; RATE];
pub(crate) type Mds<F, const T: usize> = [[F; T]; T];
pub trait Spec<F: PrimeField, const T: usize, const RATE: usize>: fmt::Debug {
fn full_rounds() -> usize;
fn partial_rounds() -> usize;
fn sbox(val: F) -> F;
fn secure_mds() -> usize;
fn constants() -> (Vec<[F; T]>, Mds<F, T>, Mds<F, T>)
where
F: FromUniformBytes<64> + Ord,
{
let r_f = Self::full_rounds();
let r_p = Self::partial_rounds();
let mut grain = grain::Grain::new(SboxType::Pow, T as u16, r_f as u16, r_p as u16);
let round_constants = (0..(r_f + r_p))
.map(|_| {
let mut rc_row = [F::ZERO; T];
for (rc, value) in rc_row
.iter_mut()
.zip((0..T).map(|_| grain.next_field_element()))
{
*rc = value;
}
rc_row
})
.collect();
let (mds, mds_inv) = mds::generate_mds::<F, T>(&mut grain, Self::secure_mds());
(round_constants, mds, mds_inv)
}
}
pub(crate) fn permute<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
state: &mut State<F, T>,
mds: &Mds<F, T>,
round_constants: &[[F; T]],
) {
let r_f = S::full_rounds() / 2;
let r_p = S::partial_rounds();
let apply_mds = |state: &mut State<F, T>| {
let mut new_state = [F::ZERO; T];
#[allow(clippy::needless_range_loop)]
for i in 0..T {
for j in 0..T {
new_state[i] += mds[i][j] * state[j];
}
}
*state = new_state;
};
let full_round = |state: &mut State<F, T>, rcs: &[F; T]| {
for (word, rc) in state.iter_mut().zip(rcs.iter()) {
*word = S::sbox(*word + rc);
}
apply_mds(state);
};
let part_round = |state: &mut State<F, T>, rcs: &[F; T]| {
for (word, rc) in state.iter_mut().zip(rcs.iter()) {
*word += rc;
}
state[0] = S::sbox(state[0]);
apply_mds(state);
};
iter::empty()
.chain(iter::repeat(&full_round as &dyn Fn(&mut State<F, T>, &[F; T])).take(r_f))
.chain(iter::repeat(&part_round as &dyn Fn(&mut State<F, T>, &[F; T])).take(r_p))
.chain(iter::repeat(&full_round as &dyn Fn(&mut State<F, T>, &[F; T])).take(r_f))
.zip(round_constants.iter())
.fold(state, |state, (round, rcs)| {
round(state, rcs);
state
});
}
fn poseidon_sponge<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
state: &mut State<F, T>,
input: Option<(&Absorbing<F, RATE>, usize)>,
mds_matrix: &Mds<F, T>,
round_constants: &[[F; T]],
) -> Squeezing<F, RATE> {
if let Some((Absorbing(input), layout_offset)) = input {
assert!(layout_offset <= T - RATE);
for (word, value) in state.iter_mut().skip(layout_offset).zip(input.iter()) {
*word += value.expect("poseidon_sponge is called with a padded input");
}
}
permute::<F, S, T, RATE>(state, mds_matrix, round_constants);
let mut output = [None; RATE];
for (word, value) in output.iter_mut().zip(state.iter()) {
*word = Some(*value);
}
Squeezing(output)
}
mod private {
pub trait SealedSpongeMode {}
impl<F, const RATE: usize> SealedSpongeMode for super::Absorbing<F, RATE> {}
impl<F, const RATE: usize> SealedSpongeMode for super::Squeezing<F, RATE> {}
}
pub trait SpongeMode: private::SealedSpongeMode {}
#[derive(Debug)]
pub struct Absorbing<F, const RATE: usize>(pub(crate) SpongeRate<F, RATE>);
#[derive(Debug)]
pub struct Squeezing<F, const RATE: usize>(pub(crate) SpongeRate<F, RATE>);
impl<F, const RATE: usize> SpongeMode for Absorbing<F, RATE> {}
impl<F, const RATE: usize> SpongeMode for Squeezing<F, RATE> {}
impl<F: fmt::Debug, const RATE: usize> Absorbing<F, RATE> {
pub(crate) fn init_with(val: F) -> Self {
Self(
iter::once(Some(val))
.chain((1..RATE).map(|_| None))
.collect::<Vec<_>>()
.try_into()
.unwrap(),
)
}
}
pub(crate) struct Sponge<
F: PrimeField,
S: Spec<F, T, RATE>,
M: SpongeMode,
const T: usize,
const RATE: usize,
> {
mode: M,
state: State<F, T>,
mds_matrix: Mds<F, T>,
round_constants: Vec<[F; T]>,
layout: usize,
_marker: PhantomData<S>,
}
impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>
Sponge<F, S, Absorbing<F, RATE>, T, RATE>
{
pub(crate) fn new(initial_capacity_element: F, layout: usize) -> Self
where
F: FromUniformBytes<64> + Ord,
{
let (round_constants, mds_matrix, _) = S::constants();
let mode = Absorbing([None; RATE]);
let mut state = [F::ZERO; T];
state[(RATE + layout) % T] = initial_capacity_element;
Sponge {
mode,
state,
mds_matrix,
round_constants,
layout,
_marker: PhantomData::default(),
}
}
pub(crate) fn update_capacity(&mut self, capacity_element: F) {
self.state[(RATE + self.layout) % T] += capacity_element;
}
pub(crate) fn absorb(&mut self, value: F) {
for entry in self.mode.0.iter_mut() {
if entry.is_none() {
*entry = Some(value);
return;
}
}
let _ = poseidon_sponge::<F, S, T, RATE>(
&mut self.state,
Some((&self.mode, self.layout)),
&self.mds_matrix,
&self.round_constants,
);
self.mode = Absorbing::init_with(value);
}
pub(crate) fn finish_absorbing(mut self) -> Sponge<F, S, Squeezing<F, RATE>, T, RATE> {
let mode = poseidon_sponge::<F, S, T, RATE>(
&mut self.state,
Some((&self.mode, self.layout)),
&self.mds_matrix,
&self.round_constants,
);
Sponge {
mode,
state: self.state,
mds_matrix: self.mds_matrix,
round_constants: self.round_constants,
layout: self.layout,
_marker: PhantomData::default(),
}
}
}
impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>
Sponge<F, S, Squeezing<F, RATE>, T, RATE>
{
pub(crate) fn squeeze(&mut self) -> F {
loop {
for entry in self.mode.0.iter_mut() {
if let Some(e) = entry.take() {
return e;
}
}
self.mode = poseidon_sponge::<F, S, T, RATE>(
&mut self.state,
None,
&self.mds_matrix,
&self.round_constants,
);
}
}
}
pub trait Domain<F: PrimeField, const RATE: usize> {
type Padding: IntoIterator<Item = F>;
fn name() -> String;
fn initial_capacity_element() -> F;
fn padding(input_len: usize) -> Self::Padding;
fn layout(_width: usize) -> usize {
0
}
}
#[derive(Clone, Copy, Debug)]
pub struct ConstantLength<const L: usize>;
impl<F: PrimeField, const RATE: usize, const L: usize> Domain<F, RATE> for ConstantLength<L> {
type Padding = iter::Take<iter::Repeat<F>>;
fn name() -> String {
format!("ConstantLength<{L}>")
}
fn initial_capacity_element() -> F {
F::from_u128((L as u128) << 64)
}
fn padding(input_len: usize) -> Self::Padding {
assert_eq!(input_len, L);
let k = (L + RATE - 1) / RATE;
iter::repeat(F::ZERO).take(k * RATE - L)
}
}
#[derive(Clone, Copy, Debug)]
pub struct ConstantLengthIden3<const L: usize>;
impl<F: PrimeField, const RATE: usize, const L: usize> Domain<F, RATE> for ConstantLengthIden3<L> {
type Padding = <ConstantLength<L> as Domain<F, RATE>>::Padding;
fn name() -> String {
format!("ConstantLength<{L}> in iden3's style")
}
fn initial_capacity_element() -> F {
F::ZERO
}
fn padding(input_len: usize) -> Self::Padding {
<ConstantLength<L> as Domain<F, RATE>>::padding(input_len)
}
fn layout(width: usize) -> usize {
width - RATE
}
}
#[derive(Clone, Copy, Debug)]
pub struct VariableLengthIden3;
impl<F: PrimeField, const RATE: usize> Domain<F, RATE> for VariableLengthIden3 {
type Padding = <ConstantLength<1> as Domain<F, RATE>>::Padding;
fn name() -> String {
"VariableLength in iden3's style".to_string()
}
fn initial_capacity_element() -> F {
<ConstantLengthIden3<1> as Domain<F, RATE>>::initial_capacity_element()
}
fn padding(input_len: usize) -> Self::Padding {
let k = input_len % RATE;
iter::repeat(F::ZERO).take(if k == 0 { 0 } else { RATE - k })
}
fn layout(width: usize) -> usize {
<ConstantLengthIden3<1> as Domain<F, RATE>>::layout(width)
}
}
pub struct Hash<
F: PrimeField,
S: Spec<F, T, RATE>,
D: Domain<F, RATE>,
const T: usize,
const RATE: usize,
> {
sponge: Sponge<F, S, Absorbing<F, RATE>, T, RATE>,
_domain: PhantomData<D>,
}
impl<F: PrimeField, S: Spec<F, T, RATE>, D: Domain<F, RATE>, const T: usize, const RATE: usize>
fmt::Debug for Hash<F, S, D, T, RATE>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Hash")
.field("width", &T)
.field("rate", &RATE)
.field("R_F", &S::full_rounds())
.field("R_P", &S::partial_rounds())
.field("domain", &D::name())
.finish()
}
}
impl<F: PrimeField, S: Spec<F, T, RATE>, D: Domain<F, RATE>, const T: usize, const RATE: usize>
Hash<F, S, D, T, RATE>
{
pub fn init() -> Self
where
F: FromUniformBytes<64> + Ord,
{
Hash {
sponge: Sponge::new(D::initial_capacity_element(), D::layout(T)),
_domain: PhantomData::default(),
}
}
pub fn permute(&self, state: &mut [F; T]) {
permute::<F, S, T, RATE>(state, &self.sponge.mds_matrix, &self.sponge.round_constants);
}
}
impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize, const L: usize>
Hash<F, S, ConstantLength<L>, T, RATE>
{
pub fn hash(mut self, message: [F; L]) -> F {
for value in message
.into_iter()
.chain(<ConstantLength<L> as Domain<F, RATE>>::padding(L))
{
self.sponge.absorb(value);
}
self.sponge.finish_absorbing().squeeze()
}
}
impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize, const L: usize>
Hash<F, S, ConstantLengthIden3<L>, T, RATE>
{
pub fn hash(mut self, message: [F; L], domain: F) -> F {
self.sponge.update_capacity(domain);
for value in message
.into_iter()
.chain(<ConstantLength<L> as Domain<F, RATE>>::padding(L))
{
self.sponge.absorb(value);
}
self.sponge.finish_absorbing().squeeze()
}
}
impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>
Hash<F, S, VariableLengthIden3, T, RATE>
{
pub fn hash_with_cap(mut self, message: &[F], cap: u128) -> F {
self.sponge.update_capacity(F::from_u128(cap));
for value in message {
self.sponge.absorb(*value);
}
for pad in <VariableLengthIden3 as Domain<F, RATE>>::padding(message.len()) {
self.sponge.absorb(pad);
}
self.sponge.finish_absorbing().squeeze()
}
}
#[cfg(test)]
mod tests {
use ff::PrimeField;
use super::pasta::Fp;
use super::{permute, ConstantLength, Hash, P128Pow5T3, P128Pow5T3Compact, Spec};
type OrchardNullifier = P128Pow5T3<Fp>;
#[test]
fn orchard_spec_equivalence() {
let message = [Fp::from(6), Fp::from(42)];
let (round_constants, mds, _) = OrchardNullifier::constants();
let hasher = Hash::<_, OrchardNullifier, ConstantLength<2>, 3, 2>::init();
let result = hasher.hash(message);
let mut state = [message[0], message[1], Fp::from_u128(2 << 64)];
permute::<_, OrchardNullifier, 3, 2>(&mut state, &mds, &round_constants);
assert_eq!(state[0], result);
}
#[test]
fn hasher_permute_equivalence() {
let message = [Fp::from(6), Fp::from(42)];
let hasher = Hash::<_, OrchardNullifier, ConstantLength<2>, 3, 2>::init();
let mut state = [Fp::from(6), Fp::from(42), Fp::from_u128(2 << 64)];
hasher.permute(&mut state);
let result = hasher.hash(message);
assert_eq!(state[0], result);
}
#[test]
fn spec_equivalence() {
let message = [Fp::from(6), Fp::from(42)];
let hasher1 = Hash::<_, P128Pow5T3<Fp>, ConstantLength<2>, 3, 2>::init();
let hasher2 = Hash::<_, P128Pow5T3Compact<Fp>, ConstantLength<2>, 3, 2>::init();
let result1 = hasher1.hash(message);
let result2 = hasher2.hash(message);
assert_eq!(result1, result2);
}
}