1use core::ops::Mul;
17
18use p3_field::{Field, FieldAlgebra, PrimeField32};
19use p3_monty_31::{
20 GenericPoseidon2LinearLayersMonty31, InternalLayerBaseParameters, InternalLayerParameters,
21 MontyField31, Poseidon2ExternalLayerMonty31, Poseidon2InternalLayerMonty31,
22};
23use p3_poseidon2::Poseidon2;
24
25use crate::{KoalaBear, KoalaBearParameters};
26
27pub type Poseidon2InternalLayerKoalaBear<const WIDTH: usize> =
28 Poseidon2InternalLayerMonty31<KoalaBearParameters, WIDTH, KoalaBearInternalLayerParameters>;
29
30pub type Poseidon2ExternalLayerKoalaBear<const WIDTH: usize> =
31 Poseidon2ExternalLayerMonty31<KoalaBearParameters, WIDTH>;
32
33const KOALABEAR_S_BOX_DEGREE: u64 = 3;
38
39pub type Poseidon2KoalaBear<const WIDTH: usize> = Poseidon2<
44 <KoalaBear as Field>::Packing,
45 Poseidon2ExternalLayerKoalaBear<WIDTH>,
46 Poseidon2InternalLayerKoalaBear<WIDTH>,
47 WIDTH,
48 KOALABEAR_S_BOX_DEGREE,
49>;
50
51pub type GenericPoseidon2LinearLayersKoalaBear =
57 GenericPoseidon2LinearLayersMonty31<KoalaBearParameters, KoalaBearInternalLayerParameters>;
58
59const INTERNAL_DIAG_MONTY_16: [KoalaBear; 16] = KoalaBear::new_array([
67 KoalaBear::ORDER_U32 - 2,
68 1,
69 2,
70 (KoalaBear::ORDER_U32 + 1) >> 1,
71 3,
72 4,
73 (KoalaBear::ORDER_U32 - 1) >> 1,
74 KoalaBear::ORDER_U32 - 3,
75 KoalaBear::ORDER_U32 - 4,
76 KoalaBear::ORDER_U32 - ((KoalaBear::ORDER_U32 - 1) >> 8),
77 KoalaBear::ORDER_U32 - ((KoalaBear::ORDER_U32 - 1) >> 3),
78 KoalaBear::ORDER_U32 - 127,
79 (KoalaBear::ORDER_U32 - 1) >> 8,
80 (KoalaBear::ORDER_U32 - 1) >> 3,
81 (KoalaBear::ORDER_U32 - 1) >> 4,
82 127,
83]);
84
85const INTERNAL_DIAG_MONTY_24: [KoalaBear; 24] = KoalaBear::new_array([
88 KoalaBear::ORDER_U32 - 2,
89 1,
90 2,
91 (KoalaBear::ORDER_U32 + 1) >> 1,
92 3,
93 4,
94 (KoalaBear::ORDER_U32 - 1) >> 1,
95 KoalaBear::ORDER_U32 - 3,
96 KoalaBear::ORDER_U32 - 4,
97 KoalaBear::ORDER_U32 - ((KoalaBear::ORDER_U32 - 1) >> 8),
98 KoalaBear::ORDER_U32 - ((KoalaBear::ORDER_U32 - 1) >> 2),
99 KoalaBear::ORDER_U32 - ((KoalaBear::ORDER_U32 - 1) >> 3),
100 KoalaBear::ORDER_U32 - ((KoalaBear::ORDER_U32 - 1) >> 4),
101 KoalaBear::ORDER_U32 - ((KoalaBear::ORDER_U32 - 1) >> 5),
102 KoalaBear::ORDER_U32 - ((KoalaBear::ORDER_U32 - 1) >> 6),
103 KoalaBear::ORDER_U32 - 127,
104 (KoalaBear::ORDER_U32 - 1) >> 8,
105 (KoalaBear::ORDER_U32 - 1) >> 3,
106 (KoalaBear::ORDER_U32 - 1) >> 4,
107 (KoalaBear::ORDER_U32 - 1) >> 5,
108 (KoalaBear::ORDER_U32 - 1) >> 6,
109 (KoalaBear::ORDER_U32 - 1) >> 7,
110 (KoalaBear::ORDER_U32 - 1) >> 9,
111 127,
112]);
113
114#[derive(Debug, Clone, Default)]
116pub struct KoalaBearInternalLayerParameters;
117
118impl InternalLayerBaseParameters<KoalaBearParameters, 16> for KoalaBearInternalLayerParameters {
119 type ArrayLike = [MontyField31<KoalaBearParameters>; 15];
120
121 const INTERNAL_DIAG_MONTY: [MontyField31<KoalaBearParameters>; 16] = INTERNAL_DIAG_MONTY_16;
122
123 fn internal_layer_mat_mul(
126 state: &mut [MontyField31<KoalaBearParameters>; 16],
127 sum: MontyField31<KoalaBearParameters>,
128 ) {
129 state[1] += sum;
132 state[2] = state[2].double() + sum;
133 state[3] = state[3].halve() + sum;
134 state[4] = sum + state[4].double() + state[4];
135 state[5] = sum + state[5].double().double();
136 state[6] = sum - state[6].halve();
137 state[7] = sum - (state[7].double() + state[7]);
138 state[8] = sum - state[8].double().double();
139 state[9] = state[9].mul_2exp_neg_n(8);
140 state[9] += sum;
141 state[10] = state[10].mul_2exp_neg_n(3);
142 state[10] += sum;
143 state[11] = state[11].mul_2exp_neg_n(24);
144 state[11] += sum;
145 state[12] = state[12].mul_2exp_neg_n(8);
146 state[12] = sum - state[12];
147 state[13] = state[13].mul_2exp_neg_n(3);
148 state[13] = sum - state[13];
149 state[14] = state[14].mul_2exp_neg_n(4);
150 state[14] = sum - state[14];
151 state[15] = state[15].mul_2exp_neg_n(24);
152 state[15] = sum - state[15];
153 }
154
155 fn generic_internal_linear_layer<FA>(state: &mut [FA; 16])
156 where
157 FA: FieldAlgebra + Mul<KoalaBear, Output = FA>,
158 {
159 let part_sum: FA = state[1..].iter().cloned().sum();
160 let full_sum = part_sum.clone() + state[0].clone();
161
162 state[0] = part_sum - state[0].clone();
164 state[1] = full_sum.clone() + state[1].clone();
165 state[2] = full_sum.clone() + state[2].double();
166
167 state
171 .iter_mut()
172 .zip(INTERNAL_DIAG_MONTY_16)
173 .skip(3)
174 .for_each(|(val, diag_elem)| {
175 *val = full_sum.clone() + val.clone() * diag_elem;
176 });
177 }
178}
179
180impl InternalLayerBaseParameters<KoalaBearParameters, 24> for KoalaBearInternalLayerParameters {
181 type ArrayLike = [MontyField31<KoalaBearParameters>; 23];
182
183 const INTERNAL_DIAG_MONTY: [MontyField31<KoalaBearParameters>; 24] = INTERNAL_DIAG_MONTY_24;
184
185 fn internal_layer_mat_mul(
188 state: &mut [MontyField31<KoalaBearParameters>; 24],
189 sum: MontyField31<KoalaBearParameters>,
190 ) {
191 state[1] += sum;
194 state[2] = state[2].double() + sum;
195 state[3] = state[3].halve() + sum;
196 state[4] = sum + state[4].double() + state[4];
197 state[5] = sum + state[5].double().double();
198 state[6] = sum - state[6].halve();
199 state[7] = sum - (state[7].double() + state[7]);
200 state[8] = sum - state[8].double().double();
201 state[9] = state[9].mul_2exp_neg_n(8);
202 state[9] += sum;
203 state[10] = state[10].mul_2exp_neg_n(2);
204 state[10] += sum;
205 state[11] = state[11].mul_2exp_neg_n(3);
206 state[11] += sum;
207 state[12] = state[12].mul_2exp_neg_n(4);
208 state[12] += sum;
209 state[13] = state[13].mul_2exp_neg_n(5);
210 state[13] += sum;
211 state[14] = state[14].mul_2exp_neg_n(6);
212 state[14] += sum;
213 state[15] = state[15].mul_2exp_neg_n(24);
214 state[15] += sum;
215 state[16] = state[16].mul_2exp_neg_n(8);
216 state[16] = sum - state[16];
217 state[17] = state[17].mul_2exp_neg_n(3);
218 state[17] = sum - state[17];
219 state[18] = state[18].mul_2exp_neg_n(4);
220 state[18] = sum - state[18];
221 state[19] = state[19].mul_2exp_neg_n(5);
222 state[19] = sum - state[19];
223 state[20] = state[20].mul_2exp_neg_n(6);
224 state[20] = sum - state[20];
225 state[21] = state[21].mul_2exp_neg_n(7);
226 state[21] = sum - state[21];
227 state[22] = state[22].mul_2exp_neg_n(9);
228 state[22] = sum - state[22];
229 state[23] = state[23].mul_2exp_neg_n(24);
230 state[23] = sum - state[23];
231 }
232
233 fn generic_internal_linear_layer<FA>(state: &mut [FA; 24])
234 where
235 FA: FieldAlgebra + core::ops::Mul<KoalaBear, Output = FA>,
236 {
237 let part_sum: FA = state[1..].iter().cloned().sum();
238 let full_sum = part_sum.clone() + state[0].clone();
239
240 state[0] = part_sum - state[0].clone();
242 state[1] = full_sum.clone() + state[1].clone();
243 state[2] = full_sum.clone() + state[2].double();
244
245 state
249 .iter_mut()
250 .zip(INTERNAL_DIAG_MONTY_24)
251 .skip(3)
252 .for_each(|(val, diag_elem)| {
253 *val = full_sum.clone() + val.clone() * diag_elem;
254 });
255 }
256}
257
258impl InternalLayerParameters<KoalaBearParameters, 16> for KoalaBearInternalLayerParameters {}
259impl InternalLayerParameters<KoalaBearParameters, 24> for KoalaBearInternalLayerParameters {}
260
261#[cfg(test)]
262mod tests {
263 use p3_field::FieldAlgebra;
264 use p3_symmetric::Permutation;
265 use rand::{Rng, SeedableRng};
266 use rand_xoshiro::Xoroshiro128Plus;
267
268 use super::*;
269
270 type F = KoalaBear;
271
272 #[test]
280 fn test_poseidon2_width_16_random() {
281 let mut input: [F; 16] = [
282 894848333, 1437655012, 1200606629, 1690012884, 71131202, 1749206695, 1717947831,
283 120589055, 19776022, 42382981, 1831865506, 724844064, 171220207, 1299207443, 227047920,
284 1783754913,
285 ]
286 .map(F::from_canonical_u32);
287
288 let expected: [F; 16] = [
289 652590279, 1200629963, 1013089423, 1840372851, 19101828, 561050015, 1714865585,
290 994637181, 498949829, 729884572, 1957973925, 263012103, 535029297, 2121808603,
291 964663675, 1473622080,
292 ]
293 .map(F::from_canonical_u32);
294
295 let mut rng = Xoroshiro128Plus::seed_from_u64(1);
296 let perm = Poseidon2KoalaBear::new_from_rng_128(&mut rng);
297
298 perm.permute_mut(&mut input);
299 assert_eq!(input, expected);
300 }
301
302 #[test]
307 fn test_poseidon2_width_24_random() {
308 let mut input: [F; 24] = [
309 886409618, 1327899896, 1902407911, 591953491, 648428576, 1844789031, 1198336108,
310 355597330, 1799586834, 59617783, 790334801, 1968791836, 559272107, 31054313,
311 1042221543, 474748436, 135686258, 263665994, 1962340735, 1741539604, 2026927696,
312 449439011, 1131357108, 50869465,
313 ]
314 .map(F::from_canonical_u32);
315
316 let expected: [F; 24] = [
317 3825456, 486989921, 613714063, 282152282, 1027154688, 1171655681, 879344953,
318 1090688809, 1960721991, 1604199242, 1329947150, 1535171244, 781646521, 1156559780,
319 1875690339, 368140677, 457503063, 304208551, 1919757655, 835116474, 1293372648,
320 1254825008, 810923913, 1773631109,
321 ]
322 .map(F::from_canonical_u32);
323
324 let mut rng = Xoroshiro128Plus::seed_from_u64(1);
325 let perm = Poseidon2KoalaBear::new_from_rng_128(&mut rng);
326
327 perm.permute_mut(&mut input);
328 assert_eq!(input, expected);
329 }
330
331 #[test]
334 fn test_generic_internal_linear_layer_16() {
335 let mut rng = rand::thread_rng();
336 let mut input1: [F; 16] = rng.gen();
337 let mut input2 = input1;
338
339 let part_sum: F = input1[1..].iter().cloned().sum();
340 let full_sum = part_sum + input1[0];
341
342 input1[0] = part_sum - input1[0];
343
344 KoalaBearInternalLayerParameters::internal_layer_mat_mul(&mut input1, full_sum);
345 KoalaBearInternalLayerParameters::generic_internal_linear_layer(&mut input2);
346
347 assert_eq!(input1, input2);
348 }
349
350 #[test]
353 fn test_generic_internal_linear_layer_24() {
354 let mut rng = rand::thread_rng();
355 let mut input1: [F; 24] = rng.gen();
356 let mut input2 = input1;
357
358 let part_sum: F = input1[1..].iter().cloned().sum();
359 let full_sum = part_sum + input1[0];
360
361 input1[0] = part_sum - input1[0];
362
363 KoalaBearInternalLayerParameters::internal_layer_mat_mul(&mut input1, full_sum);
364 KoalaBearInternalLayerParameters::generic_internal_linear_layer(&mut input2);
365
366 assert_eq!(input1, input2);
367 }
368}