1use super::poseidon2_params::Poseidon2Params;
2use crate::merkle_tree::merkle_tree_fp::MerkleTreeHash;
3use ark_ff::PrimeField;
4use std::sync::Arc;
5
6#[derive(Clone, Debug)]
7pub struct Poseidon2<F: PrimeField> {
8 pub(crate) params: Arc<Poseidon2Params<F>>,
9}
10
11impl<F: PrimeField> Poseidon2<F> {
12 pub fn new(params: &Arc<Poseidon2Params<F>>) -> Self {
13 Poseidon2 {
14 params: Arc::clone(params),
15 }
16 }
17
18 pub fn get_t(&self) -> usize {
19 self.params.t
20 }
21
22 pub fn permutation(&self, input: &[F]) -> Vec<F> {
23 let t = self.params.t;
24 assert_eq!(input.len(), t);
25
26 let mut current_state = input.to_owned();
27
28 self.matmul_external(&mut current_state);
30
31 for r in 0..self.params.rounds_f_beginning {
32 current_state = self.add_rc(¤t_state, &self.params.round_constants[r]);
33 current_state = self.sbox(¤t_state);
34 self.matmul_external(&mut current_state);
35 }
36
37 let p_end = self.params.rounds_f_beginning + self.params.rounds_p;
38 for r in self.params.rounds_f_beginning..p_end {
39 current_state[0].add_assign(&self.params.round_constants[r][0]);
40 current_state[0] = self.sbox_p(¤t_state[0]);
41 self.matmul_internal(&mut current_state, &self.params.mat_internal_diag_m_1);
42 }
43
44 for r in p_end..self.params.rounds {
45 current_state = self.add_rc(¤t_state, &self.params.round_constants[r]);
46 current_state = self.sbox(¤t_state);
47 self.matmul_external(&mut current_state);
48 }
49 current_state
50 }
51
52 fn sbox(&self, input: &[F]) -> Vec<F> {
53 input.iter().map(|el| self.sbox_p(el)).collect()
54 }
55
56 fn sbox_p(&self, input: &F) -> F {
57 let mut input2 = *input;
58 input2.square_in_place();
59
60 match self.params.d {
61 3 => {
62 let mut out = input2;
63 out.mul_assign(input);
64 out
65 }
66 5 => {
67 let mut out = input2;
68 out.square_in_place();
69 out.mul_assign(input);
70 out
71 }
72 7 => {
73 let mut out = input2;
74 out.square_in_place();
75 out.mul_assign(&input2);
76 out.mul_assign(input);
77 out
78 }
79 _ => {
80 panic!()
81 }
82 }
83 }
84
85 fn matmul_m4(&self, input: &mut[F]) {
86 let t = self.params.t;
87 let t4 = t / 4;
88 for i in 0..t4 {
89 let start_index = i * 4;
90 let mut t_0 = input[start_index];
91 t_0.add_assign(&input[start_index + 1]);
92 let mut t_1 = input[start_index + 2];
93 t_1.add_assign(&input[start_index + 3]);
94 let mut t_2 = input[start_index + 1];
95 t_2.double_in_place();
96 t_2.add_assign(&t_1);
97 let mut t_3 = input[start_index + 3];
98 t_3.double_in_place();
99 t_3.add_assign(&t_0);
100 let mut t_4 = t_1;
101 t_4.double_in_place();
102 t_4.double_in_place();
103 t_4.add_assign(&t_3);
104 let mut t_5 = t_0;
105 t_5.double_in_place();
106 t_5.double_in_place();
107 t_5.add_assign(&t_2);
108 let mut t_6 = t_3;
109 t_6.add_assign(&t_5);
110 let mut t_7 = t_2;
111 t_7.add_assign(&t_4);
112 input[start_index] = t_6;
113 input[start_index + 1] = t_5;
114 input[start_index + 2] = t_7;
115 input[start_index + 3] = t_4;
116 }
117 }
118
119 fn matmul_external(&self, input: &mut[F]) {
120 let t = self.params.t;
121 match t {
122 2 => {
123 let mut sum = input[0];
125 sum.add_assign(&input[1]);
126 input[0].add_assign(&sum);
127 input[1].add_assign(&sum);
128 }
129 3 => {
130 let mut sum = input[0];
132 sum.add_assign(&input[1]);
133 sum.add_assign(&input[2]);
134 input[0].add_assign(&sum);
135 input[1].add_assign(&sum);
136 input[2].add_assign(&sum);
137 }
138 4 => {
139 self.matmul_m4(input);
141 }
142 8 | 12 | 16 | 20 | 24 => {
143 self.matmul_m4(input);
145
146 let t4 = t / 4;
148 let mut stored = [F::zero(); 4];
149 for l in 0..4 {
150 stored[l] = input[l];
151 for j in 1..t4 {
152 stored[l].add_assign(&input[4 * j + l]);
153 }
154 }
155 for i in 0..input.len() {
156 input[i].add_assign(&stored[i % 4]);
157 }
158 }
159 _ => {
160 panic!()
161 }
162 }
163 }
164
165 fn matmul_internal(&self, input: &mut[F], mat_internal_diag_m_1: &[F]) {
166 let t = self.params.t;
167
168 match t {
169 2 => {
170 let mut sum = input[0];
173 sum.add_assign(&input[1]);
174 input[0].add_assign(&sum);
175 input[1].double_in_place();
176 input[1].add_assign(&sum);
177 }
178 3 => {
179 let mut sum = input[0];
183 sum.add_assign(&input[1]);
184 sum.add_assign(&input[2]);
185 input[0].add_assign(&sum);
186 input[1].add_assign(&sum);
187 input[2].double_in_place();
188 input[2].add_assign(&sum);
189 }
190 4 | 8 | 12 | 16 | 20 | 24 => {
191 let mut sum = input[0];
193 input
194 .iter()
195 .skip(1)
196 .take(t-1)
197 .for_each(|el| sum.add_assign(el));
198 for i in 0..input.len() {
200 input[i].mul_assign(&mat_internal_diag_m_1[i]);
201 input[i].add_assign(&sum);
202 }
203 }
204 _ => {
205 panic!()
206 }
207 }
208 }
209
210 fn add_rc(&self, input: &[F], rc: &[F]) -> Vec<F> {
211 input
212 .iter()
213 .zip(rc.iter())
214 .map(|(a, b)| {
215 let mut r = *a;
216 r.add_assign(b);
217 r
218 })
219 .collect()
220 }
221}
222
223impl<F: PrimeField> MerkleTreeHash<F> for Poseidon2<F> {
224 fn compress(&self, input: &[&F]) -> F {
225 self.permutation(&[input[0].to_owned(), input[1].to_owned(), F::zero()])[0]
226 }
227}
228
229#[allow(unused_imports)]
230#[cfg(test)]
231mod poseidon2_tests_goldilocks {
232 use super::*;
233 use crate::{fields::{goldilocks::FpGoldiLocks, utils::from_hex, utils::random_scalar}};
234 use crate::poseidon2::poseidon2_instance_goldilocks::{
235 POSEIDON2_GOLDILOCKS_8_PARAMS,
236 POSEIDON2_GOLDILOCKS_12_PARAMS,
237 POSEIDON2_GOLDILOCKS_16_PARAMS,
238 POSEIDON2_GOLDILOCKS_20_PARAMS,
239 };
240 use std::convert::TryFrom;
241
242 type Scalar = FpGoldiLocks;
243
244 static TESTRUNS: usize = 5;
245
246 #[test]
247 fn consistent_perm() {
248 let instances = vec![
249 Poseidon2::new(&POSEIDON2_GOLDILOCKS_8_PARAMS),
250 Poseidon2::new(&POSEIDON2_GOLDILOCKS_12_PARAMS),
251 Poseidon2::new(&POSEIDON2_GOLDILOCKS_16_PARAMS),
252 Poseidon2::new(&POSEIDON2_GOLDILOCKS_20_PARAMS),
253 ];
254 for instance in instances {
255 let t = instance.params.t;
256 for _ in 0..TESTRUNS {
257 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
258
259 let mut input2: Vec<Scalar>;
260 loop {
261 input2 = (0..t).map(|_| random_scalar()).collect();
262 if input1 != input2 {
263 break;
264 }
265 }
266
267 let perm1 = instance.permutation(&input1);
268 let perm2 = instance.permutation(&input1);
269 let perm3 = instance.permutation(&input2);
270 assert_eq!(perm1, perm2);
271 assert_ne!(perm1, perm3);
272 }
273 }
274 }
275
276 #[test]
277 fn kats() {
278 let poseidon2 = Poseidon2::new(&POSEIDON2_GOLDILOCKS_12_PARAMS);
279 let mut input: Vec<Scalar> = vec![];
280 for i in 0..poseidon2.params.t {
281 input.push(Scalar::from(i as u64));
282 }
283 let perm = poseidon2.permutation(&input);
284 assert_eq!(perm[0], from_hex("0x01eaef96bdf1c0c1"));
285 assert_eq!(perm[1], from_hex("0x1f0d2cc525b2540c"));
286 assert_eq!(perm[2], from_hex("0x6282c1dfe1e0358d"));
287 assert_eq!(perm[3], from_hex("0xe780d721f698e1e6"));
288 assert_eq!(perm[4], from_hex("0x280c0b6f753d833b"));
289 assert_eq!(perm[5], from_hex("0x1b942dd5023156ab"));
290 assert_eq!(perm[6], from_hex("0x43f0df3fcccb8398"));
291 assert_eq!(perm[7], from_hex("0xe8e8190585489025"));
292 assert_eq!(perm[8], from_hex("0x56bdbf72f77ada22"));
293 assert_eq!(perm[9], from_hex("0x7911c32bf9dcd705"));
294 assert_eq!(perm[10], from_hex("0xec467926508fbe67"));
295 assert_eq!(perm[11], from_hex("0x6a50450ddf85a6ed"));
296 }
297}
298
299#[allow(unused_imports)]
300#[cfg(test)]
301mod poseidon2_tests_babybear {
302 use super::*;
303 use crate::{fields::{babybear::FpBabyBear, utils::from_hex, utils::random_scalar}};
304 use crate::poseidon2::poseidon2_instance_babybear::{
305 POSEIDON2_BABYBEAR_16_PARAMS,
306 POSEIDON2_BABYBEAR_24_PARAMS,
307 };
308 use std::convert::TryFrom;
309
310 type Scalar = FpBabyBear;
311
312 static TESTRUNS: usize = 5;
313
314 #[test]
315 fn consistent_perm() {
316 let instances = vec![
317 Poseidon2::new(&POSEIDON2_BABYBEAR_16_PARAMS),
318 Poseidon2::new(&POSEIDON2_BABYBEAR_24_PARAMS)
319 ];
320 for instance in instances {
321 let t = instance.params.t;
322 for _ in 0..TESTRUNS {
323 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
324
325 let mut input2: Vec<Scalar>;
326 loop {
327 input2 = (0..t).map(|_| random_scalar()).collect();
328 if input1 != input2 {
329 break;
330 }
331 }
332
333 let perm1 = instance.permutation(&input1);
334 let perm2 = instance.permutation(&input1);
335 let perm3 = instance.permutation(&input2);
336 assert_eq!(perm1, perm2);
337 assert_ne!(perm1, perm3);
338 }
339 }
340 }
341
342 #[test]
343 fn kats() {
344 let poseidon2 = Poseidon2::new(&POSEIDON2_BABYBEAR_24_PARAMS);
345 let mut input: Vec<Scalar> = vec![];
346 for i in 0..poseidon2.params.t {
347 input.push(Scalar::from(i as u64));
348 }
349 let perm = poseidon2.permutation(&input);
350 assert_eq!(perm[0], from_hex("0x2ed3e23d"));
351 assert_eq!(perm[1], from_hex("0x12921fb0"));
352 assert_eq!(perm[2], from_hex("0x0e659e79"));
353 assert_eq!(perm[3], from_hex("0x61d81dc9"));
354 assert_eq!(perm[4], from_hex("0x32bae33b"));
355 assert_eq!(perm[5], from_hex("0x62486ae3"));
356 assert_eq!(perm[6], from_hex("0x1e681b60"));
357 assert_eq!(perm[7], from_hex("0x24b91325"));
358 assert_eq!(perm[8], from_hex("0x2a2ef5b9"));
359 assert_eq!(perm[9], from_hex("0x50e8593e"));
360 assert_eq!(perm[10], from_hex("0x5bc818ec"));
361 assert_eq!(perm[11], from_hex("0x10691997"));
362 assert_eq!(perm[12], from_hex("0x35a14520"));
363 assert_eq!(perm[13], from_hex("0x2ba6a3c5"));
364 assert_eq!(perm[14], from_hex("0x279d47ec"));
365 assert_eq!(perm[15], from_hex("0x55014e81"));
366 assert_eq!(perm[16], from_hex("0x5953a67f"));
367 assert_eq!(perm[17], from_hex("0x2f403111"));
368 assert_eq!(perm[18], from_hex("0x6b8828ff"));
369 assert_eq!(perm[19], from_hex("0x1801301f"));
370 assert_eq!(perm[20], from_hex("0x2749207a"));
371 assert_eq!(perm[21], from_hex("0x3dc9cf21"));
372 assert_eq!(perm[22], from_hex("0x3c985ba2"));
373 assert_eq!(perm[23], from_hex("0x57a99864"));
374 }
375}
376
377#[allow(unused_imports)]
378#[cfg(test)]
379mod poseidon2_tests_bls12 {
380 use super::*;
381 use crate::{fields::{bls12::FpBLS12, utils::from_hex, utils::random_scalar}};
382 use crate::poseidon2::poseidon2_instance_bls12::{
383 POSEIDON2_BLS_2_PARAMS,
384 POSEIDON2_BLS_3_PARAMS,
385 POSEIDON2_BLS_4_PARAMS,
386 POSEIDON2_BLS_8_PARAMS,
387 };
388 use std::convert::TryFrom;
389
390 type Scalar = FpBLS12;
391
392 static TESTRUNS: usize = 5;
393
394 #[test]
395 fn consistent_perm() {
396 let instances = vec![
397 Poseidon2::new(&POSEIDON2_BLS_2_PARAMS),
398 Poseidon2::new(&POSEIDON2_BLS_3_PARAMS),
399 Poseidon2::new(&POSEIDON2_BLS_4_PARAMS),
400 Poseidon2::new(&POSEIDON2_BLS_8_PARAMS)
401 ];
402 for instance in instances {
403 let t = instance.params.t;
404 for _ in 0..TESTRUNS {
405 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
406
407 let mut input2: Vec<Scalar>;
408 loop {
409 input2 = (0..t).map(|_| random_scalar()).collect();
410 if input1 != input2 {
411 break;
412 }
413 }
414
415 let perm1 = instance.permutation(&input1);
416 let perm2 = instance.permutation(&input1);
417 let perm3 = instance.permutation(&input2);
418 assert_eq!(perm1, perm2);
419 assert_ne!(perm1, perm3);
420 }
421 }
422 }
423
424 #[test]
425 fn kats() {
426 let poseidon2_2 = Poseidon2::new(&POSEIDON2_BLS_2_PARAMS);
427 let mut input_2: Vec<Scalar> = vec![];
428 for i in 0..poseidon2_2.params.t {
429 input_2.push(Scalar::from(i as u64));
430 }
431 let perm_2 = poseidon2_2.permutation(&input_2);
432 assert_eq!(perm_2[0], from_hex("0x73c46dd530e248a87b61d19e67fa1b4ed30fc3d09f16531fe189fb945a15ce4e"));
433 assert_eq!(perm_2[1], from_hex("0x1f0e305ee21c9366d5793b80251405032a3fee32b9dd0b5f4578262891b043b4"));
434
435 let poseidon2_3 = Poseidon2::new(&POSEIDON2_BLS_3_PARAMS);
436 let mut input_3: Vec<Scalar> = vec![];
437 for i in 0..poseidon2_3.params.t {
438 input_3.push(Scalar::from(i as u64));
439 }
440 let perm_3 = poseidon2_3.permutation(&input_3);
441 assert_eq!(perm_3[0], from_hex("0x1b152349b1950b6a8ca75ee4407b6e26ca5cca5650534e56ef3fd45761fbf5f0"));
442 assert_eq!(perm_3[1], from_hex("0x4c5793c87d51bdc2c08a32108437dc0000bd0275868f09ebc5f36919af5b3891"));
443 assert_eq!(perm_3[2], from_hex("0x1fc8ed171e67902ca49863159fe5ba6325318843d13976143b8125f08b50dc6b"));
444
445 let poseidon2_4 = Poseidon2::new(&POSEIDON2_BLS_4_PARAMS);
446 let mut input_4: Vec<Scalar> = vec![];
447 for i in 0..poseidon2_4.params.t {
448 input_4.push(Scalar::from(i as u64));
449 }
450 let perm_4 = poseidon2_4.permutation(&input_4);
451 assert_eq!(perm_4[0], from_hex("0x28ff6c4edf9768c08ae26290487e93449cc8bc155fc2fad92a344adceb3ada6d"));
452 assert_eq!(perm_4[1], from_hex("0x0e56f2b6fad25075aa93560185b70e2b180ed7e269159c507c288b6747a0db2d"));
453 assert_eq!(perm_4[2], from_hex("0x6d8196f28da6006bb89b3df94600acdc03d0ba7c2b0f3f4409a54c1db6bf30d0"));
454 assert_eq!(perm_4[3], from_hex("0x07cfb49540ee456cce38b8a7d1a930a57ffc6660737f6589ef184c5e15334e36"));
455 }
456}
457
458#[allow(unused_imports)]
459#[cfg(test)]
460mod poseidon2_tests_bn256 {
461 use super::*;
462 use crate::{fields::{bn256::FpBN256, utils::from_hex, utils::random_scalar}, poseidon2::poseidon2_instance_bn256::POSEIDON2_BN256_PARAMS};
463 use std::convert::TryFrom;
464
465 type Scalar = FpBN256;
466
467 static TESTRUNS: usize = 5;
468
469 #[test]
470 fn consistent_perm() {
471 let poseidon2 = Poseidon2::new(&POSEIDON2_BN256_PARAMS);
472 let t = poseidon2.params.t;
473 for _ in 0..TESTRUNS {
474 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
475
476 let mut input2: Vec<Scalar>;
477 loop {
478 input2 = (0..t).map(|_| random_scalar()).collect();
479 if input1 != input2 {
480 break;
481 }
482 }
483
484 let perm1 = poseidon2.permutation(&input1);
485 let perm2 = poseidon2.permutation(&input1);
486 let perm3 = poseidon2.permutation(&input2);
487 assert_eq!(perm1, perm2);
488 assert_ne!(perm1, perm3);
489 }
490 }
491
492 #[test]
493 fn kats() {
494 let poseidon2 = Poseidon2::new(&POSEIDON2_BN256_PARAMS);
495 let mut input: Vec<Scalar> = vec![];
496 for i in 0..poseidon2.params.t {
497 input.push(Scalar::from(i as u64));
498 }
499 let perm = poseidon2.permutation(&input);
500 assert_eq!(perm[0], from_hex("0x0bb61d24daca55eebcb1929a82650f328134334da98ea4f847f760054f4a3033"));
501 assert_eq!(perm[1], from_hex("0x303b6f7c86d043bfcbcc80214f26a30277a15d3f74ca654992defe7ff8d03570"));
502 assert_eq!(perm[2], from_hex("0x1ed25194542b12eef8617361c3ba7c52e660b145994427cc86296242cf766ec8"));
503
504 }
505}
506
507#[allow(unused_imports)]
508#[cfg(test)]
509mod poseidon2_tests_pallas {
510 use super::*;
511 use crate::{fields::{pallas::FpPallas, utils::from_hex, utils::random_scalar}};
512 use crate::poseidon2::poseidon2_instance_pallas::{
513 POSEIDON2_PALLAS_3_PARAMS,
514 POSEIDON2_PALLAS_4_PARAMS,
515 POSEIDON2_PALLAS_8_PARAMS,
516 };
517 use std::convert::TryFrom;
518
519 type Scalar = FpPallas;
520
521 static TESTRUNS: usize = 5;
522
523 #[test]
524 fn consistent_perm() {
525 let instances = vec![
526 Poseidon2::new(&POSEIDON2_PALLAS_3_PARAMS),
527 Poseidon2::new(&POSEIDON2_PALLAS_4_PARAMS),
528 Poseidon2::new(&POSEIDON2_PALLAS_8_PARAMS)
529 ];
530 for instance in instances {
531 let t = instance.params.t;
532 for _ in 0..TESTRUNS {
533 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
534
535 let mut input2: Vec<Scalar>;
536 loop {
537 input2 = (0..t).map(|_| random_scalar()).collect();
538 if input1 != input2 {
539 break;
540 }
541 }
542
543 let perm1 = instance.permutation(&input1);
544 let perm2 = instance.permutation(&input1);
545 let perm3 = instance.permutation(&input2);
546 assert_eq!(perm1, perm2);
547 assert_ne!(perm1, perm3);
548 }
549 }
550 }
551
552 #[test]
553 fn kats() {
554 let poseidon2 = Poseidon2::new(&POSEIDON2_PALLAS_3_PARAMS);
555 let mut input: Vec<Scalar> = vec![];
556 for i in 0..poseidon2.params.t {
557 input.push(Scalar::from(i as u64));
558 }
559 let perm = poseidon2.permutation(&input);
560 assert_eq!(perm[0], from_hex("0x1a9b54c7512a914dd778282c44b3513fea7251420b9d95750baae059b2268d7a"));
561 assert_eq!(perm[1], from_hex("0x1c48ea0994a7d7984ea338a54dbf0c8681f5af883fe988d59ba3380c9f7901fc"));
562 assert_eq!(perm[2], from_hex("0x079ddd0a80a3e9414489b526a2770448964766685f4c4842c838f8a23120b401"));
563
564 }
565}
566
567#[allow(unused_imports)]
568#[cfg(test)]
569mod poseidon2_tests_vesta {
570 use super::*;
571 use crate::{fields::{vesta::FpVesta, utils::from_hex, utils::random_scalar}, poseidon2::poseidon2_instance_vesta::POSEIDON2_VESTA_PARAMS};
572 use std::convert::TryFrom;
573
574 type Scalar = FpVesta;
575
576 static TESTRUNS: usize = 5;
577
578 #[test]
579 fn consistent_perm() {
580 let poseidon2 = Poseidon2::new(&POSEIDON2_VESTA_PARAMS);
581 let t = poseidon2.params.t;
582 for _ in 0..TESTRUNS {
583 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
584
585 let mut input2: Vec<Scalar>;
586 loop {
587 input2 = (0..t).map(|_| random_scalar()).collect();
588 if input1 != input2 {
589 break;
590 }
591 }
592
593 let perm1 = poseidon2.permutation(&input1);
594 let perm2 = poseidon2.permutation(&input1);
595 let perm3 = poseidon2.permutation(&input2);
596 assert_eq!(perm1, perm2);
597 assert_ne!(perm1, perm3);
598 }
599 }
600
601 #[test]
602 fn kats() {
603 let poseidon2 = Poseidon2::new(&POSEIDON2_VESTA_PARAMS);
604 let mut input: Vec<Scalar> = vec![];
605 for i in 0..poseidon2.params.t {
606 input.push(Scalar::from(i as u64));
607 }
608 let perm = poseidon2.permutation(&input);
609 assert_eq!(perm[0], from_hex("0x261ecbdfd62c617b82d297705f18c788fc9831b14a6a2b8f61229bef68ce2792"));
610 assert_eq!(perm[1], from_hex("0x2c76327e0b7653873263158cf8545c282364b183880fcdea93ca8526d518c66f"));
611 assert_eq!(perm[2], from_hex("0x262316c0ce5244838c75873299b59d763ae0849d2dd31bdc95caf7db1c2901bf"));
612
613 }
614}