1use super::poseidon_params::PoseidonParams;
2use crate::merkle_tree::merkle_tree_fp::MerkleTreeHash;
3use ark_ff::PrimeField;
4use std::sync::Arc;
5
6#[derive(Clone, Debug)]
7pub struct Poseidon<S: PrimeField> {
8 pub(crate) params: Arc<PoseidonParams<S>>,
9}
10
11impl<S: PrimeField> Poseidon<S> {
12 pub fn new(params: &Arc<PoseidonParams<S>>) -> Self {
13 Poseidon {
14 params: Arc::clone(params),
15 }
16 }
17
18 pub fn get_t(&self) -> usize {
19 self.params.t
20 }
21
22 pub fn permutation(&self, input: &[S]) -> Vec<S> {
23 let t = self.params.t;
24 assert_eq!(input.len(), t);
25
26 let mut current_state = input.to_owned();
27 for r in 0..self.params.rounds_f_beginning {
28 current_state = self.add_rc(¤t_state, &self.params.round_constants[r]);
29 current_state = self.sbox(¤t_state);
30 current_state = self.matmul(¤t_state, &self.params.mds);
31 }
32 let p_end = self.params.rounds_f_beginning + self.params.rounds_p;
33 current_state = self.add_rc(¤t_state, &self.params.opt_round_constants[0]);
34 current_state = self.matmul(¤t_state, &self.params.m_i);
35
36 for r in self.params.rounds_f_beginning..p_end {
37 current_state[0] = self.sbox_p(¤t_state[0]);
38 if r < p_end - 1 {
39 current_state[0].add_assign(
40 &self.params.opt_round_constants[r + 1 - self.params.rounds_f_beginning][0],
41 );
42 }
43 current_state = self.cheap_matmul(¤t_state, p_end - r - 1);
44 }
45 for r in p_end..self.params.rounds {
46 current_state = self.add_rc(¤t_state, &self.params.round_constants[r]);
47 current_state = self.sbox(¤t_state);
48 current_state = self.matmul(¤t_state, &self.params.mds);
49 }
50 current_state
51 }
52
53 pub fn permutation_not_opt(&self, input: &[S]) -> Vec<S> {
54 let t = self.params.t;
55 assert_eq!(input.len(), t);
56
57 let mut current_state = input.to_owned();
58
59 for r in 0..self.params.rounds_f_beginning {
60 current_state = self.add_rc(¤t_state, &self.params.round_constants[r]);
61 current_state = self.sbox(¤t_state);
62 current_state = self.matmul(¤t_state, &self.params.mds);
63 }
64 let p_end = self.params.rounds_f_beginning + self.params.rounds_p;
65 for r in self.params.rounds_f_beginning..p_end {
66 current_state = self.add_rc(¤t_state, &self.params.round_constants[r]);
67 current_state[0] = self.sbox_p(¤t_state[0]);
68 current_state = self.matmul(¤t_state, &self.params.mds);
69 }
70 for r in p_end..self.params.rounds {
71 current_state = self.add_rc(¤t_state, &self.params.round_constants[r]);
72 current_state = self.sbox(¤t_state);
73 current_state = self.matmul(¤t_state, &self.params.mds);
74 }
75 current_state
76 }
77
78 fn sbox(&self, input: &[S]) -> Vec<S> {
79 input.iter().map(|el| self.sbox_p(el)).collect()
80 }
81
82 fn sbox_p(&self, input: &S) -> S {
83 let mut input2 = *input;
84 input2.square_in_place();
85
86 match self.params.d {
87 3 => {
88 let mut out = input2;
89 out.mul_assign(input);
90 out
91 }
92 5 => {
93 let mut out = input2;
94 out.square_in_place();
95 out.mul_assign(input);
96 out
97 }
98 7 => {
99 let mut out = input2;
100 out.square_in_place();
101 out.mul_assign(&input2);
102 out.mul_assign(input);
103 out
104 }
105 _ => {
106 panic!()
107 }
108 }
109 }
110
111 fn cheap_matmul(&self, input: &[S], r: usize) -> Vec<S> {
112 let v = &self.params.v[r];
113 let w_hat = &self.params.w_hat[r];
114 let t = self.params.t;
115
116 let mut new_state = vec![S::zero(); t];
117 new_state[0] = self.params.mds[0][0];
118 new_state[0].mul_assign(&input[0]);
119 for i in 1..t {
120 let mut tmp = w_hat[i - 1];
121 tmp.mul_assign(&input[i]);
122 new_state[0].add_assign(&tmp);
123 }
124 for i in 1..t {
125 new_state[i] = input[0];
126 new_state[i].mul_assign(&v[i - 1]);
127 new_state[i].add_assign(&input[i]);
128 }
129
130 new_state
131 }
132
133 fn matmul(&self, input: &[S], mat: &[Vec<S>]) -> Vec<S> {
134 let t = mat.len();
135 debug_assert!(t == input.len());
136 let mut out = vec![S::zero(); t];
137 for row in 0..t {
138 for (col, inp) in input.iter().enumerate().take(t) {
139 let mut tmp = mat[row][col];
140 tmp.mul_assign(inp);
141 out[row].add_assign(&tmp);
142 }
143 }
144 out
145 }
146
147 fn add_rc(&self, input: &[S], rc: &[S]) -> Vec<S> {
148 input
149 .iter()
150 .zip(rc.iter())
151 .map(|(a, b)| {
152 let mut r = *a;
153 r.add_assign(b);
154 r
155 })
156 .collect()
157 }
158}
159
160impl<F: PrimeField> MerkleTreeHash<F> for Poseidon<F> {
161 fn compress(&self, input: &[&F]) -> F {
162 self.permutation(&[input[0].to_owned(), input[1].to_owned(), F::zero()])[0]
163 }
164}
165
166#[cfg(test)]
167mod poseidon_tests_bls12 {
168 use super::*;
169 use crate::fields::{bls12::FpBLS12, utils::from_hex, utils::random_scalar};
170 use crate::poseidon::poseidon_instance_bls12::{
171 POSEIDON_BLS_2_PARAMS,
172 POSEIDON_BLS_3_PARAMS,
173 POSEIDON_BLS_4_PARAMS,
174 POSEIDON_BLS_8_PARAMS,
175 };
176
177 type Scalar = FpBLS12;
178
179 static TESTRUNS: usize = 5;
180
181 #[test]
182 fn consistent_perm() {
183 let instances = vec![
184 Poseidon::new(&POSEIDON_BLS_2_PARAMS),
185 Poseidon::new(&POSEIDON_BLS_3_PARAMS),
186 Poseidon::new(&POSEIDON_BLS_4_PARAMS),
187 Poseidon::new(&POSEIDON_BLS_8_PARAMS)
188 ];
189 for instance in instances {
190 let t = instance.params.t;
191 for _ in 0..TESTRUNS {
192 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
193
194 let mut input2: Vec<Scalar>;
195 loop {
196 input2 = (0..t).map(|_| random_scalar()).collect();
197 if input1 != input2 {
198 break;
199 }
200 }
201
202 let perm1 = instance.permutation(&input1);
203 let perm2 = instance.permutation(&input1);
204 let perm3 = instance.permutation(&input2);
205 assert_eq!(perm1, perm2);
206 assert_ne!(perm1, perm3);
207 }
208 }
209 }
210
211 #[test]
212 fn kats() {
213 let poseidon_2 = Poseidon::new(&POSEIDON_BLS_2_PARAMS);
214 let input_2: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1),];
215 let perm_2 = poseidon_2.permutation(&input_2);
216 assert_eq!(
217 perm_2[0],
218 from_hex("0x1dc37ce34aeee058292bb73bff9acffce73a8a92f3d6d1daa8b77d9516b5c837")
219 );
220 assert_eq!(
221 perm_2[1],
222 from_hex("0x534cc8001b9c21da25d62749e136ea3d702651ba129f0d5ed7847cf81bc8b042")
223 );
224
225 let poseidon_3 = Poseidon::new(&POSEIDON_BLS_3_PARAMS);
226 let input_3: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1), Scalar::from(2)];
227 let perm_3 = poseidon_3.permutation(&input_3);
228 assert_eq!(
229 perm_3[0],
230 from_hex("0x200e6982ac00df8fa65cef1fde9f21373fdbbfd98f2df1eb5fa04f3302ab0397")
231 );
232 assert_eq!(
233 perm_3[1],
234 from_hex("0x2233c9a40d91c1f643b700f836a1ac231c3f3a8d438ad1609355e1b7317a47e5")
235 );
236 assert_eq!(
237 perm_3[2],
238 from_hex("0x2eae6736db3c086ad29938869dedbf969dd9804a58aa228ec467b7d5a08dc765")
239 );
240 }
241 #[test]
242 fn opt_equals_not_opt() {
243 let instances = vec![
244 Poseidon::new(&POSEIDON_BLS_2_PARAMS),
245 Poseidon::new(&POSEIDON_BLS_3_PARAMS),
246 Poseidon::new(&POSEIDON_BLS_4_PARAMS),
247 Poseidon::new(&POSEIDON_BLS_8_PARAMS)
248 ];
249 for instance in instances {
250 let t = instance.params.t;
251 for _ in 0..TESTRUNS {
252 let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
253
254 let perm1 = instance.permutation(&input);
255 let perm2 = instance.permutation_not_opt(&input);
256 assert_eq!(perm1, perm2);
257 }
258 }
259 }
260}
261
262#[cfg(test)]
263mod poseidon_tests_bn256 {
264 use super::*;
265 use crate::fields::{bn256::FpBN256, utils::from_hex, utils::random_scalar};
266 use crate::poseidon::poseidon_instance_bn256::POSEIDON_BN_PARAMS;
267
268 type Scalar = FpBN256;
269
270 static TESTRUNS: usize = 5;
271
272 #[test]
273 fn consistent_perm() {
274 let poseidon = Poseidon::new(&POSEIDON_BN_PARAMS);
275 let t = poseidon.params.t;
276 for _ in 0..TESTRUNS {
277 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
278
279 let mut input2: Vec<Scalar>;
280 loop {
281 input2 = (0..t).map(|_| random_scalar()).collect();
282 if input1 != input2 {
283 break;
284 }
285 }
286
287 let perm1 = poseidon.permutation(&input1);
288 let perm2 = poseidon.permutation(&input1);
289 let perm3 = poseidon.permutation(&input2);
290 assert_eq!(perm1, perm2);
291 assert_ne!(perm1, perm3);
292 }
293 }
294
295 #[test]
296 fn kats() {
297 let poseidon = Poseidon::new(&POSEIDON_BN_PARAMS);
298 let input: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1), Scalar::from(2)];
299 let perm = poseidon.permutation(&input);
300 assert_eq!(
301 perm[0],
302 from_hex("0x2677d68d9cfa91f197bf5148b50afac461b6b8340ff119a5217794770baade5f")
303 );
304 assert_eq!(
305 perm[1],
306 from_hex("0x21ae9d716173496b62c76ad7deb4654961f64334441bcf77e17a047155a3239f")
307 );
308 assert_eq!(
309 perm[2],
310 from_hex("0x008f8e7c73ff20b6a141c48cef73215860acc749b14f0a7887f74950215169c6")
311 );
312 }
313 #[test]
314 fn opt_equals_not_opt() {
315 let poseidon = Poseidon::new(&POSEIDON_BN_PARAMS);
316 let t = poseidon.params.t;
317 for _ in 0..TESTRUNS {
318 let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
319
320 let perm1 = poseidon.permutation(&input);
321 let perm2 = poseidon.permutation_not_opt(&input);
322 assert_eq!(perm1, perm2);
323 }
324 }
325}
326
327#[cfg(test)]
328#[allow(unused_imports)]
329mod poseidon_tests_goldilocks {
330 use super::*;
331 use crate::fields::{goldilocks::FpGoldiLocks, utils::from_hex, utils::random_scalar};
332 use crate::poseidon::poseidon_instance_goldilocks::{
333 POSEIDON_GOLDILOCKS_8_PARAMS,
334 POSEIDON_GOLDILOCKS_12_PARAMS,
335 POSEIDON_GOLDILOCKS_16_PARAMS,
336 POSEIDON_GOLDILOCKS_20_PARAMS,
337 };
338 use std::convert::TryFrom;
339
340 type Scalar = FpGoldiLocks;
341 use ark_ff::UniformRand;
342
343 static TESTRUNS: usize = 5;
344
345 #[test]
346 fn consistent_perm() {
347 let instances = vec![
348 Poseidon::new(&POSEIDON_GOLDILOCKS_8_PARAMS),
349 Poseidon::new(&POSEIDON_GOLDILOCKS_12_PARAMS),
350 Poseidon::new(&POSEIDON_GOLDILOCKS_16_PARAMS),
351 Poseidon::new(&POSEIDON_GOLDILOCKS_20_PARAMS),
352 ];
353 for instance in instances {
354 let t = instance.params.t;
355 for _ in 0..TESTRUNS {
356 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
357
358 let mut input2: Vec<Scalar>;
359 loop {
360 input2 = (0..t).map(|_| random_scalar()).collect();
361 if input1 != input2 {
362 break;
363 }
364 }
365
366 let perm1 = instance.permutation(&input1);
367 let perm2 = instance.permutation(&input1);
368 let perm3 = instance.permutation(&input2);
369 assert_eq!(perm1, perm2);
370 assert_ne!(perm1, perm3);
371 }
372 }
373 }
374
375 #[test]
376 fn kats() {
377 let poseidon = Poseidon::new(&POSEIDON_GOLDILOCKS_12_PARAMS);
378 let mut input: Vec<Scalar> = vec![];
380 for i in 0..poseidon.params.t {
381 input.push(Scalar::from(i as u64));
382 }
383 let perm = poseidon.permutation(&input);
384 assert_eq!(
385 perm[0],
386 from_hex("0xe9ad770762f48ef5")
387 );
388 assert_eq!(
389 perm[1],
390 from_hex("0xc12796961ddc7859")
391 );
392 assert_eq!(
393 perm[2],
394 from_hex("0xa61b71de9595e016")
395 );
396 assert_eq!(
397 perm[3],
398 from_hex("0xead9e6aa583aafa3")
399 );
400 assert_eq!(
401 perm[4],
402 from_hex("0x93e297beff76e95b")
403 );
404 assert_eq!(
405 perm[5],
406 from_hex("0x53abd3c5c2a0e924")
407 );
408 assert_eq!(
409 perm[6],
410 from_hex("0xf3bc50e655c74f51")
411 );
412 assert_eq!(
413 perm[7],
414 from_hex("0x246cac41b9a45d84")
415 );
416 assert_eq!(
417 perm[8],
418 from_hex("0xcc7f9314b2341f4f")
419 );
420 assert_eq!(
421 perm[9],
422 from_hex("0xf5f071587c83415c")
423 );
424 assert_eq!(
425 perm[10],
426 from_hex("0x09486cf35116fba3")
427 );
428 assert_eq!(
429 perm[11],
430 from_hex("0x9d82aaf136b5c38a")
431 );
432 }
433
434 #[test]
435 fn opt_equals_not_opt() {
436 let instances = vec![
437 Poseidon::new(&POSEIDON_GOLDILOCKS_8_PARAMS),
438 Poseidon::new(&POSEIDON_GOLDILOCKS_12_PARAMS),
439 Poseidon::new(&POSEIDON_GOLDILOCKS_16_PARAMS),
440 Poseidon::new(&POSEIDON_GOLDILOCKS_20_PARAMS),
441 ];
442 for instance in instances {
443 let t = instance.params.t;
444 for _ in 0..TESTRUNS {
445 let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
446
447 let perm1 = instance.permutation(&input);
448 let perm2 = instance.permutation_not_opt(&input);
449 assert_eq!(perm1, perm2);
450 }
451 }
452 }
453}
454
455#[cfg(test)]
456#[allow(unused_imports)]
457mod poseidon_tests_babybear {
458 use super::*;
459 use crate::fields::{babybear::FpBabyBear, utils::from_hex, utils::random_scalar};
460 use crate::poseidon::poseidon_instance_babybear::{
461 POSEIDON_BABYBEAR_16_PARAMS,
462 POSEIDON_BABYBEAR_24_PARAMS,
463 };
464
465 type Scalar = FpBabyBear;
466 use ark_ff::UniformRand;
467
468 static TESTRUNS: usize = 5;
469
470 #[test]
471 fn consistent_perm() {
472 let instances = vec![
473 Poseidon::new(&POSEIDON_BABYBEAR_16_PARAMS),
474 Poseidon::new(&POSEIDON_BABYBEAR_24_PARAMS),
475 ];
476 for instance in instances {
477 let t = instance.params.t;
478 for _ in 0..TESTRUNS {
479 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
480
481 let mut input2: Vec<Scalar>;
482 loop {
483 input2 = (0..t).map(|_| random_scalar()).collect();
484 if input1 != input2 {
485 break;
486 }
487 }
488
489 let perm1 = instance.permutation(&input1);
490 let perm2 = instance.permutation(&input1);
491 let perm3 = instance.permutation(&input2);
492 assert_eq!(perm1, perm2);
493 assert_ne!(perm1, perm3);
494 }
495 }
496 }
497
498 #[test]
499 fn opt_equals_not_opt() {
500 let instances = vec![
501 Poseidon::new(&POSEIDON_BABYBEAR_16_PARAMS),
502 Poseidon::new(&POSEIDON_BABYBEAR_24_PARAMS),
503 ];
504 for instance in instances {
505 let t = instance.params.t;
506 for _ in 0..TESTRUNS {
507 let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
508
509 let perm1 = instance.permutation(&input);
510 let perm2 = instance.permutation_not_opt(&input);
511 assert_eq!(perm1, perm2);
512 }
513 }
514 }
515}
516
517#[cfg(test)]
518#[allow(unused_imports)]
519mod poseidon_tests_pallas {
520 use super::*;
521 use crate::fields::{pallas::FpPallas, utils::from_hex, utils::random_scalar};
522 use crate::poseidon::poseidon_instance_pallas::{
523 POSEIDON_PALLAS_3_PARAMS,
524 POSEIDON_PALLAS_4_PARAMS,
525 POSEIDON_PALLAS_8_PARAMS,
526 };
527
528
529 type Scalar = FpPallas;
530 use ark_ff::UniformRand;
531
532 static TESTRUNS: usize = 5;
533
534 #[test]
535 fn consistent_perm() {
536 let instances = vec![
537 Poseidon::new(&POSEIDON_PALLAS_3_PARAMS),
538 Poseidon::new(&POSEIDON_PALLAS_4_PARAMS),
539 Poseidon::new(&POSEIDON_PALLAS_8_PARAMS)
540 ];
541 for instance in instances {
542 let t = instance.params.t;
543 for _ in 0..TESTRUNS {
544 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
545
546 let mut input2: Vec<Scalar>;
547 loop {
548 input2 = (0..t).map(|_| random_scalar()).collect();
549 if input1 != input2 {
550 break;
551 }
552 }
553
554 let perm1 = instance.permutation(&input1);
555 let perm2 = instance.permutation(&input1);
556 let perm3 = instance.permutation(&input2);
557 assert_eq!(perm1, perm2);
558 assert_ne!(perm1, perm3);
559 }
560 }
561 }
562
563 #[test]
564 fn kats() {
565 let poseidon = Poseidon::new(&POSEIDON_PALLAS_3_PARAMS);
566 let input: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1), Scalar::from(2)];
567 let perm = poseidon.permutation(&input);
568 assert_eq!(
569 perm[0],
570 from_hex("0x08fd69dd1602112194d1fefd8c2b20242e371879feba6683a4bdeebd6e8f121c")
571 );
572 assert_eq!(
573 perm[1],
574 from_hex("0x2a17023cc2483bf305661df2580c3b29444f8b954de7f2166091592ba7728591")
575 );
576 assert_eq!(
577 perm[2],
578 from_hex("0x1495649c6632dd6202315e468aa08b1392b750dfe0d2b3bbc902e230355e9615")
579 );
580 }
581
582 #[test]
583 fn opt_equals_not_opt() {
584 let instances = vec![
585 Poseidon::new(&POSEIDON_PALLAS_3_PARAMS),
586 Poseidon::new(&POSEIDON_PALLAS_4_PARAMS),
587 Poseidon::new(&POSEIDON_PALLAS_8_PARAMS)
588 ];
589 for instance in instances {
590 let t = instance.params.t;
591 for _ in 0..TESTRUNS {
592 let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
593
594 let perm1 = instance.permutation(&input);
595 let perm2 = instance.permutation_not_opt(&input);
596 assert_eq!(perm1, perm2);
597 }
598 }
599 }
600}
601
602#[cfg(test)]
603#[allow(unused_imports)]
604mod poseidon_tests_vesta {
605 use super::*;
606 use crate::fields::{vesta::FpVesta, utils::from_hex, utils::random_scalar};
607 use crate::poseidon::poseidon_instance_vesta::POSEIDON_VESTA_PARAMS;
608
609 type Scalar = FpVesta;
610 use ark_ff::UniformRand;
611
612 static TESTRUNS: usize = 5;
613
614 #[test]
615 fn consistent_perm() {
616 let poseidon = Poseidon::new(&POSEIDON_VESTA_PARAMS);
617 let t = poseidon.params.t;
618 for _ in 0..TESTRUNS {
619 let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
620
621 let mut input2: Vec<Scalar>;
622 loop {
623 input2 = (0..t).map(|_| random_scalar()).collect();
624 if input1 != input2 {
625 break;
626 }
627 }
628
629 let perm1 = poseidon.permutation(&input1);
630 let perm2 = poseidon.permutation(&input1);
631 let perm3 = poseidon.permutation(&input2);
632 assert_eq!(perm1, perm2);
633 assert_ne!(perm1, perm3);
634 }
635 }
636
637 #[test]
638 fn kats() {
639 let poseidon = Poseidon::new(&POSEIDON_VESTA_PARAMS);
640 let input: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1), Scalar::from(2)];
641 let perm = poseidon.permutation(&input);
642 assert_eq!(
643 perm[0],
644 from_hex("0x32e8b71fc2963b1c2371a5a9e191671079b3e059d9683027b146bd5d34cea133")
645 );
646 assert_eq!(
647 perm[1],
648 from_hex("0x005e6cd1461b0470c03f045e8fba078846bbdbb0992c37fc6f4764ebdb92a1d6")
649 );
650 assert_eq!(
651 perm[2],
652 from_hex("0x162f4406f334d8600c569b3172e75abf00f6c201871d4fff9834cedd0c8aa5d3")
653 );
654 }
655
656 #[test]
657 fn opt_equals_not_opt() {
658 let poseidon = Poseidon::new(&POSEIDON_VESTA_PARAMS);
659 let t = poseidon.params.t;
660 for _ in 0..TESTRUNS {
661 let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
662
663 let perm1 = poseidon.permutation(&input);
664 let perm2 = poseidon.permutation_not_opt(&input);
665 assert_eq!(perm1, perm2);
666 }
667 }
668}