zkhash/poseidon2/
poseidon2.rs

1use super::poseidon2_params::Poseidon2Params;
2use crate::merkle_tree::merkle_tree_fp::MerkleTreeHash;
3use ark_ff::PrimeField;
4use std::sync::Arc;
5
6#[derive(Clone, Debug)]
7pub struct Poseidon2<F: PrimeField> {
8    pub(crate) params: Arc<Poseidon2Params<F>>,
9}
10
11impl<F: PrimeField> Poseidon2<F> {
12    pub fn new(params: &Arc<Poseidon2Params<F>>) -> Self {
13        Poseidon2 {
14            params: Arc::clone(params),
15        }
16    }
17
18    pub fn get_t(&self) -> usize {
19        self.params.t
20    }
21
22    pub fn permutation(&self, input: &[F]) -> Vec<F> {
23        let t = self.params.t;
24        assert_eq!(input.len(), t);
25
26        let mut current_state = input.to_owned();
27
28        // Linear layer at beginning
29        self.matmul_external(&mut current_state);
30
31        for r in 0..self.params.rounds_f_beginning {
32            current_state = self.add_rc(&current_state, &self.params.round_constants[r]);
33            current_state = self.sbox(&current_state);
34            self.matmul_external(&mut current_state);
35        }
36
37        let p_end = self.params.rounds_f_beginning + self.params.rounds_p;
38        for r in self.params.rounds_f_beginning..p_end {
39            current_state[0].add_assign(&self.params.round_constants[r][0]);
40            current_state[0] = self.sbox_p(&current_state[0]);
41            self.matmul_internal(&mut current_state, &self.params.mat_internal_diag_m_1);
42        }
43        
44        for r in p_end..self.params.rounds {
45            current_state = self.add_rc(&current_state, &self.params.round_constants[r]);
46            current_state = self.sbox(&current_state);
47            self.matmul_external(&mut current_state);
48        }
49        current_state
50    }
51
52    fn sbox(&self, input: &[F]) -> Vec<F> {
53        input.iter().map(|el| self.sbox_p(el)).collect()
54    }
55
56    fn sbox_p(&self, input: &F) -> F {
57        let mut input2 = *input;
58        input2.square_in_place();
59
60        match self.params.d {
61            3 => {
62                let mut out = input2;
63                out.mul_assign(input);
64                out
65            }
66            5 => {
67                let mut out = input2;
68                out.square_in_place();
69                out.mul_assign(input);
70                out
71            }
72            7 => {
73                let mut out = input2;
74                out.square_in_place();
75                out.mul_assign(&input2);
76                out.mul_assign(input);
77                out
78            }
79            _ => {
80                panic!()
81            }
82        }
83    }
84
85    fn matmul_m4(&self, input: &mut[F]) {
86        let t = self.params.t;
87        let t4 = t / 4;
88        for i in 0..t4 {
89            let start_index = i * 4;
90            let mut t_0 = input[start_index];
91            t_0.add_assign(&input[start_index + 1]);
92            let mut t_1 = input[start_index + 2];
93            t_1.add_assign(&input[start_index + 3]);
94            let mut t_2 = input[start_index + 1];
95            t_2.double_in_place();
96            t_2.add_assign(&t_1);
97            let mut t_3 = input[start_index + 3];
98            t_3.double_in_place();
99            t_3.add_assign(&t_0);
100            let mut t_4 = t_1;
101            t_4.double_in_place();
102            t_4.double_in_place();
103            t_4.add_assign(&t_3);
104            let mut t_5 = t_0;
105            t_5.double_in_place();
106            t_5.double_in_place();
107            t_5.add_assign(&t_2);
108            let mut t_6 = t_3;
109            t_6.add_assign(&t_5);
110            let mut t_7 = t_2;
111            t_7.add_assign(&t_4);
112            input[start_index] = t_6;
113            input[start_index + 1] = t_5;
114            input[start_index + 2] = t_7;
115            input[start_index + 3] = t_4;
116        }
117    }
118
119    fn matmul_external(&self, input: &mut[F]) {
120        let t = self.params.t;
121        match t {
122            2 => {
123                // Matrix circ(2, 1)
124                let mut sum = input[0];
125                sum.add_assign(&input[1]);
126                input[0].add_assign(&sum);
127                input[1].add_assign(&sum);
128            }
129            3 => {
130                // Matrix circ(2, 1, 1)
131                let mut sum = input[0];
132                sum.add_assign(&input[1]);
133                sum.add_assign(&input[2]);
134                input[0].add_assign(&sum);
135                input[1].add_assign(&sum);
136                input[2].add_assign(&sum);
137            }
138            4 => {
139                // Applying cheap 4x4 MDS matrix to each 4-element part of the state
140                self.matmul_m4(input);
141            }
142            8 | 12 | 16 | 20 | 24 => {
143                // Applying cheap 4x4 MDS matrix to each 4-element part of the state
144                self.matmul_m4(input);
145
146                // Applying second cheap matrix for t > 4
147                let t4 = t / 4;
148                let mut stored = [F::zero(); 4];
149                for l in 0..4 {
150                    stored[l] = input[l];
151                    for j in 1..t4 {
152                        stored[l].add_assign(&input[4 * j + l]);
153                    }
154                }
155                for i in 0..input.len() {
156                    input[i].add_assign(&stored[i % 4]);
157                }
158            }
159            _ => {
160                panic!()
161            }
162        }
163    }
164
165    fn matmul_internal(&self, input: &mut[F], mat_internal_diag_m_1: &[F]) {
166        let t = self.params.t;
167
168        match t {
169            2 => {
170                // [2, 1]
171                // [1, 3]
172                let mut sum = input[0];
173                sum.add_assign(&input[1]);
174                input[0].add_assign(&sum);
175                input[1].double_in_place();
176                input[1].add_assign(&sum);
177            }
178            3 => {
179                // [2, 1, 1]
180                // [1, 2, 1]
181                // [1, 1, 3]
182                let mut sum = input[0];
183                sum.add_assign(&input[1]);
184                sum.add_assign(&input[2]);
185                input[0].add_assign(&sum);
186                input[1].add_assign(&sum);
187                input[2].double_in_place();
188                input[2].add_assign(&sum);
189            }
190            4 | 8 | 12 | 16 | 20 | 24 => {
191                // Compute input sum
192                let mut sum = input[0];
193                input
194                    .iter()
195                    .skip(1)
196                    .take(t-1)
197                    .for_each(|el| sum.add_assign(el));
198                // Add sum + diag entry * element to each element
199                for i in 0..input.len() {
200                    input[i].mul_assign(&mat_internal_diag_m_1[i]);
201                    input[i].add_assign(&sum);
202                }
203            }
204            _ => {
205                panic!()
206            }
207        }
208    }
209
210    fn add_rc(&self, input: &[F], rc: &[F]) -> Vec<F> {
211        input
212            .iter()
213            .zip(rc.iter())
214            .map(|(a, b)| {
215                let mut r = *a;
216                r.add_assign(b);
217                r
218            })
219            .collect()
220    }
221}
222
223impl<F: PrimeField> MerkleTreeHash<F> for Poseidon2<F> {
224    fn compress(&self, input: &[&F]) -> F {
225        self.permutation(&[input[0].to_owned(), input[1].to_owned(), F::zero()])[0]
226    }
227}
228
229#[allow(unused_imports)]
230#[cfg(test)]
231mod poseidon2_tests_goldilocks {
232    use super::*;
233    use crate::{fields::{goldilocks::FpGoldiLocks, utils::from_hex, utils::random_scalar}};
234    use crate::poseidon2::poseidon2_instance_goldilocks::{
235        POSEIDON2_GOLDILOCKS_8_PARAMS,
236        POSEIDON2_GOLDILOCKS_12_PARAMS,
237        POSEIDON2_GOLDILOCKS_16_PARAMS,
238        POSEIDON2_GOLDILOCKS_20_PARAMS,
239    };
240    use std::convert::TryFrom;
241
242    type Scalar = FpGoldiLocks;
243
244    static TESTRUNS: usize = 5;
245
246    #[test]
247    fn consistent_perm() {
248        let instances = vec![
249            Poseidon2::new(&POSEIDON2_GOLDILOCKS_8_PARAMS),
250            Poseidon2::new(&POSEIDON2_GOLDILOCKS_12_PARAMS),
251            Poseidon2::new(&POSEIDON2_GOLDILOCKS_16_PARAMS),
252            Poseidon2::new(&POSEIDON2_GOLDILOCKS_20_PARAMS),
253        ];
254        for instance in instances {
255            let t = instance.params.t;
256            for _ in 0..TESTRUNS {
257                let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
258
259                let mut input2: Vec<Scalar>;
260                loop {
261                    input2 = (0..t).map(|_| random_scalar()).collect();
262                    if input1 != input2 {
263                        break;
264                    }
265                }
266
267                let perm1 = instance.permutation(&input1);
268                let perm2 = instance.permutation(&input1);
269                let perm3 = instance.permutation(&input2);
270                assert_eq!(perm1, perm2);
271                assert_ne!(perm1, perm3);
272            }
273        }
274    }
275
276    #[test]
277    fn kats() {
278        let poseidon2 = Poseidon2::new(&POSEIDON2_GOLDILOCKS_12_PARAMS);
279        let mut input: Vec<Scalar> = vec![];
280        for i in 0..poseidon2.params.t {
281            input.push(Scalar::from(i as u64));
282        }
283        let perm = poseidon2.permutation(&input);
284        assert_eq!(perm[0], from_hex("0x01eaef96bdf1c0c1"));
285        assert_eq!(perm[1], from_hex("0x1f0d2cc525b2540c"));
286        assert_eq!(perm[2], from_hex("0x6282c1dfe1e0358d"));
287        assert_eq!(perm[3], from_hex("0xe780d721f698e1e6"));
288        assert_eq!(perm[4], from_hex("0x280c0b6f753d833b"));
289        assert_eq!(perm[5], from_hex("0x1b942dd5023156ab"));
290        assert_eq!(perm[6], from_hex("0x43f0df3fcccb8398"));
291        assert_eq!(perm[7], from_hex("0xe8e8190585489025"));
292        assert_eq!(perm[8], from_hex("0x56bdbf72f77ada22"));
293        assert_eq!(perm[9], from_hex("0x7911c32bf9dcd705"));
294        assert_eq!(perm[10], from_hex("0xec467926508fbe67"));
295        assert_eq!(perm[11], from_hex("0x6a50450ddf85a6ed"));
296    }
297}
298
299#[allow(unused_imports)]
300#[cfg(test)]
301mod poseidon2_tests_babybear {
302    use super::*;
303    use crate::{fields::{babybear::FpBabyBear, utils::from_hex, utils::random_scalar}};
304    use crate::poseidon2::poseidon2_instance_babybear::{
305        POSEIDON2_BABYBEAR_16_PARAMS,
306        POSEIDON2_BABYBEAR_24_PARAMS,
307    };
308    use std::convert::TryFrom;
309
310    type Scalar = FpBabyBear;
311
312    static TESTRUNS: usize = 5;
313
314    #[test]
315    fn consistent_perm() {
316        let instances = vec![
317            Poseidon2::new(&POSEIDON2_BABYBEAR_16_PARAMS),
318            Poseidon2::new(&POSEIDON2_BABYBEAR_24_PARAMS)
319        ];
320        for instance in instances {
321            let t = instance.params.t;
322            for _ in 0..TESTRUNS {
323                let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
324
325                let mut input2: Vec<Scalar>;
326                loop {
327                    input2 = (0..t).map(|_| random_scalar()).collect();
328                    if input1 != input2 {
329                        break;
330                    }
331                }
332
333                let perm1 = instance.permutation(&input1);
334                let perm2 = instance.permutation(&input1);
335                let perm3 = instance.permutation(&input2);
336                assert_eq!(perm1, perm2);
337                assert_ne!(perm1, perm3);
338            }
339        }
340    }
341
342    #[test]
343    fn kats() {
344        let poseidon2 = Poseidon2::new(&POSEIDON2_BABYBEAR_24_PARAMS);
345        let mut input: Vec<Scalar> = vec![];
346        for i in 0..poseidon2.params.t {
347            input.push(Scalar::from(i as u64));
348        }
349        let perm = poseidon2.permutation(&input);
350        assert_eq!(perm[0], from_hex("0x2ed3e23d"));
351        assert_eq!(perm[1], from_hex("0x12921fb0"));
352        assert_eq!(perm[2], from_hex("0x0e659e79"));
353        assert_eq!(perm[3], from_hex("0x61d81dc9"));
354        assert_eq!(perm[4], from_hex("0x32bae33b"));
355        assert_eq!(perm[5], from_hex("0x62486ae3"));
356        assert_eq!(perm[6], from_hex("0x1e681b60"));
357        assert_eq!(perm[7], from_hex("0x24b91325"));
358        assert_eq!(perm[8], from_hex("0x2a2ef5b9"));
359        assert_eq!(perm[9], from_hex("0x50e8593e"));
360        assert_eq!(perm[10], from_hex("0x5bc818ec"));
361        assert_eq!(perm[11], from_hex("0x10691997"));
362        assert_eq!(perm[12], from_hex("0x35a14520"));
363        assert_eq!(perm[13], from_hex("0x2ba6a3c5"));
364        assert_eq!(perm[14], from_hex("0x279d47ec"));
365        assert_eq!(perm[15], from_hex("0x55014e81"));
366        assert_eq!(perm[16], from_hex("0x5953a67f"));
367        assert_eq!(perm[17], from_hex("0x2f403111"));
368        assert_eq!(perm[18], from_hex("0x6b8828ff"));
369        assert_eq!(perm[19], from_hex("0x1801301f"));
370        assert_eq!(perm[20], from_hex("0x2749207a"));
371        assert_eq!(perm[21], from_hex("0x3dc9cf21"));
372        assert_eq!(perm[22], from_hex("0x3c985ba2"));
373        assert_eq!(perm[23], from_hex("0x57a99864"));
374    }
375}
376
377#[allow(unused_imports)]
378#[cfg(test)]
379mod poseidon2_tests_bls12 {
380    use super::*;
381    use crate::{fields::{bls12::FpBLS12, utils::from_hex, utils::random_scalar}};
382    use crate::poseidon2::poseidon2_instance_bls12::{
383        POSEIDON2_BLS_2_PARAMS,
384        POSEIDON2_BLS_3_PARAMS,
385        POSEIDON2_BLS_4_PARAMS,
386        POSEIDON2_BLS_8_PARAMS,
387    };
388    use std::convert::TryFrom;
389
390    type Scalar = FpBLS12;
391
392    static TESTRUNS: usize = 5;
393
394    #[test]
395    fn consistent_perm() {
396        let instances = vec![
397            Poseidon2::new(&POSEIDON2_BLS_2_PARAMS),
398            Poseidon2::new(&POSEIDON2_BLS_3_PARAMS),
399            Poseidon2::new(&POSEIDON2_BLS_4_PARAMS),
400            Poseidon2::new(&POSEIDON2_BLS_8_PARAMS)
401        ];
402        for instance in instances {
403            let t = instance.params.t;
404            for _ in 0..TESTRUNS {
405                let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
406
407                let mut input2: Vec<Scalar>;
408                loop {
409                    input2 = (0..t).map(|_| random_scalar()).collect();
410                    if input1 != input2 {
411                        break;
412                    }
413                }
414
415                let perm1 = instance.permutation(&input1);
416                let perm2 = instance.permutation(&input1);
417                let perm3 = instance.permutation(&input2);
418                assert_eq!(perm1, perm2);
419                assert_ne!(perm1, perm3);
420            }
421        }
422    }
423
424    #[test]
425    fn kats() {
426        let poseidon2_2 = Poseidon2::new(&POSEIDON2_BLS_2_PARAMS);
427        let mut input_2: Vec<Scalar> = vec![];
428        for i in 0..poseidon2_2.params.t {
429            input_2.push(Scalar::from(i as u64));
430        }
431        let perm_2 = poseidon2_2.permutation(&input_2);
432        assert_eq!(perm_2[0], from_hex("0x73c46dd530e248a87b61d19e67fa1b4ed30fc3d09f16531fe189fb945a15ce4e"));
433        assert_eq!(perm_2[1], from_hex("0x1f0e305ee21c9366d5793b80251405032a3fee32b9dd0b5f4578262891b043b4"));
434
435        let poseidon2_3 = Poseidon2::new(&POSEIDON2_BLS_3_PARAMS);
436        let mut input_3: Vec<Scalar> = vec![];
437        for i in 0..poseidon2_3.params.t {
438            input_3.push(Scalar::from(i as u64));
439        }
440        let perm_3 = poseidon2_3.permutation(&input_3);
441        assert_eq!(perm_3[0], from_hex("0x1b152349b1950b6a8ca75ee4407b6e26ca5cca5650534e56ef3fd45761fbf5f0"));
442        assert_eq!(perm_3[1], from_hex("0x4c5793c87d51bdc2c08a32108437dc0000bd0275868f09ebc5f36919af5b3891"));
443        assert_eq!(perm_3[2], from_hex("0x1fc8ed171e67902ca49863159fe5ba6325318843d13976143b8125f08b50dc6b"));
444
445        let poseidon2_4 = Poseidon2::new(&POSEIDON2_BLS_4_PARAMS);
446        let mut input_4: Vec<Scalar> = vec![];
447        for i in 0..poseidon2_4.params.t {
448            input_4.push(Scalar::from(i as u64));
449        }
450        let perm_4 = poseidon2_4.permutation(&input_4);
451        assert_eq!(perm_4[0], from_hex("0x28ff6c4edf9768c08ae26290487e93449cc8bc155fc2fad92a344adceb3ada6d"));
452        assert_eq!(perm_4[1], from_hex("0x0e56f2b6fad25075aa93560185b70e2b180ed7e269159c507c288b6747a0db2d"));
453        assert_eq!(perm_4[2], from_hex("0x6d8196f28da6006bb89b3df94600acdc03d0ba7c2b0f3f4409a54c1db6bf30d0"));
454        assert_eq!(perm_4[3], from_hex("0x07cfb49540ee456cce38b8a7d1a930a57ffc6660737f6589ef184c5e15334e36"));
455    }
456}
457
458#[allow(unused_imports)]
459#[cfg(test)]
460mod poseidon2_tests_bn256 {
461    use super::*;
462    use crate::{fields::{bn256::FpBN256, utils::from_hex, utils::random_scalar}, poseidon2::poseidon2_instance_bn256::POSEIDON2_BN256_PARAMS};
463    use std::convert::TryFrom;
464
465    type Scalar = FpBN256;
466
467    static TESTRUNS: usize = 5;
468
469    #[test]
470    fn consistent_perm() {
471        let poseidon2 = Poseidon2::new(&POSEIDON2_BN256_PARAMS);
472        let t = poseidon2.params.t;
473        for _ in 0..TESTRUNS {
474            let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
475
476            let mut input2: Vec<Scalar>;
477            loop {
478                input2 = (0..t).map(|_| random_scalar()).collect();
479                if input1 != input2 {
480                    break;
481                }
482            }
483
484            let perm1 = poseidon2.permutation(&input1);
485            let perm2 = poseidon2.permutation(&input1);
486            let perm3 = poseidon2.permutation(&input2);
487            assert_eq!(perm1, perm2);
488            assert_ne!(perm1, perm3);
489        }
490    }
491
492    #[test]
493    fn kats() {
494        let poseidon2 = Poseidon2::new(&POSEIDON2_BN256_PARAMS);
495        let mut input: Vec<Scalar> = vec![];
496        for i in 0..poseidon2.params.t {
497            input.push(Scalar::from(i as u64));
498        }
499        let perm = poseidon2.permutation(&input);
500        assert_eq!(perm[0], from_hex("0x0bb61d24daca55eebcb1929a82650f328134334da98ea4f847f760054f4a3033"));
501        assert_eq!(perm[1], from_hex("0x303b6f7c86d043bfcbcc80214f26a30277a15d3f74ca654992defe7ff8d03570"));
502        assert_eq!(perm[2], from_hex("0x1ed25194542b12eef8617361c3ba7c52e660b145994427cc86296242cf766ec8"));
503
504    }
505}
506
507#[allow(unused_imports)]
508#[cfg(test)]
509mod poseidon2_tests_pallas {
510    use super::*;
511    use crate::{fields::{pallas::FpPallas, utils::from_hex, utils::random_scalar}};
512    use crate::poseidon2::poseidon2_instance_pallas::{
513        POSEIDON2_PALLAS_3_PARAMS,
514        POSEIDON2_PALLAS_4_PARAMS,
515        POSEIDON2_PALLAS_8_PARAMS,
516    };
517    use std::convert::TryFrom;
518
519    type Scalar = FpPallas;
520
521    static TESTRUNS: usize = 5;
522
523    #[test]
524    fn consistent_perm() {
525        let instances = vec![
526            Poseidon2::new(&POSEIDON2_PALLAS_3_PARAMS),
527            Poseidon2::new(&POSEIDON2_PALLAS_4_PARAMS),
528            Poseidon2::new(&POSEIDON2_PALLAS_8_PARAMS)
529        ];
530        for instance in instances {
531            let t = instance.params.t;
532            for _ in 0..TESTRUNS {
533                let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
534
535                let mut input2: Vec<Scalar>;
536                loop {
537                    input2 = (0..t).map(|_| random_scalar()).collect();
538                    if input1 != input2 {
539                        break;
540                    }
541                }
542
543                let perm1 = instance.permutation(&input1);
544                let perm2 = instance.permutation(&input1);
545                let perm3 = instance.permutation(&input2);
546                assert_eq!(perm1, perm2);
547                assert_ne!(perm1, perm3);
548            }
549        }
550    }
551
552    #[test]
553    fn kats() {
554        let poseidon2 = Poseidon2::new(&POSEIDON2_PALLAS_3_PARAMS);
555        let mut input: Vec<Scalar> = vec![];
556        for i in 0..poseidon2.params.t {
557            input.push(Scalar::from(i as u64));
558        }
559        let perm = poseidon2.permutation(&input);
560        assert_eq!(perm[0], from_hex("0x1a9b54c7512a914dd778282c44b3513fea7251420b9d95750baae059b2268d7a"));
561        assert_eq!(perm[1], from_hex("0x1c48ea0994a7d7984ea338a54dbf0c8681f5af883fe988d59ba3380c9f7901fc"));
562        assert_eq!(perm[2], from_hex("0x079ddd0a80a3e9414489b526a2770448964766685f4c4842c838f8a23120b401"));
563
564    }
565}
566
567#[allow(unused_imports)]
568#[cfg(test)]
569mod poseidon2_tests_vesta {
570    use super::*;
571    use crate::{fields::{vesta::FpVesta, utils::from_hex, utils::random_scalar}, poseidon2::poseidon2_instance_vesta::POSEIDON2_VESTA_PARAMS};
572    use std::convert::TryFrom;
573
574    type Scalar = FpVesta;
575
576    static TESTRUNS: usize = 5;
577
578    #[test]
579    fn consistent_perm() {
580        let poseidon2 = Poseidon2::new(&POSEIDON2_VESTA_PARAMS);
581        let t = poseidon2.params.t;
582        for _ in 0..TESTRUNS {
583            let input1: Vec<Scalar> = (0..t).map(|_| random_scalar()).collect();
584
585            let mut input2: Vec<Scalar>;
586            loop {
587                input2 = (0..t).map(|_| random_scalar()).collect();
588                if input1 != input2 {
589                    break;
590                }
591            }
592
593            let perm1 = poseidon2.permutation(&input1);
594            let perm2 = poseidon2.permutation(&input1);
595            let perm3 = poseidon2.permutation(&input2);
596            assert_eq!(perm1, perm2);
597            assert_ne!(perm1, perm3);
598        }
599    }
600
601    #[test]
602    fn kats() {
603        let poseidon2 = Poseidon2::new(&POSEIDON2_VESTA_PARAMS);
604        let mut input: Vec<Scalar> = vec![];
605        for i in 0..poseidon2.params.t {
606            input.push(Scalar::from(i as u64));
607        }
608        let perm = poseidon2.permutation(&input);
609        assert_eq!(perm[0], from_hex("0x261ecbdfd62c617b82d297705f18c788fc9831b14a6a2b8f61229bef68ce2792"));
610        assert_eq!(perm[1], from_hex("0x2c76327e0b7653873263158cf8545c282364b183880fcdea93ca8526d518c66f"));
611        assert_eq!(perm[2], from_hex("0x262316c0ce5244838c75873299b59d763ae0849d2dd31bdc95caf7db1c2901bf"));
612
613    }
614}