zkhash/neptune/
neptune.rs

1use crate::merkle_tree::merkle_tree_fp::MerkleTreeHash;
2
3use super::neptune_params::NeptuneParams;
4use ark_ff::PrimeField;
5use std::sync::Arc;
6
7#[derive(Clone, Debug)]
8pub struct Neptune<S: PrimeField> {
9    pub(crate) params: Arc<NeptuneParams<S>>,
10}
11
12impl<S: PrimeField> Neptune<S> {
13    pub fn new(params: &Arc<NeptuneParams<S>>) -> Self {
14        Neptune {
15            params: Arc::clone(params),
16        }
17    }
18
19    pub fn get_t(&self) -> usize {
20        self.params.t
21    }
22
23    fn external_round(&self, input: &[S], r: usize) -> Vec<S> {
24        let output = self.external_sbox(input);
25        let output = self.external_matmul(&output);
26        self.add_rc(&output, &self.params.round_constants[r])
27    }
28
29    fn internal_round(&self, input: &[S], r: usize) -> Vec<S> {
30        let output = self.internal_sbox(input);
31        let output = self.internal_matmul(&output);
32        self.add_rc(&output, &self.params.round_constants[r])
33    }
34
35    fn add_rc(&self, input: &[S], rc: &[S]) -> Vec<S> {
36        input
37            .iter()
38            .zip(rc.iter())
39            .map(|(a, b)| {
40                let mut r = *a;
41                r.add_assign(b);
42                r
43            })
44            .collect()
45    }
46
47    fn sbox_d(&self, input: &S) -> S {
48        let mut input2 = *input;
49        input2.square_in_place();
50
51        match self.params.d {
52            3 => {
53                let mut out = input2;
54                out.mul_assign(input);
55                out
56            }
57            5 => {
58                let mut out = input2;
59                out.square_in_place();
60                out.mul_assign(input);
61                out
62            }
63            7 => {
64                let mut out = input2;
65                out.square_in_place();
66                out.mul_assign(&input2);
67                out.mul_assign(input);
68                out
69            }
70            _ => {
71                panic!();
72            }
73        }
74    }
75
76    fn external_sbox_prime(&self, x1: &S, x2: &S) -> (S, S) {
77        let mut zi = x1.to_owned();
78        zi.sub_assign(x2);
79        let mut zib = zi;
80        zib.square_in_place();
81        // zib.mul_assign(&self.params.abc[1]); // beta = 1
82
83        // first terms
84        let mut sum = x1.to_owned();
85        sum.add_assign(x2);
86        let mut y1 = sum.to_owned();
87        let mut y2 = sum.to_owned();
88        y1.add_assign(x1);
89        y2.add_assign(x2);
90        y2.add_assign(x2);
91        // y1.mul_assign(&self.params.a_[0]); // alpha = 1
92        // y2.mul_assign(&self.params.a_[0]); // alpha = 1
93
94        // middle terms
95        let mut tmp1 = zib.to_owned();
96        tmp1.double_in_place();
97        let mut tmp2 = tmp1.to_owned();
98        tmp1.add_assign(&zib);
99        tmp2.double_in_place();
100        // tmp1.mul_assign(&self.params.a_[1]); // done with additions, since alpha = beta = 1
101        // tmp2.mul_assign(&self.params.a_[2]); // done with additions, since alpha = beta = 1
102        y1.add_assign(&tmp1);
103        y2.add_assign(&tmp2);
104
105        // third terms
106        let mut tmp = zi.to_owned();
107        tmp.sub_assign(x2);
108        // tmp.mul_assign(&self.params.abc[0]); // alpha = 1
109        tmp.sub_assign(&zib);
110        tmp.add_assign(&self.params.abc[2]);
111        tmp.square_in_place();
112        // tmp.mul_assign(&self.params.abc[1]); // beta = 1
113        y1.add_assign(&tmp);
114        y2.add_assign(&tmp);
115
116        (y1, y2)
117    }
118
119    fn external_sbox(&self, input: &[S]) -> Vec<S> {
120        let t = input.len();
121        let t_ = t >> 1;
122        let mut output = vec![S::zero(); t];
123        for i in 0..t_ {
124            let out = self.external_sbox_prime(&input[2 * i], &input[2 * i + 1]);
125            output[2 * i] = out.0;
126            output[2 * i + 1] = out.1;
127        }
128        output
129    }
130
131    fn internal_sbox(&self, input: &[S]) -> Vec<S> {
132        let mut output = input.to_owned();
133        output[0] = self.sbox_d(&input[0]);
134        output
135    }
136
137    fn external_matmul_4(input: &[S]) -> Vec<S> {
138        let mut output = input.to_owned();
139        output.swap(1, 3);
140
141        let mut sum1 = input[0].to_owned();
142        sum1.add_assign(&input[2]);
143        let mut sum2 = input[1].to_owned();
144        sum2.add_assign(&input[3]);
145
146        output[0].add_assign(&sum1);
147        output[1].add_assign(&sum2);
148        output[2].add_assign(&sum1);
149        output[3].add_assign(&sum2);
150
151        output
152    }
153
154    fn external_matmul_8(input: &[S]) -> Vec<S> {
155        // multiplication by circ(3 2 1 1) is equal to state + state + rot(state) + sum(state)
156        let mut output = input.to_owned();
157        output.swap(1, 7);
158        output.swap(3, 5);
159
160        let mut sum1 = input[0].to_owned();
161        let mut sum2 = input[1].to_owned();
162
163        input
164            .iter()
165            .step_by(2)
166            .skip(1)
167            .for_each(|el| sum1.add_assign(el));
168        input
169            .iter()
170            .skip(1)
171            .step_by(2)
172            .skip(1)
173            .for_each(|el| sum2.add_assign(el));
174
175        let mut output_rot = output.to_owned();
176        output_rot.rotate_left(2);
177
178        for ((i, el), rot) in output.iter_mut().enumerate().zip(output_rot.iter()) {
179            el.double_in_place();
180            el.add_assign(rot);
181            if i & 1 == 0 {
182                el.add_assign(&sum1);
183            } else {
184                el.add_assign(&sum2);
185            }
186        }
187
188        output.swap(3, 7);
189        output
190    }
191
192    fn external_matmul(&self, input: &[S]) -> Vec<S> {
193        let t = self.params.t;
194
195        if t == 4 {
196            return Self::external_matmul_4(input);
197        } else if t == 8 {
198            return Self::external_matmul_8(input);
199        }
200
201        let mut out = vec![S::zero(); t];
202        let t_ = t >> 1;
203        for row in 0..t_ {
204            for col in 0..t_ {
205                // even rows
206                let mut tmp_e = self.params.m_e[2 * row][2 * col];
207                tmp_e.mul_assign(&input[2 * col]);
208                out[2 * row].add_assign(&tmp_e);
209
210                // odd rows
211                let mut tmp_o = self.params.m_e[2 * row + 1][2 * col + 1];
212                tmp_o.mul_assign(&input[2 * col + 1]);
213                out[2 * row + 1].add_assign(&tmp_o);
214            }
215        }
216        out
217    }
218
219    fn internal_matmul(&self, input: &[S]) -> Vec<S> {
220        let mut out = input.to_owned();
221
222        let mut sum = input[0];
223        input.iter().skip(1).for_each(|el| sum.add_assign(el));
224
225        for (o, mu) in out.iter_mut().zip(self.params.mu.iter()) {
226            o.mul_assign(mu);
227            // o.sub_assign(input[row]); // Already done in parameter creation
228            o.add_assign(&sum);
229        }
230        out
231    }
232
233    pub fn permutation(&self, input: &[S]) -> Vec<S> {
234        let t = self.params.t;
235        assert_eq!(input.len(), t);
236
237        // inital matmul
238        let mut current_state = self.external_matmul(input);
239
240        for r in 0..self.params.rounds_f_beginning {
241            current_state = self.external_round(&current_state, r);
242        }
243        let p_end = self.params.rounds_f_beginning + self.params.rounds_p;
244        for r in self.params.rounds_f_beginning..p_end {
245            current_state = self.internal_round(&current_state, r);
246        }
247        for r in p_end..self.params.rounds {
248            current_state = self.external_round(&current_state, r);
249        }
250
251        current_state
252    }
253}
254
255impl<S: PrimeField> MerkleTreeHash<S> for Neptune<S> {
256    fn compress(&self, input: &[&S]) -> S {
257        self.permutation(&[
258            input[0].to_owned(),
259            input[1].to_owned(),
260            S::zero(),
261            S::zero(),
262        ])[0]
263    }
264}
265
266#[cfg(test)]
267mod neptune_tests_bls12 {
268    use super::*;
269    use crate::{fields::{bls12::FpBLS12, utils}};
270    use crate::neptune::neptune_instances::{
271        NEPTUNE_BLS_4_PARAMS,
272        NEPTUNE_BLS_8_PARAMS,
273    };
274    type Scalar = FpBLS12;
275
276    static TESTRUNS: usize = 5;
277
278    fn matmul(input: &[Scalar], mat: &[Vec<Scalar>]) -> Vec<Scalar> {
279        let t = mat.len();
280        debug_assert!(t == input.len());
281        let mut out = vec![Scalar::from(0); t];
282        for row in 0..t {
283            for (col, inp) in input.iter().enumerate().take(t) {
284                let mut tmp: Scalar = mat[row][col];
285                tmp *= inp;
286                out[row] += tmp;
287            }
288        }
289        out
290    }
291
292    fn build_mi(neptune_params: &Arc<NeptuneParams<Scalar>>) -> Vec<Vec<Scalar>> {
293        let t = neptune_params.t;
294        let mut mi = vec![vec![Scalar::from(1); t]; t];
295        for (i, matrow) in mi.iter_mut().enumerate().take(t) {
296            matrow[i] = neptune_params.mu[i];
297            matrow[i] += Scalar::from(1); // Compensate for subtraction in parameter creation
298        }
299        mi
300    }
301
302    fn matmul_equalities(t: usize) {
303        let neptune_params = Arc::new(NeptuneParams::<Scalar>::new(t, 3, 2, 1));
304        let neptune = Neptune::new(&neptune_params);
305        let t = neptune.params.t;
306
307        // check external matrix
308        let me = &neptune_params.m_e;
309        for (row, matrow) in me.iter().enumerate().take(t) {
310            for (col, matrowcol) in matrow.iter().enumerate().take(t) {
311                if (row + col) % 2 == 0 {
312                    assert!(*matrowcol != Scalar::from(0));
313                } else {
314                    assert_eq!(*matrowcol, Scalar::from(0));
315                }
316            }
317        }
318
319        let mi = build_mi(&neptune_params);
320        for _ in 0..TESTRUNS {
321            let input: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
322            let external1 = neptune.external_matmul(&input);
323            let external2 = matmul(&input, me);
324            assert_eq!(external1, external2);
325
326            let internal1 = neptune.internal_matmul(&input);
327            let internal2 = matmul(&input, &mi);
328            assert_eq!(internal1, internal2);
329        }
330    }
331
332    #[test]
333    fn matmul_equalities_4() {
334        matmul_equalities(4);
335    }
336
337    #[test]
338    fn matmul_equalities_6() {
339        matmul_equalities(6);
340    }
341
342    #[test]
343    fn matmul_equalities_8() {
344        matmul_equalities(8);
345    }
346
347    #[test]
348    fn matmul_equalities_10() {
349        matmul_equalities(10);
350    }
351
352    #[test]
353    fn matmul_equalities_60() {
354        matmul_equalities(60);
355    }
356
357    #[test]
358    fn consistent_perm() {
359        let instances = vec![
360            Neptune::new(&NEPTUNE_BLS_4_PARAMS),
361            Neptune::new(&NEPTUNE_BLS_8_PARAMS),
362        ];
363        for instance in instances {
364            let t = instance.params.t;
365            for _ in 0..TESTRUNS {
366                let input1: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
367
368                let mut input2: Vec<Scalar>;
369                loop {
370                    input2 = (0..t).map(|_| utils::random_scalar()).collect();
371                    if input1 != input2 {
372                        break;
373                    }
374                }
375
376                let perm1 = instance.permutation(&input1);
377                let perm2 = instance.permutation(&input1);
378                let perm3 = instance.permutation(&input2);
379                assert_eq!(perm1, perm2);
380                assert_ne!(perm1, perm3);
381            }
382        }
383    }
384}
385
386#[cfg(test)]
387mod neptune_tests_bn256 {
388    use super::*;
389    use crate::{
390        fields::{bn256::FpBN256, utils},
391        neptune::neptune_instances::NEPTUNE_BN_PARAMS,
392    };
393    type Scalar = FpBN256;
394
395    static TESTRUNS: usize = 5;
396
397    fn matmul(input: &[Scalar], mat: &[Vec<Scalar>]) -> Vec<Scalar> {
398        let t = mat.len();
399        debug_assert!(t == input.len());
400        let mut out = vec![Scalar::from(0); t];
401        for row in 0..t {
402            for (col, inp) in input.iter().enumerate().take(t) {
403                let mut tmp = mat[row][col];
404                tmp *= inp;
405                out[row] += tmp;
406            }
407        }
408        out
409    }
410
411    fn build_mi(neptune_params: &Arc<NeptuneParams<Scalar>>) -> Vec<Vec<Scalar>> {
412        let t = neptune_params.t;
413        let mut mi = vec![vec![Scalar::from(1); t]; t];
414        for (i, matrow) in mi.iter_mut().enumerate().take(t) {
415            matrow[i] = neptune_params.mu[i];
416            matrow[i] += Scalar::from(1); // Compensate for subtraction in parameter creation
417        }
418        mi
419    }
420
421    fn matmul_equalities(t: usize) {
422        let neptune_params = Arc::new(NeptuneParams::<Scalar>::new(t, 3, 2, 1));
423        let neptune = Neptune::new(&neptune_params);
424        let t = neptune.params.t;
425
426        // check external matrix
427        let me = &neptune_params.m_e;
428        for (row, matrow) in me.iter().enumerate().take(t) {
429            for (col, matrowcol) in matrow.iter().enumerate().take(t) {
430                if (row + col) % 2 == 0 {
431                    assert!(*matrowcol != Scalar::from(0));
432                } else {
433                    assert_eq!(*matrowcol, Scalar::from(0));
434                }
435            }
436        }
437
438        let mi = build_mi(&neptune_params);
439        for _ in 0..TESTRUNS {
440            let input: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
441            let external1 = neptune.external_matmul(&input);
442            let external2 = matmul(&input, me);
443            assert_eq!(external1, external2);
444
445            let internal1 = neptune.internal_matmul(&input);
446            let internal2 = matmul(&input, &mi);
447            assert_eq!(internal1, internal2);
448        }
449    }
450
451    #[test]
452    fn matmul_equalities_4() {
453        matmul_equalities(4);
454    }
455
456    #[test]
457    fn matmul_equalities_6() {
458        matmul_equalities(6);
459    }
460
461    #[test]
462    fn matmul_equalities_8() {
463        matmul_equalities(8);
464    }
465
466    #[test]
467    fn matmul_equalities_10() {
468        matmul_equalities(10);
469    }
470
471    #[test]
472    fn matmul_equalities_60() {
473        matmul_equalities(60);
474    }
475
476    #[test]
477    fn consistent_perm() {
478        let neptune = Neptune::new(&NEPTUNE_BN_PARAMS);
479        let t = neptune.params.t;
480        for _ in 0..TESTRUNS {
481            let input1: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
482
483            let mut input2: Vec<Scalar>;
484            loop {
485                input2 = (0..t).map(|_| utils::random_scalar()).collect();
486                if input1 != input2 {
487                    break;
488                }
489            }
490
491            let perm1 = neptune.permutation(&input1);
492            let perm2 = neptune.permutation(&input1);
493            let perm3 = neptune.permutation(&input2);
494            assert_eq!(perm1, perm2);
495            assert_ne!(perm1, perm3);
496        }
497    }
498}
499
500#[cfg(test)]
501mod neptune_tests_goldilocks {
502    use super::*;
503    use crate::{fields::{goldilocks::FpGoldiLocks, utils}};
504    use crate::neptune::neptune_instances::{
505        NEPTUNE_GOLDILOCKS_8_PARAMS,
506        NEPTUNE_GOLDILOCKS_12_PARAMS,
507        NEPTUNE_GOLDILOCKS_16_PARAMS,
508        NEPTUNE_GOLDILOCKS_20_PARAMS,
509    };
510    type Scalar = FpGoldiLocks;
511
512    static TESTRUNS: usize = 5;
513
514    fn matmul(input: &[Scalar], mat: &[Vec<Scalar>]) -> Vec<Scalar> {
515        let t = mat.len();
516        debug_assert!(t == input.len());
517        let mut out = vec![Scalar::from(0); t];
518        for row in 0..t {
519            for (col, inp) in input.iter().enumerate().take(t) {
520                let mut tmp = mat[row][col];
521                tmp *= inp;
522                out[row] += tmp;
523            }
524        }
525        out
526    }
527
528    fn build_mi(neptune_params: &Arc<NeptuneParams<Scalar>>) -> Vec<Vec<Scalar>> {
529        let t = neptune_params.t;
530        let mut mi = vec![vec![Scalar::from(1); t]; t];
531        for (i, matrow) in mi.iter_mut().enumerate().take(t) {
532            matrow[i] = neptune_params.mu[i];
533            matrow[i] += Scalar::from(1); // Compensate for subtraction in parameter creation
534        }
535        mi
536    }
537
538    fn matmul_equalities(t: usize) {
539        let neptune_params = Arc::new(NeptuneParams::<Scalar>::new(t, 3, 2, 1));
540        let neptune = Neptune::new(&neptune_params);
541        let t = neptune.params.t;
542
543        // check external matrix
544        let me = &neptune_params.m_e;
545        for (row, matrow) in me.iter().enumerate().take(t) {
546            for (col, matrowcol) in matrow.iter().enumerate().take(t) {
547                if (row + col) % 2 == 0 {
548                    assert!(*matrowcol != Scalar::from(0));
549                } else {
550                    assert_eq!(*matrowcol, Scalar::from(0));
551                }
552            }
553        }
554
555        let mi = build_mi(&neptune_params);
556        for _ in 0..TESTRUNS {
557            let input: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
558            let external1 = neptune.external_matmul(&input);
559            let external2 = matmul(&input, me);
560            assert_eq!(external1, external2);
561
562            let internal1 = neptune.internal_matmul(&input);
563            let internal2 = matmul(&input, &mi);
564            assert_eq!(internal1, internal2);
565        }
566    }
567
568    #[test]
569    fn matmul_equalities_4() {
570        matmul_equalities(4);
571    }
572
573    #[test]
574    fn matmul_equalities_6() {
575        matmul_equalities(6);
576    }
577
578    #[test]
579    fn matmul_equalities_8() {
580        matmul_equalities(8);
581    }
582
583    #[test]
584    fn matmul_equalities_10() {
585        matmul_equalities(10);
586    }
587
588    #[test]
589    fn matmul_equalities_60() {
590        matmul_equalities(60);
591    }
592
593    #[test]
594    fn consistent_perm() {
595        let instances = vec![
596            Neptune::new(&NEPTUNE_GOLDILOCKS_8_PARAMS),
597            Neptune::new(&NEPTUNE_GOLDILOCKS_12_PARAMS),
598            Neptune::new(&NEPTUNE_GOLDILOCKS_16_PARAMS),
599            Neptune::new(&NEPTUNE_GOLDILOCKS_20_PARAMS),
600        ];
601        for instance in instances {
602            let t = instance.params.t;
603            for _ in 0..TESTRUNS {
604                let input1: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
605
606                let mut input2: Vec<Scalar>;
607                loop {
608                    input2 = (0..t).map(|_| utils::random_scalar()).collect();
609                    if input1 != input2 {
610                        break;
611                    }
612                }
613
614                let perm1 = instance.permutation(&input1);
615                let perm2 = instance.permutation(&input1);
616                let perm3 = instance.permutation(&input2);
617                assert_eq!(perm1, perm2);
618                assert_ne!(perm1, perm3);
619            }
620        }
621    }
622}
623
624#[cfg(test)]
625mod neptune_tests_babybear {
626    use super::*;
627    use crate::{
628        fields::{babybear::FpBabyBear, utils},
629        neptune::neptune_instances::NEPTUNE_BABYBEAR_16_PARAMS,
630        neptune::neptune_instances::NEPTUNE_BABYBEAR_24_PARAMS,
631    };
632    type Scalar = FpBabyBear;
633
634    static TESTRUNS: usize = 5;
635
636    fn matmul(input: &[Scalar], mat: &[Vec<Scalar>]) -> Vec<Scalar> {
637        let t = mat.len();
638        debug_assert!(t == input.len());
639        let mut out = vec![Scalar::from(0); t];
640        for row in 0..t {
641            for (col, inp) in input.iter().enumerate().take(t) {
642                let mut tmp = mat[row][col];
643                tmp *= inp;
644                out[row] += tmp;
645            }
646        }
647        out
648    }
649
650    fn build_mi(neptune_params: &Arc<NeptuneParams<Scalar>>) -> Vec<Vec<Scalar>> {
651        let t = neptune_params.t;
652        let mut mi = vec![vec![Scalar::from(1); t]; t];
653        for (i, matrow) in mi.iter_mut().enumerate().take(t) {
654            matrow[i] = neptune_params.mu[i];
655            matrow[i] += Scalar::from(1); // Compensate for subtraction in parameter creation
656        }
657        mi
658    }
659
660    fn matmul_equalities(t: usize) {
661        let neptune_params = Arc::new(NeptuneParams::<Scalar>::new(t, 3, 2, 1));
662        let neptune = Neptune::new(&neptune_params);
663        let t = neptune.params.t;
664
665        // check external matrix
666        let me = &neptune_params.m_e;
667        for (row, matrow) in me.iter().enumerate().take(t) {
668            for (col, matrowcol) in matrow.iter().enumerate().take(t) {
669                if (row + col) % 2 == 0 {
670                    assert!(*matrowcol != Scalar::from(0));
671                } else {
672                    assert_eq!(*matrowcol, Scalar::from(0));
673                }
674            }
675        }
676
677        let mi = build_mi(&neptune_params);
678        for _ in 0..TESTRUNS {
679            let input: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
680            let external1 = neptune.external_matmul(&input);
681            let external2 = matmul(&input, me);
682            assert_eq!(external1, external2);
683
684            let internal1 = neptune.internal_matmul(&input);
685            let internal2 = matmul(&input, &mi);
686            assert_eq!(internal1, internal2);
687        }
688    }
689
690    #[test]
691    fn matmul_equalities_4() {
692        matmul_equalities(4);
693    }
694
695    #[test]
696    fn matmul_equalities_6() {
697        matmul_equalities(6);
698    }
699
700    #[test]
701    fn matmul_equalities_8() {
702        matmul_equalities(8);
703    }
704
705    #[test]
706    fn matmul_equalities_10() {
707        matmul_equalities(10);
708    }
709
710    #[test]
711    fn matmul_equalities_60() {
712        matmul_equalities(60);
713    }
714
715    #[test]
716    fn consistent_perm() {
717        let instances = vec![
718            Neptune::new(&NEPTUNE_BABYBEAR_16_PARAMS),
719            Neptune::new(&NEPTUNE_BABYBEAR_24_PARAMS),
720        ];
721        for instance in instances {
722            let t = instance.params.t;
723            for _ in 0..TESTRUNS {
724                let input1: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
725
726                let mut input2: Vec<Scalar>;
727                loop {
728                    input2 = (0..t).map(|_| utils::random_scalar()).collect();
729                    if input1 != input2 {
730                        break;
731                    }
732                }
733
734                let perm1 = instance.permutation(&input1);
735                let perm2 = instance.permutation(&input1);
736                let perm3 = instance.permutation(&input2);
737                assert_eq!(perm1, perm2);
738                assert_ne!(perm1, perm3);
739            }
740        }
741    }
742}