poseidon_primitives/poseidon/
primitives.rs

1//! The Poseidon algebraic hash function.
2
3use std::convert::TryInto;
4use std::fmt;
5use std::iter;
6use std::marker::PhantomData;
7
8use ff::{FromUniformBytes, PrimeField};
9
10//pub(crate) mod fp;
11//pub(crate) mod fq;
12pub(crate) mod grain;
13pub(crate) mod mds;
14
15mod fields;
16#[macro_use]
17mod binops;
18
19#[cfg(test)]
20pub(crate) mod bn256;
21#[cfg(test)]
22pub(crate) mod pasta;
23
24//#[cfg(test)]
25//pub(crate) mod test_vectors;
26
27mod p128pow5t3;
28mod p128pow5t3_compact;
29
30pub use p128pow5t3::P128Pow5T3;
31#[allow(unused_imports)]
32pub(crate) use p128pow5t3::P128Pow5T3Constants;
33pub use p128pow5t3_compact::P128Pow5T3Compact;
34
35use grain::SboxType;
36
37/// The type used to hold permutation state.
38pub(crate) type State<F, const T: usize> = [F; T];
39
40/// The type used to hold sponge rate.
41pub(crate) type SpongeRate<F, const RATE: usize> = [Option<F>; RATE];
42
43/// The type used to hold the MDS matrix and its inverse.
44pub(crate) type Mds<F, const T: usize> = [[F; T]; T];
45
46/// A specification for a Poseidon permutation.
47pub trait Spec<F: PrimeField, const T: usize, const RATE: usize>: fmt::Debug {
48    /// The number of full rounds for this specification.
49    ///
50    /// This must be an even number.
51    fn full_rounds() -> usize;
52
53    /// The number of partial rounds for this specification.
54    fn partial_rounds() -> usize;
55
56    /// The S-box for this specification.
57    fn sbox(val: F) -> F;
58
59    /// Side-loaded index of the first correct and secure MDS that will be generated by
60    /// the reference implementation.
61    ///
62    /// This is used by the default implementation of [`Spec::constants`]. If you are
63    /// hard-coding the constants, you may leave this unimplemented.
64    fn secure_mds() -> usize;
65
66    /// Generates `(round_constants, mds, mds^-1)` corresponding to this specification.
67    fn constants() -> (Vec<[F; T]>, Mds<F, T>, Mds<F, T>)
68    where
69        F: FromUniformBytes<64> + Ord,
70    {
71        let r_f = Self::full_rounds();
72        let r_p = Self::partial_rounds();
73
74        let mut grain = grain::Grain::new(SboxType::Pow, T as u16, r_f as u16, r_p as u16);
75
76        let round_constants = (0..(r_f + r_p))
77            .map(|_| {
78                let mut rc_row = [F::ZERO; T];
79                for (rc, value) in rc_row
80                    .iter_mut()
81                    .zip((0..T).map(|_| grain.next_field_element()))
82                {
83                    *rc = value;
84                }
85                rc_row
86            })
87            .collect();
88
89        let (mds, mds_inv) = mds::generate_mds::<F, T>(&mut grain, Self::secure_mds());
90
91        (round_constants, mds, mds_inv)
92    }
93}
94
95/// Runs the Poseidon permutation on the given state.
96pub(crate) fn permute<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
97    state: &mut State<F, T>,
98    mds: &Mds<F, T>,
99    round_constants: &[[F; T]],
100) {
101    let r_f = S::full_rounds() / 2;
102    let r_p = S::partial_rounds();
103
104    let apply_mds = |state: &mut State<F, T>| {
105        let mut new_state = [F::ZERO; T];
106        // Matrix multiplication
107        #[allow(clippy::needless_range_loop)]
108        for i in 0..T {
109            for j in 0..T {
110                new_state[i] += mds[i][j] * state[j];
111            }
112        }
113        *state = new_state;
114    };
115
116    let full_round = |state: &mut State<F, T>, rcs: &[F; T]| {
117        for (word, rc) in state.iter_mut().zip(rcs.iter()) {
118            *word = S::sbox(*word + rc);
119        }
120        apply_mds(state);
121    };
122
123    let part_round = |state: &mut State<F, T>, rcs: &[F; T]| {
124        for (word, rc) in state.iter_mut().zip(rcs.iter()) {
125            *word += rc;
126        }
127        // In a partial round, the S-box is only applied to the first state word.
128        state[0] = S::sbox(state[0]);
129        apply_mds(state);
130    };
131
132    iter::empty()
133        .chain(iter::repeat(&full_round as &dyn Fn(&mut State<F, T>, &[F; T])).take(r_f))
134        .chain(iter::repeat(&part_round as &dyn Fn(&mut State<F, T>, &[F; T])).take(r_p))
135        .chain(iter::repeat(&full_round as &dyn Fn(&mut State<F, T>, &[F; T])).take(r_f))
136        .zip(round_constants.iter())
137        .fold(state, |state, (round, rcs)| {
138            round(state, rcs);
139            state
140        });
141}
142
143fn poseidon_sponge<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
144    state: &mut State<F, T>,
145    input: Option<(&Absorbing<F, RATE>, usize)>,
146    mds_matrix: &Mds<F, T>,
147    round_constants: &[[F; T]],
148) -> Squeezing<F, RATE> {
149    if let Some((Absorbing(input), layout_offset)) = input {
150        assert!(layout_offset <= T - RATE);
151        // `Iterator::zip` short-circuits when one iterator completes, so this will only
152        // mutate the rate portion of the state.
153        for (word, value) in state.iter_mut().skip(layout_offset).zip(input.iter()) {
154            *word += value.expect("poseidon_sponge is called with a padded input");
155        }
156    }
157
158    permute::<F, S, T, RATE>(state, mds_matrix, round_constants);
159
160    let mut output = [None; RATE];
161    for (word, value) in output.iter_mut().zip(state.iter()) {
162        *word = Some(*value);
163    }
164    Squeezing(output)
165}
166
167mod private {
168    pub trait SealedSpongeMode {}
169    impl<F, const RATE: usize> SealedSpongeMode for super::Absorbing<F, RATE> {}
170    impl<F, const RATE: usize> SealedSpongeMode for super::Squeezing<F, RATE> {}
171}
172
173/// The state of the `Sponge`.
174pub trait SpongeMode: private::SealedSpongeMode {}
175
176/// The absorbing state of the `Sponge`.
177#[derive(Debug)]
178pub struct Absorbing<F, const RATE: usize>(pub(crate) SpongeRate<F, RATE>);
179
180/// The squeezing state of the `Sponge`.
181#[derive(Debug)]
182pub struct Squeezing<F, const RATE: usize>(pub(crate) SpongeRate<F, RATE>);
183
184impl<F, const RATE: usize> SpongeMode for Absorbing<F, RATE> {}
185impl<F, const RATE: usize> SpongeMode for Squeezing<F, RATE> {}
186
187impl<F: fmt::Debug, const RATE: usize> Absorbing<F, RATE> {
188    pub(crate) fn init_with(val: F) -> Self {
189        Self(
190            iter::once(Some(val))
191                .chain((1..RATE).map(|_| None))
192                .collect::<Vec<_>>()
193                .try_into()
194                .unwrap(),
195        )
196    }
197}
198
199/// A Poseidon sponge.
200pub(crate) struct Sponge<
201    F: PrimeField,
202    S: Spec<F, T, RATE>,
203    M: SpongeMode,
204    const T: usize,
205    const RATE: usize,
206> {
207    mode: M,
208    state: State<F, T>,
209    mds_matrix: Mds<F, T>,
210    round_constants: Vec<[F; T]>,
211    layout: usize,
212    _marker: PhantomData<S>,
213}
214
215impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>
216    Sponge<F, S, Absorbing<F, RATE>, T, RATE>
217{
218    /// Constructs a new sponge for the given Poseidon specification.
219    pub(crate) fn new(initial_capacity_element: F, layout: usize) -> Self
220    where
221        F: FromUniformBytes<64> + Ord,
222    {
223        let (round_constants, mds_matrix, _) = S::constants();
224
225        let mode = Absorbing([None; RATE]);
226        let mut state = [F::ZERO; T];
227        state[(RATE + layout) % T] = initial_capacity_element;
228
229        Sponge {
230            mode,
231            state,
232            mds_matrix,
233            round_constants,
234            layout,
235            _marker: PhantomData,
236        }
237    }
238
239    /// add the capacity into current position of output
240    pub(crate) fn update_capacity(&mut self, capacity_element: F) {
241        self.state[(RATE + self.layout) % T] += capacity_element;
242    }
243
244    /// Absorbs an element into the sponge.
245    pub(crate) fn absorb(&mut self, value: F) {
246        for entry in self.mode.0.iter_mut() {
247            if entry.is_none() {
248                *entry = Some(value);
249                return;
250            }
251        }
252
253        // We've already absorbed as many elements as we can
254        let _ = poseidon_sponge::<F, S, T, RATE>(
255            &mut self.state,
256            Some((&self.mode, self.layout)),
257            &self.mds_matrix,
258            &self.round_constants,
259        );
260        self.mode = Absorbing::init_with(value);
261    }
262
263    /// Transitions the sponge into its squeezing state.
264    pub(crate) fn finish_absorbing(mut self) -> Sponge<F, S, Squeezing<F, RATE>, T, RATE> {
265        let mode = poseidon_sponge::<F, S, T, RATE>(
266            &mut self.state,
267            Some((&self.mode, self.layout)),
268            &self.mds_matrix,
269            &self.round_constants,
270        );
271
272        Sponge {
273            mode,
274            state: self.state,
275            mds_matrix: self.mds_matrix,
276            round_constants: self.round_constants,
277            layout: self.layout,
278            _marker: PhantomData,
279        }
280    }
281}
282
283impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>
284    Sponge<F, S, Squeezing<F, RATE>, T, RATE>
285{
286    /// Squeezes an element from the sponge.
287    pub(crate) fn squeeze(&mut self) -> F {
288        loop {
289            for entry in self.mode.0.iter_mut() {
290                if let Some(e) = entry.take() {
291                    return e;
292                }
293            }
294
295            // We've already squeezed out all available elements
296            self.mode = poseidon_sponge::<F, S, T, RATE>(
297                &mut self.state,
298                None,
299                &self.mds_matrix,
300                &self.round_constants,
301            );
302        }
303    }
304}
305
306/// A domain in which a Poseidon hash function is being used.
307pub trait Domain<F: PrimeField, const RATE: usize> {
308    /// Iterator that outputs padding field elements.
309    type Padding: IntoIterator<Item = F>;
310
311    /// The name of this domain, for debug formatting purposes.
312    fn name() -> String;
313
314    /// The initial capacity element, encoding this domain.
315    fn initial_capacity_element() -> F;
316
317    /// Returns the padding to be appended to the input.
318    fn padding(input_len: usize) -> Self::Padding;
319
320    /// Set the position of inputs in state: how many fields
321    /// of offset the first input should be put in, for iden3,
322    /// inputs are right aligned in the state array
323    fn layout(_width: usize) -> usize {
324        0
325    }
326}
327
328/// A Poseidon hash function used with constant input length.
329///
330/// Domain specified in [ePrint 2019/458 section 4.2](https://eprint.iacr.org/2019/458.pdf).
331#[derive(Clone, Copy, Debug)]
332pub struct ConstantLength<const L: usize>;
333
334impl<F: PrimeField, const RATE: usize, const L: usize> Domain<F, RATE> for ConstantLength<L> {
335    type Padding = iter::Take<iter::Repeat<F>>;
336
337    fn name() -> String {
338        format!("ConstantLength<{L}>")
339    }
340
341    fn initial_capacity_element() -> F {
342        // Capacity value is $length \cdot 2^64 + (o-1)$ where o is the output length.
343        // We hard-code an output length of 1.
344        F::from_u128((L as u128) << 64)
345    }
346
347    fn padding(input_len: usize) -> Self::Padding {
348        assert_eq!(input_len, L);
349        // For constant-input-length hashing, we pad the input with zeroes to a multiple
350        // of RATE. On its own this would not be sponge-compliant padding, but the
351        // Poseidon authors encode the constant length into the capacity element, ensuring
352        // that inputs of different lengths do not share the same permutation.
353        let k = L.div_ceil(RATE);
354        iter::repeat(F::ZERO).take(k * RATE - L)
355    }
356}
357
358/// A Poseidon hash function used with constant input length, this is iden3's specifications
359#[derive(Clone, Copy, Debug)]
360pub struct ConstantLengthIden3<const L: usize>;
361
362impl<F: PrimeField, const RATE: usize, const L: usize> Domain<F, RATE> for ConstantLengthIden3<L> {
363    type Padding = <ConstantLength<L> as Domain<F, RATE>>::Padding;
364
365    fn name() -> String {
366        format!("ConstantLength<{L}> in iden3's style")
367    }
368
369    // iden3's scheme do not set any capacity mark
370    fn initial_capacity_element() -> F {
371        F::ZERO
372    }
373
374    fn padding(input_len: usize) -> Self::Padding {
375        <ConstantLength<L> as Domain<F, RATE>>::padding(input_len)
376    }
377
378    fn layout(width: usize) -> usize {
379        width - RATE
380    }
381}
382
383/// A Poseidon hash function used with variable input length, this is iden3's specifications
384#[derive(Clone, Copy, Debug)]
385pub struct VariableLengthIden3;
386
387impl<F: PrimeField, const RATE: usize> Domain<F, RATE> for VariableLengthIden3 {
388    type Padding = <ConstantLength<1> as Domain<F, RATE>>::Padding;
389
390    fn name() -> String {
391        "VariableLength in iden3's style".to_string()
392    }
393
394    // iden3's scheme do not set any capacity mark
395    fn initial_capacity_element() -> F {
396        <ConstantLengthIden3<1> as Domain<F, RATE>>::initial_capacity_element()
397    }
398
399    fn padding(input_len: usize) -> Self::Padding {
400        let k = input_len % RATE;
401        iter::repeat(F::ZERO).take(if k == 0 { 0 } else { RATE - k })
402    }
403
404    fn layout(width: usize) -> usize {
405        <ConstantLengthIden3<1> as Domain<F, RATE>>::layout(width)
406    }
407}
408
409/// A Poseidon hash function, built around a sponge.
410pub struct Hash<
411    F: PrimeField,
412    S: Spec<F, T, RATE>,
413    D: Domain<F, RATE>,
414    const T: usize,
415    const RATE: usize,
416> {
417    sponge: Sponge<F, S, Absorbing<F, RATE>, T, RATE>,
418    _domain: PhantomData<D>,
419}
420
421impl<F: PrimeField, S: Spec<F, T, RATE>, D: Domain<F, RATE>, const T: usize, const RATE: usize>
422    fmt::Debug for Hash<F, S, D, T, RATE>
423{
424    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
425        f.debug_struct("Hash")
426            .field("width", &T)
427            .field("rate", &RATE)
428            .field("R_F", &S::full_rounds())
429            .field("R_P", &S::partial_rounds())
430            .field("domain", &D::name())
431            .finish()
432    }
433}
434
435impl<F: PrimeField, S: Spec<F, T, RATE>, D: Domain<F, RATE>, const T: usize, const RATE: usize>
436    Hash<F, S, D, T, RATE>
437{
438    /// Initializes a new hasher.
439    pub fn init() -> Self
440    where
441        F: FromUniformBytes<64> + Ord,
442    {
443        Hash {
444            sponge: Sponge::new(D::initial_capacity_element(), D::layout(T)),
445            _domain: PhantomData,
446        }
447    }
448
449    /// help permute a state
450    pub fn permute(&self, state: &mut [F; T]) {
451        permute::<F, S, T, RATE>(state, &self.sponge.mds_matrix, &self.sponge.round_constants);
452    }
453}
454
455impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize, const L: usize>
456    Hash<F, S, ConstantLength<L>, T, RATE>
457{
458    /// Hashes the given input.
459    pub fn hash(mut self, message: [F; L]) -> F {
460        for value in message
461            .into_iter()
462            .chain(<ConstantLength<L> as Domain<F, RATE>>::padding(L))
463        {
464            self.sponge.absorb(value);
465        }
466        self.sponge.finish_absorbing().squeeze()
467    }
468}
469
470impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize, const L: usize>
471    Hash<F, S, ConstantLengthIden3<L>, T, RATE>
472{
473    /// Hashes the given input.
474    pub fn hash(mut self, message: [F; L], domain: F) -> F {
475        // notice iden3 domain has no initial capacity element so the domain is updated here
476        self.sponge.update_capacity(domain);
477        for value in message
478            .into_iter()
479            .chain(<ConstantLength<L> as Domain<F, RATE>>::padding(L))
480        {
481            self.sponge.absorb(value);
482        }
483        self.sponge.finish_absorbing().squeeze()
484    }
485}
486
487impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>
488    Hash<F, S, VariableLengthIden3, T, RATE>
489{
490    /// Hashes the given input.
491    pub fn hash_with_cap(mut self, message: &[F], cap: u128) -> F {
492        self.sponge.update_capacity(F::from_u128(cap));
493        for value in message {
494            self.sponge.absorb(*value);
495        }
496
497        for pad in <VariableLengthIden3 as Domain<F, RATE>>::padding(message.len()) {
498            self.sponge.absorb(pad);
499        }
500
501        self.sponge.finish_absorbing().squeeze()
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use ff::PrimeField;
508
509    use super::pasta::Fp;
510
511    use super::{permute, ConstantLength, Hash, P128Pow5T3, P128Pow5T3Compact, Spec};
512    type OrchardNullifier = P128Pow5T3<Fp>;
513
514    #[test]
515    fn orchard_spec_equivalence() {
516        let message = [Fp::from(6), Fp::from(42)];
517
518        let (round_constants, mds, _) = OrchardNullifier::constants();
519
520        let hasher = Hash::<_, OrchardNullifier, ConstantLength<2>, 3, 2>::init();
521        let result = hasher.hash(message);
522
523        // The result should be equivalent to just directly applying the permutation and
524        // taking the first state element as the output.
525        let mut state = [message[0], message[1], Fp::from_u128(2 << 64)];
526        permute::<_, OrchardNullifier, 3, 2>(&mut state, &mds, &round_constants);
527        assert_eq!(state[0], result);
528    }
529
530    #[test]
531    fn hasher_permute_equivalence() {
532        let message = [Fp::from(6), Fp::from(42)];
533        let hasher = Hash::<_, OrchardNullifier, ConstantLength<2>, 3, 2>::init();
534        // The result should be equivalent to just directly applying the permutation and
535        // taking the first state element as the output.
536        let mut state = [Fp::from(6), Fp::from(42), Fp::from_u128(2 << 64)];
537
538        hasher.permute(&mut state);
539
540        let result = hasher.hash(message);
541        assert_eq!(state[0], result);
542    }
543
544    #[test]
545    fn spec_equivalence() {
546        let message = [Fp::from(6), Fp::from(42)];
547        let hasher1 = Hash::<_, P128Pow5T3<Fp>, ConstantLength<2>, 3, 2>::init();
548        let hasher2 = Hash::<_, P128Pow5T3Compact<Fp>, ConstantLength<2>, 3, 2>::init();
549
550        let result1 = hasher1.hash(message);
551        let result2 = hasher2.hash(message);
552        assert_eq!(result1, result2);
553    }
554}