zkhash/poseidon/
poseidon.rs

1use super::poseidon_params::PoseidonParams;
2use crate::merkle_tree::merkle_tree_fp::MerkleTreeHash;
3use ark_ff::PrimeField;
4use std::sync::Arc;
5
6#[derive(Clone, Debug)]
7pub struct Poseidon<S: PrimeField> {
8    pub(crate) params: Arc<PoseidonParams<S>>,
9}
10
11impl<S: PrimeField> Poseidon<S> {
12    pub fn new(params: &Arc<PoseidonParams<S>>) -> Self {
13        Poseidon {
14            params: Arc::clone(params),
15        }
16    }
17
18    pub fn get_t(&self) -> usize {
19        self.params.t
20    }
21
22    pub fn permutation(&self, input: &[S]) -> Vec<S> {
23        let t = self.params.t;
24        assert_eq!(input.len(), t);
25
26        let mut current_state = input.to_owned();
27        for r in 0..self.params.rounds_f_beginning {
28            current_state = self.add_rc(&current_state, &self.params.round_constants[r]);
29            current_state = self.sbox(&current_state);
30            current_state = self.matmul(&current_state, &self.params.mds);
31        }
32        let p_end = self.params.rounds_f_beginning + self.params.rounds_p;
33        current_state = self.add_rc(&current_state, &self.params.opt_round_constants[0]);
34        current_state = self.matmul(&current_state, &self.params.m_i);
35
36        for r in self.params.rounds_f_beginning..p_end {
37            current_state[0] = self.sbox_p(&current_state[0]);
38            if r < p_end - 1 {
39                current_state[0].add_assign(
40                    &self.params.opt_round_constants[r + 1 - self.params.rounds_f_beginning][0],
41                );
42            }
43            current_state = self.cheap_matmul(&current_state, p_end - r - 1);
44        }
45        for r in p_end..self.params.rounds {
46            current_state = self.add_rc(&current_state, &self.params.round_constants[r]);
47            current_state = self.sbox(&current_state);
48            current_state = self.matmul(&current_state, &self.params.mds);
49        }
50        current_state
51    }
52
53    pub fn permutation_not_opt(&self, input: &[S]) -> Vec<S> {
54        let t = self.params.t;
55        assert_eq!(input.len(), t);
56
57        let mut current_state = input.to_owned();
58
59        for r in 0..self.params.rounds_f_beginning {
60            current_state = self.add_rc(&current_state, &self.params.round_constants[r]);
61            current_state = self.sbox(&current_state);
62            current_state = self.matmul(&current_state, &self.params.mds);
63        }
64        let p_end = self.params.rounds_f_beginning + self.params.rounds_p;
65        for r in self.params.rounds_f_beginning..p_end {
66            current_state = self.add_rc(&current_state, &self.params.round_constants[r]);
67            current_state[0] = self.sbox_p(&current_state[0]);
68            current_state = self.matmul(&current_state, &self.params.mds);
69        }
70        for r in p_end..self.params.rounds {
71            current_state = self.add_rc(&current_state, &self.params.round_constants[r]);
72            current_state = self.sbox(&current_state);
73            current_state = self.matmul(&current_state, &self.params.mds);
74        }
75        current_state
76    }
77
78    fn sbox(&self, input: &[S]) -> Vec<S> {
79        input.iter().map(|el| self.sbox_p(el)).collect()
80    }
81
82    fn sbox_p(&self, input: &S) -> S {
83        let mut input2 = *input;
84        input2.square_in_place();
85
86        match self.params.d {
87            3 => {
88                let mut out = input2;
89                out.mul_assign(input);
90                out
91            }
92            5 => {
93                let mut out = input2;
94                out.square_in_place();
95                out.mul_assign(input);
96                out
97            }
98            7 => {
99                let mut out = input2;
100                out.square_in_place();
101                out.mul_assign(&input2);
102                out.mul_assign(input);
103                out
104            }
105            _ => {
106                panic!()
107            }
108        }
109    }
110
111    fn cheap_matmul(&self, input: &[S], r: usize) -> Vec<S> {
112        let v = &self.params.v[r];
113        let w_hat = &self.params.w_hat[r];
114        let t = self.params.t;
115
116        let mut new_state = vec![S::zero(); t];
117        new_state[0] = self.params.mds[0][0];
118        new_state[0].mul_assign(&input[0]);
119        for i in 1..t {
120            let mut tmp = w_hat[i - 1];
121            tmp.mul_assign(&input[i]);
122            new_state[0].add_assign(&tmp);
123        }
124        for i in 1..t {
125            new_state[i] = input[0];
126            new_state[i].mul_assign(&v[i - 1]);
127            new_state[i].add_assign(&input[i]);
128        }
129
130        new_state
131    }
132
133    fn matmul(&self, input: &[S], mat: &[Vec<S>]) -> Vec<S> {
134        let t = mat.len();
135        debug_assert!(t == input.len());
136        let mut out = vec![S::zero(); t];
137        for row in 0..t {
138            for (col, inp) in input.iter().enumerate().take(t) {
139                let mut tmp = mat[row][col];
140                tmp.mul_assign(inp);
141                out[row].add_assign(&tmp);
142            }
143        }
144        out
145    }
146
147    fn add_rc(&self, input: &[S], rc: &[S]) -> Vec<S> {
148        input
149            .iter()
150            .zip(rc.iter())
151            .map(|(a, b)| {
152                let mut r = *a;
153                r.add_assign(b);
154                r
155            })
156            .collect()
157    }
158}
159
160impl<F: PrimeField> MerkleTreeHash<F> for Poseidon<F> {
161    fn compress(&self, input: &[&F]) -> F {
162        self.permutation(&[input[0].to_owned(), input[1].to_owned(), F::zero()])[0]
163    }
164}
165
166#[cfg(test)]
167mod poseidon_tests_bls12 {
168    use super::*;
169    use crate::fields::{bls12::FpBLS12, utils::from_hex, utils::random_scalar};
170    use crate::poseidon::poseidon_instance_bls12::{
171        POSEIDON_BLS_2_PARAMS,
172        POSEIDON_BLS_3_PARAMS,
173        POSEIDON_BLS_4_PARAMS,
174        POSEIDON_BLS_8_PARAMS,
175    };
176
177    type Scalar = FpBLS12;
178
179    static TESTRUNS: usize = 5;
180
181    #[test]
182    fn consistent_perm() {
183        let instances = vec![
184            Poseidon::new(&POSEIDON_BLS_2_PARAMS),
185            Poseidon::new(&POSEIDON_BLS_3_PARAMS),
186            Poseidon::new(&POSEIDON_BLS_4_PARAMS),
187            Poseidon::new(&POSEIDON_BLS_8_PARAMS)
188        ];
189        for instance in instances {
190            let t = instance.params.t;
191            for _ in 0..TESTRUNS {
192                let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
193
194                let mut input2: Vec<Scalar>;
195                loop {
196                    input2 = (0..t).map(|_| random_scalar()).collect();
197                    if input1 != input2 {
198                        break;
199                    }
200                }
201
202                let perm1 = instance.permutation(&input1);
203                let perm2 = instance.permutation(&input1);
204                let perm3 = instance.permutation(&input2);
205                assert_eq!(perm1, perm2);
206                assert_ne!(perm1, perm3);
207            }
208        }
209    }
210
211    #[test]
212    fn kats() {
213        let poseidon_2 = Poseidon::new(&POSEIDON_BLS_2_PARAMS);
214        let input_2: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1),];
215        let perm_2 = poseidon_2.permutation(&input_2);
216        assert_eq!(
217            perm_2[0],
218            from_hex("0x1dc37ce34aeee058292bb73bff9acffce73a8a92f3d6d1daa8b77d9516b5c837")
219        );
220        assert_eq!(
221            perm_2[1],
222            from_hex("0x534cc8001b9c21da25d62749e136ea3d702651ba129f0d5ed7847cf81bc8b042")
223        );
224
225        let poseidon_3 = Poseidon::new(&POSEIDON_BLS_3_PARAMS);
226        let input_3: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1), Scalar::from(2)];
227        let perm_3 = poseidon_3.permutation(&input_3);
228        assert_eq!(
229            perm_3[0],
230            from_hex("0x200e6982ac00df8fa65cef1fde9f21373fdbbfd98f2df1eb5fa04f3302ab0397")
231        );
232        assert_eq!(
233            perm_3[1],
234            from_hex("0x2233c9a40d91c1f643b700f836a1ac231c3f3a8d438ad1609355e1b7317a47e5")
235        );
236        assert_eq!(
237            perm_3[2],
238            from_hex("0x2eae6736db3c086ad29938869dedbf969dd9804a58aa228ec467b7d5a08dc765")
239        );
240    }
241    #[test]
242    fn opt_equals_not_opt() {
243        let instances = vec![
244            Poseidon::new(&POSEIDON_BLS_2_PARAMS),
245            Poseidon::new(&POSEIDON_BLS_3_PARAMS),
246            Poseidon::new(&POSEIDON_BLS_4_PARAMS),
247            Poseidon::new(&POSEIDON_BLS_8_PARAMS)
248        ];
249        for instance in instances {
250            let t = instance.params.t;
251            for _ in 0..TESTRUNS {
252                let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
253
254                let perm1 = instance.permutation(&input);
255                let perm2 = instance.permutation_not_opt(&input);
256                assert_eq!(perm1, perm2);
257            }
258        }
259    }
260}
261
262#[cfg(test)]
263mod poseidon_tests_bn256 {
264    use super::*;
265    use crate::fields::{bn256::FpBN256, utils::from_hex, utils::random_scalar};
266    use crate::poseidon::poseidon_instance_bn256::POSEIDON_BN_PARAMS;
267
268    type Scalar = FpBN256;
269
270    static TESTRUNS: usize = 5;
271
272    #[test]
273    fn consistent_perm() {
274        let poseidon = Poseidon::new(&POSEIDON_BN_PARAMS);
275        let t = poseidon.params.t;
276        for _ in 0..TESTRUNS {
277            let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
278
279            let mut input2: Vec<Scalar>;
280            loop {
281                input2 = (0..t).map(|_| random_scalar()).collect();
282                if input1 != input2 {
283                    break;
284                }
285            }
286
287            let perm1 = poseidon.permutation(&input1);
288            let perm2 = poseidon.permutation(&input1);
289            let perm3 = poseidon.permutation(&input2);
290            assert_eq!(perm1, perm2);
291            assert_ne!(perm1, perm3);
292        }
293    }
294
295    #[test]
296    fn kats() {
297        let poseidon = Poseidon::new(&POSEIDON_BN_PARAMS);
298        let input: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1), Scalar::from(2)];
299        let perm = poseidon.permutation(&input);
300        assert_eq!(
301            perm[0],
302            from_hex("0x2677d68d9cfa91f197bf5148b50afac461b6b8340ff119a5217794770baade5f")
303        );
304        assert_eq!(
305            perm[1],
306            from_hex("0x21ae9d716173496b62c76ad7deb4654961f64334441bcf77e17a047155a3239f")
307        );
308        assert_eq!(
309            perm[2],
310            from_hex("0x008f8e7c73ff20b6a141c48cef73215860acc749b14f0a7887f74950215169c6")
311        );
312    }
313    #[test]
314    fn opt_equals_not_opt() {
315        let poseidon = Poseidon::new(&POSEIDON_BN_PARAMS);
316        let t = poseidon.params.t;
317        for _ in 0..TESTRUNS {
318            let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
319
320            let perm1 = poseidon.permutation(&input);
321            let perm2 = poseidon.permutation_not_opt(&input);
322            assert_eq!(perm1, perm2);
323        }
324    }
325}
326
327#[cfg(test)]
328#[allow(unused_imports)]
329mod poseidon_tests_goldilocks {
330    use super::*;
331    use crate::fields::{goldilocks::FpGoldiLocks, utils::from_hex, utils::random_scalar};
332    use crate::poseidon::poseidon_instance_goldilocks::{
333        POSEIDON_GOLDILOCKS_8_PARAMS,
334        POSEIDON_GOLDILOCKS_12_PARAMS,
335        POSEIDON_GOLDILOCKS_16_PARAMS,
336        POSEIDON_GOLDILOCKS_20_PARAMS,
337    };
338    use std::convert::TryFrom;
339
340    type Scalar = FpGoldiLocks;
341    use ark_ff::UniformRand;
342
343    static TESTRUNS: usize = 5;
344
345    #[test]
346    fn consistent_perm() {
347        let instances = vec![
348            Poseidon::new(&POSEIDON_GOLDILOCKS_8_PARAMS),
349            Poseidon::new(&POSEIDON_GOLDILOCKS_12_PARAMS),
350            Poseidon::new(&POSEIDON_GOLDILOCKS_16_PARAMS),
351            Poseidon::new(&POSEIDON_GOLDILOCKS_20_PARAMS),
352        ];
353        for instance in instances {
354            let t = instance.params.t;
355            for _ in 0..TESTRUNS {
356                let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
357
358                let mut input2: Vec<Scalar>;
359                loop {
360                    input2 = (0..t).map(|_| random_scalar()).collect();
361                    if input1 != input2 {
362                        break;
363                    }
364                }
365
366                let perm1 = instance.permutation(&input1);
367                let perm2 = instance.permutation(&input1);
368                let perm3 = instance.permutation(&input2);
369                assert_eq!(perm1, perm2);
370                assert_ne!(perm1, perm3);
371            }
372        }
373    }
374
375    #[test]
376    fn kats() {
377        let poseidon = Poseidon::new(&POSEIDON_GOLDILOCKS_12_PARAMS);
378        // let input: Vec<Scalar> = vec![Scalar::zero(), Scalar::one(), utils::from_u64::<Scalar>(2)];
379        let mut input: Vec<Scalar> = vec![];
380        for i in 0..poseidon.params.t {
381            input.push(Scalar::from(i as u64));
382        }
383        let perm = poseidon.permutation(&input);
384        assert_eq!(
385            perm[0],
386            from_hex("0xe9ad770762f48ef5")
387        );
388        assert_eq!(
389            perm[1],
390            from_hex("0xc12796961ddc7859")
391        );
392        assert_eq!(
393            perm[2],
394            from_hex("0xa61b71de9595e016")
395        );
396        assert_eq!(
397            perm[3],
398            from_hex("0xead9e6aa583aafa3")
399        );
400        assert_eq!(
401            perm[4],
402            from_hex("0x93e297beff76e95b")
403        );
404        assert_eq!(
405            perm[5],
406            from_hex("0x53abd3c5c2a0e924")
407        );
408        assert_eq!(
409            perm[6],
410            from_hex("0xf3bc50e655c74f51")
411        );
412        assert_eq!(
413            perm[7],
414            from_hex("0x246cac41b9a45d84")
415        );
416        assert_eq!(
417            perm[8],
418            from_hex("0xcc7f9314b2341f4f")
419        );
420        assert_eq!(
421            perm[9],
422            from_hex("0xf5f071587c83415c")
423        );
424        assert_eq!(
425            perm[10],
426            from_hex("0x09486cf35116fba3")
427        );
428        assert_eq!(
429            perm[11],
430            from_hex("0x9d82aaf136b5c38a")
431        );
432    }
433
434    #[test]
435    fn opt_equals_not_opt() {
436        let instances = vec![
437            Poseidon::new(&POSEIDON_GOLDILOCKS_8_PARAMS),
438            Poseidon::new(&POSEIDON_GOLDILOCKS_12_PARAMS),
439            Poseidon::new(&POSEIDON_GOLDILOCKS_16_PARAMS),
440            Poseidon::new(&POSEIDON_GOLDILOCKS_20_PARAMS),
441        ];
442        for instance in instances {
443            let t = instance.params.t;
444            for _ in 0..TESTRUNS {
445                let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
446
447                let perm1 = instance.permutation(&input);
448                let perm2 = instance.permutation_not_opt(&input);
449                assert_eq!(perm1, perm2);
450            }
451        }
452    }
453}
454
455#[cfg(test)]
456#[allow(unused_imports)]
457mod poseidon_tests_babybear {
458    use super::*;
459    use crate::fields::{babybear::FpBabyBear, utils::from_hex, utils::random_scalar};
460    use crate::poseidon::poseidon_instance_babybear::{
461        POSEIDON_BABYBEAR_16_PARAMS,
462        POSEIDON_BABYBEAR_24_PARAMS,
463    };
464
465    type Scalar = FpBabyBear;
466    use ark_ff::UniformRand;
467
468    static TESTRUNS: usize = 5;
469
470    #[test]
471    fn consistent_perm() {
472        let instances = vec![
473            Poseidon::new(&POSEIDON_BABYBEAR_16_PARAMS),
474            Poseidon::new(&POSEIDON_BABYBEAR_24_PARAMS),
475        ];
476        for instance in instances {
477            let t = instance.params.t;
478            for _ in 0..TESTRUNS {
479                let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
480
481                let mut input2: Vec<Scalar>;
482                loop {
483                    input2 = (0..t).map(|_| random_scalar()).collect();
484                    if input1 != input2 {
485                        break;
486                    }
487                }
488
489                let perm1 = instance.permutation(&input1);
490                let perm2 = instance.permutation(&input1);
491                let perm3 = instance.permutation(&input2);
492                assert_eq!(perm1, perm2);
493                assert_ne!(perm1, perm3);
494            }
495        }
496    }
497
498    #[test]
499    fn opt_equals_not_opt() {
500        let instances = vec![
501            Poseidon::new(&POSEIDON_BABYBEAR_16_PARAMS),
502            Poseidon::new(&POSEIDON_BABYBEAR_24_PARAMS),
503        ];
504        for instance in instances {
505            let t = instance.params.t;
506            for _ in 0..TESTRUNS {
507                let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
508
509                let perm1 = instance.permutation(&input);
510                let perm2 = instance.permutation_not_opt(&input);
511                assert_eq!(perm1, perm2);
512            }
513        }
514    }
515}
516
517#[cfg(test)]
518#[allow(unused_imports)]
519mod poseidon_tests_pallas {
520    use super::*;
521    use crate::fields::{pallas::FpPallas, utils::from_hex, utils::random_scalar};
522    use crate::poseidon::poseidon_instance_pallas::{
523        POSEIDON_PALLAS_3_PARAMS,
524        POSEIDON_PALLAS_4_PARAMS,
525        POSEIDON_PALLAS_8_PARAMS,
526    };
527
528
529    type Scalar = FpPallas;
530    use ark_ff::UniformRand;
531
532    static TESTRUNS: usize = 5;
533
534    #[test]
535    fn consistent_perm() {
536        let instances = vec![
537            Poseidon::new(&POSEIDON_PALLAS_3_PARAMS),
538            Poseidon::new(&POSEIDON_PALLAS_4_PARAMS),
539            Poseidon::new(&POSEIDON_PALLAS_8_PARAMS)
540        ];
541        for instance in instances {
542            let t = instance.params.t;
543            for _ in 0..TESTRUNS {
544                let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
545
546                let mut input2: Vec<Scalar>;
547                loop {
548                    input2 = (0..t).map(|_| random_scalar()).collect();
549                    if input1 != input2 {
550                        break;
551                    }
552                }
553
554                let perm1 = instance.permutation(&input1);
555                let perm2 = instance.permutation(&input1);
556                let perm3 = instance.permutation(&input2);
557                assert_eq!(perm1, perm2);
558                assert_ne!(perm1, perm3);
559            }
560        }
561    }
562
563    #[test]
564    fn kats() {
565        let poseidon = Poseidon::new(&POSEIDON_PALLAS_3_PARAMS);
566        let input: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1), Scalar::from(2)];
567        let perm = poseidon.permutation(&input);
568        assert_eq!(
569            perm[0],
570            from_hex("0x08fd69dd1602112194d1fefd8c2b20242e371879feba6683a4bdeebd6e8f121c")
571        );
572        assert_eq!(
573            perm[1],
574            from_hex("0x2a17023cc2483bf305661df2580c3b29444f8b954de7f2166091592ba7728591")
575        );
576        assert_eq!(
577            perm[2],
578            from_hex("0x1495649c6632dd6202315e468aa08b1392b750dfe0d2b3bbc902e230355e9615")
579        );
580    }
581
582    #[test]
583    fn opt_equals_not_opt() {
584        let instances = vec![
585            Poseidon::new(&POSEIDON_PALLAS_3_PARAMS),
586            Poseidon::new(&POSEIDON_PALLAS_4_PARAMS),
587            Poseidon::new(&POSEIDON_PALLAS_8_PARAMS)
588        ];
589        for instance in instances {
590            let t = instance.params.t;
591            for _ in 0..TESTRUNS {
592                let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
593
594                let perm1 = instance.permutation(&input);
595                let perm2 = instance.permutation_not_opt(&input);
596                assert_eq!(perm1, perm2);
597            }
598        }
599    }
600}
601
602#[cfg(test)]
603#[allow(unused_imports)]
604mod poseidon_tests_vesta {
605    use super::*;
606    use crate::fields::{vesta::FpVesta, utils::from_hex, utils::random_scalar};
607    use crate::poseidon::poseidon_instance_vesta::POSEIDON_VESTA_PARAMS;
608
609    type Scalar = FpVesta;
610    use ark_ff::UniformRand;
611
612    static TESTRUNS: usize = 5;
613
614    #[test]
615    fn consistent_perm() {
616        let poseidon = Poseidon::new(&POSEIDON_VESTA_PARAMS);
617        let t = poseidon.params.t;
618        for _ in 0..TESTRUNS {
619            let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
620
621            let mut input2: Vec<Scalar>;
622            loop {
623                input2 = (0..t).map(|_| random_scalar()).collect();
624                if input1 != input2 {
625                    break;
626                }
627            }
628
629            let perm1 = poseidon.permutation(&input1);
630            let perm2 = poseidon.permutation(&input1);
631            let perm3 = poseidon.permutation(&input2);
632            assert_eq!(perm1, perm2);
633            assert_ne!(perm1, perm3);
634        }
635    }
636
637    #[test]
638    fn kats() {
639        let poseidon = Poseidon::new(&POSEIDON_VESTA_PARAMS);
640        let input: Vec<Scalar> = vec![Scalar::from(0), Scalar::from(1), Scalar::from(2)];
641        let perm = poseidon.permutation(&input);
642        assert_eq!(
643            perm[0],
644            from_hex("0x32e8b71fc2963b1c2371a5a9e191671079b3e059d9683027b146bd5d34cea133")
645        );
646        assert_eq!(
647            perm[1],
648            from_hex("0x005e6cd1461b0470c03f045e8fba078846bbdbb0992c37fc6f4764ebdb92a1d6")
649        );
650        assert_eq!(
651            perm[2],
652            from_hex("0x162f4406f334d8600c569b3172e75abf00f6c201871d4fff9834cedd0c8aa5d3")
653        );
654    }
655
656    #[test]
657    fn opt_equals_not_opt() {
658        let poseidon = Poseidon::new(&POSEIDON_VESTA_PARAMS);
659        let t = poseidon.params.t;
660        for _ in 0..TESTRUNS {
661            let input: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
662
663            let perm1 = poseidon.permutation(&input);
664            let perm2 = poseidon.permutation_not_opt(&input);
665            assert_eq!(perm1, perm2);
666        }
667    }
668}