p3_poseidon2/
external.rs
1use alloc::vec::Vec;
2
3use p3_field::FieldAlgebra;
4use p3_mds::MdsPermutation;
5use p3_symmetric::Permutation;
6use rand::distributions::{Distribution, Standard};
7use rand::Rng;
8
9#[inline(always)]
17fn apply_hl_mat4<FA>(x: &mut [FA; 4])
18where
19 FA: FieldAlgebra,
20{
21 let t0 = x[0].clone() + x[1].clone();
22 let t1 = x[2].clone() + x[3].clone();
23 let t2 = x[1].clone() + x[1].clone() + t1.clone();
24 let t3 = x[3].clone() + x[3].clone() + t0.clone();
25 let t4 = t1.double().double() + t3.clone();
26 let t5 = t0.double().double() + t2.clone();
27 let t6 = t3 + t5.clone();
28 let t7 = t2 + t4.clone();
29 x[0] = t6;
30 x[1] = t5;
31 x[2] = t7;
32 x[3] = t4;
33}
34
35#[inline(always)]
43fn apply_mat4<FA>(x: &mut [FA; 4])
44where
45 FA: FieldAlgebra,
46{
47 let t01 = x[0].clone() + x[1].clone();
48 let t23 = x[2].clone() + x[3].clone();
49 let t0123 = t01.clone() + t23.clone();
50 let t01123 = t0123.clone() + x[1].clone();
51 let t01233 = t0123.clone() + x[3].clone();
52 x[3] = t01233.clone() + x[0].double(); x[1] = t01123.clone() + x[2].double(); x[0] = t01123 + t01; x[2] = t01233 + t23; }
58
59#[derive(Clone, Default)]
63pub struct HLMDSMat4;
64
65impl<FA: FieldAlgebra> Permutation<[FA; 4]> for HLMDSMat4 {
66 #[inline(always)]
67 fn permute_mut(&self, input: &mut [FA; 4]) {
68 apply_hl_mat4(input)
69 }
70}
71impl<FA: FieldAlgebra> MdsPermutation<FA, 4> for HLMDSMat4 {}
72
73#[derive(Clone, Default)]
77pub struct MDSMat4;
78
79impl<FA: FieldAlgebra> Permutation<[FA; 4]> for MDSMat4 {
80 #[inline(always)]
81 fn permute_mut(&self, input: &mut [FA; 4]) {
82 apply_mat4(input)
83 }
84}
85impl<FA: FieldAlgebra> MdsPermutation<FA, 4> for MDSMat4 {}
86
87#[inline(always)]
92pub fn mds_light_permutation<
93 FA: FieldAlgebra,
94 MdsPerm4: MdsPermutation<FA, 4>,
95 const WIDTH: usize,
96>(
97 state: &mut [FA; WIDTH],
98 mdsmat: &MdsPerm4,
99) {
100 match WIDTH {
101 2 => {
102 let sum = state[0].clone() + state[1].clone();
103 state[0] += sum.clone();
104 state[1] += sum;
105 }
106
107 3 => {
108 let sum = state[0].clone() + state[1].clone() + state[2].clone();
109 state[0] += sum.clone();
110 state[1] += sum.clone();
111 state[2] += sum;
112 }
113
114 4 | 8 | 12 | 16 | 20 | 24 => {
115 for chunk in state.chunks_exact_mut(4) {
118 mdsmat.permute_mut(chunk.try_into().unwrap());
119 }
120 let sums: [FA; 4] = core::array::from_fn(|k| {
124 (0..WIDTH)
125 .step_by(4)
126 .map(|j| state[j + k].clone())
127 .sum::<FA>()
128 });
129
130 state
133 .iter_mut()
134 .enumerate()
135 .for_each(|(i, elem)| *elem += sums[i % 4].clone());
136 }
137
138 _ => {
139 panic!("Unsupported width");
140 }
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct ExternalLayerConstants<T, const WIDTH: usize> {
147 initial: Vec<[T; WIDTH]>,
149 terminal: Vec<[T; WIDTH]>, }
151
152impl<T, const WIDTH: usize> ExternalLayerConstants<T, WIDTH> {
153 pub fn new(initial: Vec<[T; WIDTH]>, terminal: Vec<[T; WIDTH]>) -> Self {
154 assert_eq!(
155 initial.len(),
156 terminal.len(),
157 "The number of initial and terminal external rounds should be equal."
158 );
159 Self { initial, terminal }
160 }
161
162 pub fn new_from_rng<R: Rng>(external_round_number: usize, rng: &mut R) -> Self
163 where
164 Standard: Distribution<[T; WIDTH]>,
165 {
166 let half_f = external_round_number / 2;
167 assert_eq!(
168 2 * half_f,
169 external_round_number,
170 "The total number of external rounds should be even"
171 );
172 let initial_constants = rng.sample_iter(Standard).take(half_f).collect();
173 let terminal_constants = rng.sample_iter(Standard).take(half_f).collect();
174
175 Self::new(initial_constants, terminal_constants)
176 }
177
178 pub fn new_from_saved_array<U, const N: usize>(
179 [initial, terminal]: [[[U; WIDTH]; N]; 2],
180 conversion_fn: fn([U; WIDTH]) -> [T; WIDTH],
181 ) -> Self
182 where
183 T: Clone,
184 {
185 let initial_consts = initial.map(conversion_fn).to_vec();
186 let terminal_consts = terminal.map(conversion_fn).to_vec();
187 Self::new(initial_consts, terminal_consts)
188 }
189
190 pub fn get_initial_constants(&self) -> &Vec<[T; WIDTH]> {
191 &self.initial
192 }
193
194 pub fn get_terminal_constants(&self) -> &Vec<[T; WIDTH]> {
195 &self.terminal
196 }
197}
198
199pub trait ExternalLayerConstructor<FA, const WIDTH: usize>
201where
202 FA: FieldAlgebra,
203{
204 fn new_from_constants(external_constants: ExternalLayerConstants<FA::F, WIDTH>) -> Self;
207}
208
209pub trait ExternalLayer<FA, const WIDTH: usize, const D: u64>: Sync + Clone
211where
212 FA: FieldAlgebra,
213{
214 fn permute_state_initial(&self, state: &mut [FA; WIDTH]);
219
220 fn permute_state_terminal(&self, state: &mut [FA; WIDTH]);
222}
223
224#[inline]
226pub fn external_terminal_permute_state<
227 FA: FieldAlgebra,
228 CT: Copy, MdsPerm4: MdsPermutation<FA, 4>,
230 const WIDTH: usize,
231>(
232 state: &mut [FA; WIDTH],
233 terminal_external_constants: &[[CT; WIDTH]],
234 add_rc_and_sbox: fn(&mut FA, CT),
235 mat4: &MdsPerm4,
236) {
237 for elem in terminal_external_constants.iter() {
238 state
239 .iter_mut()
240 .zip(elem.iter())
241 .for_each(|(s, &rc)| add_rc_and_sbox(s, rc));
242 mds_light_permutation(state, mat4);
243 }
244}
245
246#[inline]
248pub fn external_initial_permute_state<
249 FA: FieldAlgebra,
250 CT: Copy, MdsPerm4: MdsPermutation<FA, 4>,
252 const WIDTH: usize,
253>(
254 state: &mut [FA; WIDTH],
255 initial_external_constants: &[[CT; WIDTH]],
256 add_rc_and_sbox: fn(&mut FA, CT),
257 mat4: &MdsPerm4,
258) {
259 mds_light_permutation(state, mat4);
260 external_terminal_permute_state(state, initial_external_constants, add_rc_and_sbox, mat4)
263}