1use std::convert::TryInto;
4use std::fmt;
5use std::iter;
6use std::marker::PhantomData;
7
8use ff::{FromUniformBytes, PrimeField};
9
10pub(crate) mod grain;
13pub(crate) mod mds;
14
15mod fields;
16#[macro_use]
17mod binops;
18
19#[cfg(test)]
20pub(crate) mod bn256;
21#[cfg(test)]
22pub(crate) mod pasta;
23
24mod p128pow5t3;
28mod p128pow5t3_compact;
29
30pub use p128pow5t3::P128Pow5T3;
31#[allow(unused_imports)]
32pub(crate) use p128pow5t3::P128Pow5T3Constants;
33pub use p128pow5t3_compact::P128Pow5T3Compact;
34
35use grain::SboxType;
36
37pub(crate) type State<F, const T: usize> = [F; T];
39
40pub(crate) type SpongeRate<F, const RATE: usize> = [Option<F>; RATE];
42
43pub(crate) type Mds<F, const T: usize> = [[F; T]; T];
45
46pub trait Spec<F: PrimeField, const T: usize, const RATE: usize>: fmt::Debug {
48 fn full_rounds() -> usize;
52
53 fn partial_rounds() -> usize;
55
56 fn sbox(val: F) -> F;
58
59 fn secure_mds() -> usize;
65
66 fn constants() -> (Vec<[F; T]>, Mds<F, T>, Mds<F, T>)
68 where
69 F: FromUniformBytes<64> + Ord,
70 {
71 let r_f = Self::full_rounds();
72 let r_p = Self::partial_rounds();
73
74 let mut grain = grain::Grain::new(SboxType::Pow, T as u16, r_f as u16, r_p as u16);
75
76 let round_constants = (0..(r_f + r_p))
77 .map(|_| {
78 let mut rc_row = [F::ZERO; T];
79 for (rc, value) in rc_row
80 .iter_mut()
81 .zip((0..T).map(|_| grain.next_field_element()))
82 {
83 *rc = value;
84 }
85 rc_row
86 })
87 .collect();
88
89 let (mds, mds_inv) = mds::generate_mds::<F, T>(&mut grain, Self::secure_mds());
90
91 (round_constants, mds, mds_inv)
92 }
93}
94
95pub(crate) fn permute<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
97 state: &mut State<F, T>,
98 mds: &Mds<F, T>,
99 round_constants: &[[F; T]],
100) {
101 let r_f = S::full_rounds() / 2;
102 let r_p = S::partial_rounds();
103
104 let apply_mds = |state: &mut State<F, T>| {
105 let mut new_state = [F::ZERO; T];
106 #[allow(clippy::needless_range_loop)]
108 for i in 0..T {
109 for j in 0..T {
110 new_state[i] += mds[i][j] * state[j];
111 }
112 }
113 *state = new_state;
114 };
115
116 let full_round = |state: &mut State<F, T>, rcs: &[F; T]| {
117 for (word, rc) in state.iter_mut().zip(rcs.iter()) {
118 *word = S::sbox(*word + rc);
119 }
120 apply_mds(state);
121 };
122
123 let part_round = |state: &mut State<F, T>, rcs: &[F; T]| {
124 for (word, rc) in state.iter_mut().zip(rcs.iter()) {
125 *word += rc;
126 }
127 state[0] = S::sbox(state[0]);
129 apply_mds(state);
130 };
131
132 iter::empty()
133 .chain(iter::repeat(&full_round as &dyn Fn(&mut State<F, T>, &[F; T])).take(r_f))
134 .chain(iter::repeat(&part_round as &dyn Fn(&mut State<F, T>, &[F; T])).take(r_p))
135 .chain(iter::repeat(&full_round as &dyn Fn(&mut State<F, T>, &[F; T])).take(r_f))
136 .zip(round_constants.iter())
137 .fold(state, |state, (round, rcs)| {
138 round(state, rcs);
139 state
140 });
141}
142
143fn poseidon_sponge<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
144 state: &mut State<F, T>,
145 input: Option<(&Absorbing<F, RATE>, usize)>,
146 mds_matrix: &Mds<F, T>,
147 round_constants: &[[F; T]],
148) -> Squeezing<F, RATE> {
149 if let Some((Absorbing(input), layout_offset)) = input {
150 assert!(layout_offset <= T - RATE);
151 for (word, value) in state.iter_mut().skip(layout_offset).zip(input.iter()) {
154 *word += value.expect("poseidon_sponge is called with a padded input");
155 }
156 }
157
158 permute::<F, S, T, RATE>(state, mds_matrix, round_constants);
159
160 let mut output = [None; RATE];
161 for (word, value) in output.iter_mut().zip(state.iter()) {
162 *word = Some(*value);
163 }
164 Squeezing(output)
165}
166
167mod private {
168 pub trait SealedSpongeMode {}
169 impl<F, const RATE: usize> SealedSpongeMode for super::Absorbing<F, RATE> {}
170 impl<F, const RATE: usize> SealedSpongeMode for super::Squeezing<F, RATE> {}
171}
172
173pub trait SpongeMode: private::SealedSpongeMode {}
175
176#[derive(Debug)]
178pub struct Absorbing<F, const RATE: usize>(pub(crate) SpongeRate<F, RATE>);
179
180#[derive(Debug)]
182pub struct Squeezing<F, const RATE: usize>(pub(crate) SpongeRate<F, RATE>);
183
184impl<F, const RATE: usize> SpongeMode for Absorbing<F, RATE> {}
185impl<F, const RATE: usize> SpongeMode for Squeezing<F, RATE> {}
186
187impl<F: fmt::Debug, const RATE: usize> Absorbing<F, RATE> {
188 pub(crate) fn init_with(val: F) -> Self {
189 Self(
190 iter::once(Some(val))
191 .chain((1..RATE).map(|_| None))
192 .collect::<Vec<_>>()
193 .try_into()
194 .unwrap(),
195 )
196 }
197}
198
199pub(crate) struct Sponge<
201 F: PrimeField,
202 S: Spec<F, T, RATE>,
203 M: SpongeMode,
204 const T: usize,
205 const RATE: usize,
206> {
207 mode: M,
208 state: State<F, T>,
209 mds_matrix: Mds<F, T>,
210 round_constants: Vec<[F; T]>,
211 layout: usize,
212 _marker: PhantomData<S>,
213}
214
215impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>
216 Sponge<F, S, Absorbing<F, RATE>, T, RATE>
217{
218 pub(crate) fn new(initial_capacity_element: F, layout: usize) -> Self
220 where
221 F: FromUniformBytes<64> + Ord,
222 {
223 let (round_constants, mds_matrix, _) = S::constants();
224
225 let mode = Absorbing([None; RATE]);
226 let mut state = [F::ZERO; T];
227 state[(RATE + layout) % T] = initial_capacity_element;
228
229 Sponge {
230 mode,
231 state,
232 mds_matrix,
233 round_constants,
234 layout,
235 _marker: PhantomData,
236 }
237 }
238
239 pub(crate) fn update_capacity(&mut self, capacity_element: F) {
241 self.state[(RATE + self.layout) % T] += capacity_element;
242 }
243
244 pub(crate) fn absorb(&mut self, value: F) {
246 for entry in self.mode.0.iter_mut() {
247 if entry.is_none() {
248 *entry = Some(value);
249 return;
250 }
251 }
252
253 let _ = poseidon_sponge::<F, S, T, RATE>(
255 &mut self.state,
256 Some((&self.mode, self.layout)),
257 &self.mds_matrix,
258 &self.round_constants,
259 );
260 self.mode = Absorbing::init_with(value);
261 }
262
263 pub(crate) fn finish_absorbing(mut self) -> Sponge<F, S, Squeezing<F, RATE>, T, RATE> {
265 let mode = poseidon_sponge::<F, S, T, RATE>(
266 &mut self.state,
267 Some((&self.mode, self.layout)),
268 &self.mds_matrix,
269 &self.round_constants,
270 );
271
272 Sponge {
273 mode,
274 state: self.state,
275 mds_matrix: self.mds_matrix,
276 round_constants: self.round_constants,
277 layout: self.layout,
278 _marker: PhantomData,
279 }
280 }
281}
282
283impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>
284 Sponge<F, S, Squeezing<F, RATE>, T, RATE>
285{
286 pub(crate) fn squeeze(&mut self) -> F {
288 loop {
289 for entry in self.mode.0.iter_mut() {
290 if let Some(e) = entry.take() {
291 return e;
292 }
293 }
294
295 self.mode = poseidon_sponge::<F, S, T, RATE>(
297 &mut self.state,
298 None,
299 &self.mds_matrix,
300 &self.round_constants,
301 );
302 }
303 }
304}
305
306pub trait Domain<F: PrimeField, const RATE: usize> {
308 type Padding: IntoIterator<Item = F>;
310
311 fn name() -> String;
313
314 fn initial_capacity_element() -> F;
316
317 fn padding(input_len: usize) -> Self::Padding;
319
320 fn layout(_width: usize) -> usize {
324 0
325 }
326}
327
328#[derive(Clone, Copy, Debug)]
332pub struct ConstantLength<const L: usize>;
333
334impl<F: PrimeField, const RATE: usize, const L: usize> Domain<F, RATE> for ConstantLength<L> {
335 type Padding = iter::Take<iter::Repeat<F>>;
336
337 fn name() -> String {
338 format!("ConstantLength<{L}>")
339 }
340
341 fn initial_capacity_element() -> F {
342 F::from_u128((L as u128) << 64)
345 }
346
347 fn padding(input_len: usize) -> Self::Padding {
348 assert_eq!(input_len, L);
349 let k = L.div_ceil(RATE);
354 iter::repeat(F::ZERO).take(k * RATE - L)
355 }
356}
357
358#[derive(Clone, Copy, Debug)]
360pub struct ConstantLengthIden3<const L: usize>;
361
362impl<F: PrimeField, const RATE: usize, const L: usize> Domain<F, RATE> for ConstantLengthIden3<L> {
363 type Padding = <ConstantLength<L> as Domain<F, RATE>>::Padding;
364
365 fn name() -> String {
366 format!("ConstantLength<{L}> in iden3's style")
367 }
368
369 fn initial_capacity_element() -> F {
371 F::ZERO
372 }
373
374 fn padding(input_len: usize) -> Self::Padding {
375 <ConstantLength<L> as Domain<F, RATE>>::padding(input_len)
376 }
377
378 fn layout(width: usize) -> usize {
379 width - RATE
380 }
381}
382
383#[derive(Clone, Copy, Debug)]
385pub struct VariableLengthIden3;
386
387impl<F: PrimeField, const RATE: usize> Domain<F, RATE> for VariableLengthIden3 {
388 type Padding = <ConstantLength<1> as Domain<F, RATE>>::Padding;
389
390 fn name() -> String {
391 "VariableLength in iden3's style".to_string()
392 }
393
394 fn initial_capacity_element() -> F {
396 <ConstantLengthIden3<1> as Domain<F, RATE>>::initial_capacity_element()
397 }
398
399 fn padding(input_len: usize) -> Self::Padding {
400 let k = input_len % RATE;
401 iter::repeat(F::ZERO).take(if k == 0 { 0 } else { RATE - k })
402 }
403
404 fn layout(width: usize) -> usize {
405 <ConstantLengthIden3<1> as Domain<F, RATE>>::layout(width)
406 }
407}
408
409pub struct Hash<
411 F: PrimeField,
412 S: Spec<F, T, RATE>,
413 D: Domain<F, RATE>,
414 const T: usize,
415 const RATE: usize,
416> {
417 sponge: Sponge<F, S, Absorbing<F, RATE>, T, RATE>,
418 _domain: PhantomData<D>,
419}
420
421impl<F: PrimeField, S: Spec<F, T, RATE>, D: Domain<F, RATE>, const T: usize, const RATE: usize>
422 fmt::Debug for Hash<F, S, D, T, RATE>
423{
424 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
425 f.debug_struct("Hash")
426 .field("width", &T)
427 .field("rate", &RATE)
428 .field("R_F", &S::full_rounds())
429 .field("R_P", &S::partial_rounds())
430 .field("domain", &D::name())
431 .finish()
432 }
433}
434
435impl<F: PrimeField, S: Spec<F, T, RATE>, D: Domain<F, RATE>, const T: usize, const RATE: usize>
436 Hash<F, S, D, T, RATE>
437{
438 pub fn init() -> Self
440 where
441 F: FromUniformBytes<64> + Ord,
442 {
443 Hash {
444 sponge: Sponge::new(D::initial_capacity_element(), D::layout(T)),
445 _domain: PhantomData,
446 }
447 }
448
449 pub fn permute(&self, state: &mut [F; T]) {
451 permute::<F, S, T, RATE>(state, &self.sponge.mds_matrix, &self.sponge.round_constants);
452 }
453}
454
455impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize, const L: usize>
456 Hash<F, S, ConstantLength<L>, T, RATE>
457{
458 pub fn hash(mut self, message: [F; L]) -> F {
460 for value in message
461 .into_iter()
462 .chain(<ConstantLength<L> as Domain<F, RATE>>::padding(L))
463 {
464 self.sponge.absorb(value);
465 }
466 self.sponge.finish_absorbing().squeeze()
467 }
468}
469
470impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize, const L: usize>
471 Hash<F, S, ConstantLengthIden3<L>, T, RATE>
472{
473 pub fn hash(mut self, message: [F; L], domain: F) -> F {
475 self.sponge.update_capacity(domain);
477 for value in message
478 .into_iter()
479 .chain(<ConstantLength<L> as Domain<F, RATE>>::padding(L))
480 {
481 self.sponge.absorb(value);
482 }
483 self.sponge.finish_absorbing().squeeze()
484 }
485}
486
487impl<F: PrimeField, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>
488 Hash<F, S, VariableLengthIden3, T, RATE>
489{
490 pub fn hash_with_cap(mut self, message: &[F], cap: u128) -> F {
492 self.sponge.update_capacity(F::from_u128(cap));
493 for value in message {
494 self.sponge.absorb(*value);
495 }
496
497 for pad in <VariableLengthIden3 as Domain<F, RATE>>::padding(message.len()) {
498 self.sponge.absorb(pad);
499 }
500
501 self.sponge.finish_absorbing().squeeze()
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use ff::PrimeField;
508
509 use super::pasta::Fp;
510
511 use super::{permute, ConstantLength, Hash, P128Pow5T3, P128Pow5T3Compact, Spec};
512 type OrchardNullifier = P128Pow5T3<Fp>;
513
514 #[test]
515 fn orchard_spec_equivalence() {
516 let message = [Fp::from(6), Fp::from(42)];
517
518 let (round_constants, mds, _) = OrchardNullifier::constants();
519
520 let hasher = Hash::<_, OrchardNullifier, ConstantLength<2>, 3, 2>::init();
521 let result = hasher.hash(message);
522
523 let mut state = [message[0], message[1], Fp::from_u128(2 << 64)];
526 permute::<_, OrchardNullifier, 3, 2>(&mut state, &mds, &round_constants);
527 assert_eq!(state[0], result);
528 }
529
530 #[test]
531 fn hasher_permute_equivalence() {
532 let message = [Fp::from(6), Fp::from(42)];
533 let hasher = Hash::<_, OrchardNullifier, ConstantLength<2>, 3, 2>::init();
534 let mut state = [Fp::from(6), Fp::from(42), Fp::from_u128(2 << 64)];
537
538 hasher.permute(&mut state);
539
540 let result = hasher.hash(message);
541 assert_eq!(state[0], result);
542 }
543
544 #[test]
545 fn spec_equivalence() {
546 let message = [Fp::from(6), Fp::from(42)];
547 let hasher1 = Hash::<_, P128Pow5T3<Fp>, ConstantLength<2>, 3, 2>::init();
548 let hasher2 = Hash::<_, P128Pow5T3Compact<Fp>, ConstantLength<2>, 3, 2>::init();
549
550 let result1 = hasher1.hash(message);
551 let result2 = hasher2.hash(message);
552 assert_eq!(result1, result2);
553 }
554}