p3_poseidon/
lib.rs
1#![no_std]
4
5extern crate alloc;
6
7use alloc::vec::Vec;
8
9use p3_field::{FieldAlgebra, PrimeField};
10use p3_mds::MdsPermutation;
11use p3_symmetric::{CryptographicPermutation, Permutation};
12use rand::distributions::Standard;
13use rand::prelude::Distribution;
14use rand::Rng;
15
16#[derive(Clone, Debug)]
18pub struct Poseidon<F, Mds, const WIDTH: usize, const ALPHA: u64> {
19 half_num_full_rounds: usize,
20 num_partial_rounds: usize,
21 constants: Vec<F>,
22 mds: Mds,
23}
24
25impl<F, Mds, const WIDTH: usize, const ALPHA: u64> Poseidon<F, Mds, WIDTH, ALPHA>
26where
27 F: PrimeField,
28{
29 pub fn new(
34 half_num_full_rounds: usize,
35 num_partial_rounds: usize,
36 constants: Vec<F>,
37 mds: Mds,
38 ) -> Self {
39 let num_rounds = 2 * half_num_full_rounds + num_partial_rounds;
40 assert_eq!(constants.len(), WIDTH * num_rounds);
41 Self {
42 half_num_full_rounds,
43 num_partial_rounds,
44 constants,
45 mds,
46 }
47 }
48
49 pub fn new_from_rng<R: Rng>(
50 half_num_full_rounds: usize,
51 num_partial_rounds: usize,
52 mds: Mds,
53 rng: &mut R,
54 ) -> Self
55 where
56 Standard: Distribution<F>,
57 {
58 let num_rounds = 2 * half_num_full_rounds + num_partial_rounds;
59 let num_constants = WIDTH * num_rounds;
60 let constants = rng
61 .sample_iter(Standard)
62 .take(num_constants)
63 .collect::<Vec<_>>();
64 Self {
65 half_num_full_rounds,
66 num_partial_rounds,
67 constants,
68 mds,
69 }
70 }
71
72 fn half_full_rounds<FA>(&self, state: &mut [FA; WIDTH], round_ctr: &mut usize)
73 where
74 FA: FieldAlgebra<F = F>,
75 Mds: MdsPermutation<FA, WIDTH>,
76 {
77 for _ in 0..self.half_num_full_rounds {
78 self.constant_layer(state, *round_ctr);
79 Self::full_sbox_layer(state);
80 self.mds.permute_mut(state);
81 *round_ctr += 1;
82 }
83 }
84
85 fn partial_rounds<FA>(&self, state: &mut [FA; WIDTH], round_ctr: &mut usize)
86 where
87 FA: FieldAlgebra<F = F>,
88 Mds: MdsPermutation<FA, WIDTH>,
89 {
90 for _ in 0..self.num_partial_rounds {
91 self.constant_layer(state, *round_ctr);
92 Self::partial_sbox_layer(state);
93 self.mds.permute_mut(state);
94 *round_ctr += 1;
95 }
96 }
97
98 fn full_sbox_layer<FA>(state: &mut [FA; WIDTH])
99 where
100 FA: FieldAlgebra<F = F>,
101 {
102 for x in state.iter_mut() {
103 *x = x.exp_const_u64::<ALPHA>();
104 }
105 }
106
107 fn partial_sbox_layer<FA>(state: &mut [FA; WIDTH])
108 where
109 FA: FieldAlgebra<F = F>,
110 {
111 state[0] = state[0].exp_const_u64::<ALPHA>();
112 }
113
114 fn constant_layer<FA>(&self, state: &mut [FA; WIDTH], round: usize)
115 where
116 FA: FieldAlgebra<F = F>,
117 {
118 for (i, x) in state.iter_mut().enumerate() {
119 *x += FA::from_f(self.constants[round * WIDTH + i]);
120 }
121 }
122}
123
124impl<FA, Mds, const WIDTH: usize, const ALPHA: u64> Permutation<[FA; WIDTH]>
125 for Poseidon<FA::F, Mds, WIDTH, ALPHA>
126where
127 FA: FieldAlgebra,
128 FA::F: PrimeField,
129 Mds: MdsPermutation<FA, WIDTH>,
130{
131 fn permute_mut(&self, state: &mut [FA; WIDTH]) {
132 let mut round_ctr = 0;
133 self.half_full_rounds(state, &mut round_ctr);
134 self.partial_rounds(state, &mut round_ctr);
135 self.half_full_rounds(state, &mut round_ctr);
136 }
137}
138
139impl<FA, Mds, const WIDTH: usize, const ALPHA: u64> CryptographicPermutation<[FA; WIDTH]>
140 for Poseidon<FA::F, Mds, WIDTH, ALPHA>
141where
142 FA: FieldAlgebra,
143 FA::F: PrimeField,
144 Mds: MdsPermutation<FA, WIDTH>,
145{
146}