1use crate::merkle_tree::merkle_tree_fp::MerkleTreeHash;
2
3use super::neptune_params::NeptuneParams;
4use ark_ff::PrimeField;
5use std::sync::Arc;
6
7#[derive(Clone, Debug)]
8pub struct Neptune<S: PrimeField> {
9 pub(crate) params: Arc<NeptuneParams<S>>,
10}
11
12impl<S: PrimeField> Neptune<S> {
13 pub fn new(params: &Arc<NeptuneParams<S>>) -> Self {
14 Neptune {
15 params: Arc::clone(params),
16 }
17 }
18
19 pub fn get_t(&self) -> usize {
20 self.params.t
21 }
22
23 fn external_round(&self, input: &[S], r: usize) -> Vec<S> {
24 let output = self.external_sbox(input);
25 let output = self.external_matmul(&output);
26 self.add_rc(&output, &self.params.round_constants[r])
27 }
28
29 fn internal_round(&self, input: &[S], r: usize) -> Vec<S> {
30 let output = self.internal_sbox(input);
31 let output = self.internal_matmul(&output);
32 self.add_rc(&output, &self.params.round_constants[r])
33 }
34
35 fn add_rc(&self, input: &[S], rc: &[S]) -> Vec<S> {
36 input
37 .iter()
38 .zip(rc.iter())
39 .map(|(a, b)| {
40 let mut r = *a;
41 r.add_assign(b);
42 r
43 })
44 .collect()
45 }
46
47 fn sbox_d(&self, input: &S) -> S {
48 let mut input2 = *input;
49 input2.square_in_place();
50
51 match self.params.d {
52 3 => {
53 let mut out = input2;
54 out.mul_assign(input);
55 out
56 }
57 5 => {
58 let mut out = input2;
59 out.square_in_place();
60 out.mul_assign(input);
61 out
62 }
63 7 => {
64 let mut out = input2;
65 out.square_in_place();
66 out.mul_assign(&input2);
67 out.mul_assign(input);
68 out
69 }
70 _ => {
71 panic!();
72 }
73 }
74 }
75
76 fn external_sbox_prime(&self, x1: &S, x2: &S) -> (S, S) {
77 let mut zi = x1.to_owned();
78 zi.sub_assign(x2);
79 let mut zib = zi;
80 zib.square_in_place();
81 let mut sum = x1.to_owned();
85 sum.add_assign(x2);
86 let mut y1 = sum.to_owned();
87 let mut y2 = sum.to_owned();
88 y1.add_assign(x1);
89 y2.add_assign(x2);
90 y2.add_assign(x2);
91 let mut tmp1 = zib.to_owned();
96 tmp1.double_in_place();
97 let mut tmp2 = tmp1.to_owned();
98 tmp1.add_assign(&zib);
99 tmp2.double_in_place();
100 y1.add_assign(&tmp1);
103 y2.add_assign(&tmp2);
104
105 let mut tmp = zi.to_owned();
107 tmp.sub_assign(x2);
108 tmp.sub_assign(&zib);
110 tmp.add_assign(&self.params.abc[2]);
111 tmp.square_in_place();
112 y1.add_assign(&tmp);
114 y2.add_assign(&tmp);
115
116 (y1, y2)
117 }
118
119 fn external_sbox(&self, input: &[S]) -> Vec<S> {
120 let t = input.len();
121 let t_ = t >> 1;
122 let mut output = vec![S::zero(); t];
123 for i in 0..t_ {
124 let out = self.external_sbox_prime(&input[2 * i], &input[2 * i + 1]);
125 output[2 * i] = out.0;
126 output[2 * i + 1] = out.1;
127 }
128 output
129 }
130
131 fn internal_sbox(&self, input: &[S]) -> Vec<S> {
132 let mut output = input.to_owned();
133 output[0] = self.sbox_d(&input[0]);
134 output
135 }
136
137 fn external_matmul_4(input: &[S]) -> Vec<S> {
138 let mut output = input.to_owned();
139 output.swap(1, 3);
140
141 let mut sum1 = input[0].to_owned();
142 sum1.add_assign(&input[2]);
143 let mut sum2 = input[1].to_owned();
144 sum2.add_assign(&input[3]);
145
146 output[0].add_assign(&sum1);
147 output[1].add_assign(&sum2);
148 output[2].add_assign(&sum1);
149 output[3].add_assign(&sum2);
150
151 output
152 }
153
154 fn external_matmul_8(input: &[S]) -> Vec<S> {
155 let mut output = input.to_owned();
157 output.swap(1, 7);
158 output.swap(3, 5);
159
160 let mut sum1 = input[0].to_owned();
161 let mut sum2 = input[1].to_owned();
162
163 input
164 .iter()
165 .step_by(2)
166 .skip(1)
167 .for_each(|el| sum1.add_assign(el));
168 input
169 .iter()
170 .skip(1)
171 .step_by(2)
172 .skip(1)
173 .for_each(|el| sum2.add_assign(el));
174
175 let mut output_rot = output.to_owned();
176 output_rot.rotate_left(2);
177
178 for ((i, el), rot) in output.iter_mut().enumerate().zip(output_rot.iter()) {
179 el.double_in_place();
180 el.add_assign(rot);
181 if i & 1 == 0 {
182 el.add_assign(&sum1);
183 } else {
184 el.add_assign(&sum2);
185 }
186 }
187
188 output.swap(3, 7);
189 output
190 }
191
192 fn external_matmul(&self, input: &[S]) -> Vec<S> {
193 let t = self.params.t;
194
195 if t == 4 {
196 return Self::external_matmul_4(input);
197 } else if t == 8 {
198 return Self::external_matmul_8(input);
199 }
200
201 let mut out = vec![S::zero(); t];
202 let t_ = t >> 1;
203 for row in 0..t_ {
204 for col in 0..t_ {
205 let mut tmp_e = self.params.m_e[2 * row][2 * col];
207 tmp_e.mul_assign(&input[2 * col]);
208 out[2 * row].add_assign(&tmp_e);
209
210 let mut tmp_o = self.params.m_e[2 * row + 1][2 * col + 1];
212 tmp_o.mul_assign(&input[2 * col + 1]);
213 out[2 * row + 1].add_assign(&tmp_o);
214 }
215 }
216 out
217 }
218
219 fn internal_matmul(&self, input: &[S]) -> Vec<S> {
220 let mut out = input.to_owned();
221
222 let mut sum = input[0];
223 input.iter().skip(1).for_each(|el| sum.add_assign(el));
224
225 for (o, mu) in out.iter_mut().zip(self.params.mu.iter()) {
226 o.mul_assign(mu);
227 o.add_assign(&sum);
229 }
230 out
231 }
232
233 pub fn permutation(&self, input: &[S]) -> Vec<S> {
234 let t = self.params.t;
235 assert_eq!(input.len(), t);
236
237 let mut current_state = self.external_matmul(input);
239
240 for r in 0..self.params.rounds_f_beginning {
241 current_state = self.external_round(¤t_state, r);
242 }
243 let p_end = self.params.rounds_f_beginning + self.params.rounds_p;
244 for r in self.params.rounds_f_beginning..p_end {
245 current_state = self.internal_round(¤t_state, r);
246 }
247 for r in p_end..self.params.rounds {
248 current_state = self.external_round(¤t_state, r);
249 }
250
251 current_state
252 }
253}
254
255impl<S: PrimeField> MerkleTreeHash<S> for Neptune<S> {
256 fn compress(&self, input: &[&S]) -> S {
257 self.permutation(&[
258 input[0].to_owned(),
259 input[1].to_owned(),
260 S::zero(),
261 S::zero(),
262 ])[0]
263 }
264}
265
266#[cfg(test)]
267mod neptune_tests_bls12 {
268 use super::*;
269 use crate::{fields::{bls12::FpBLS12, utils}};
270 use crate::neptune::neptune_instances::{
271 NEPTUNE_BLS_4_PARAMS,
272 NEPTUNE_BLS_8_PARAMS,
273 };
274 type Scalar = FpBLS12;
275
276 static TESTRUNS: usize = 5;
277
278 fn matmul(input: &[Scalar], mat: &[Vec<Scalar>]) -> Vec<Scalar> {
279 let t = mat.len();
280 debug_assert!(t == input.len());
281 let mut out = vec![Scalar::from(0); t];
282 for row in 0..t {
283 for (col, inp) in input.iter().enumerate().take(t) {
284 let mut tmp: Scalar = mat[row][col];
285 tmp *= inp;
286 out[row] += tmp;
287 }
288 }
289 out
290 }
291
292 fn build_mi(neptune_params: &Arc<NeptuneParams<Scalar>>) -> Vec<Vec<Scalar>> {
293 let t = neptune_params.t;
294 let mut mi = vec![vec![Scalar::from(1); t]; t];
295 for (i, matrow) in mi.iter_mut().enumerate().take(t) {
296 matrow[i] = neptune_params.mu[i];
297 matrow[i] += Scalar::from(1); }
299 mi
300 }
301
302 fn matmul_equalities(t: usize) {
303 let neptune_params = Arc::new(NeptuneParams::<Scalar>::new(t, 3, 2, 1));
304 let neptune = Neptune::new(&neptune_params);
305 let t = neptune.params.t;
306
307 let me = &neptune_params.m_e;
309 for (row, matrow) in me.iter().enumerate().take(t) {
310 for (col, matrowcol) in matrow.iter().enumerate().take(t) {
311 if (row + col) % 2 == 0 {
312 assert!(*matrowcol != Scalar::from(0));
313 } else {
314 assert_eq!(*matrowcol, Scalar::from(0));
315 }
316 }
317 }
318
319 let mi = build_mi(&neptune_params);
320 for _ in 0..TESTRUNS {
321 let input: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
322 let external1 = neptune.external_matmul(&input);
323 let external2 = matmul(&input, me);
324 assert_eq!(external1, external2);
325
326 let internal1 = neptune.internal_matmul(&input);
327 let internal2 = matmul(&input, &mi);
328 assert_eq!(internal1, internal2);
329 }
330 }
331
332 #[test]
333 fn matmul_equalities_4() {
334 matmul_equalities(4);
335 }
336
337 #[test]
338 fn matmul_equalities_6() {
339 matmul_equalities(6);
340 }
341
342 #[test]
343 fn matmul_equalities_8() {
344 matmul_equalities(8);
345 }
346
347 #[test]
348 fn matmul_equalities_10() {
349 matmul_equalities(10);
350 }
351
352 #[test]
353 fn matmul_equalities_60() {
354 matmul_equalities(60);
355 }
356
357 #[test]
358 fn consistent_perm() {
359 let instances = vec![
360 Neptune::new(&NEPTUNE_BLS_4_PARAMS),
361 Neptune::new(&NEPTUNE_BLS_8_PARAMS),
362 ];
363 for instance in instances {
364 let t = instance.params.t;
365 for _ in 0..TESTRUNS {
366 let input1: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
367
368 let mut input2: Vec<Scalar>;
369 loop {
370 input2 = (0..t).map(|_| utils::random_scalar()).collect();
371 if input1 != input2 {
372 break;
373 }
374 }
375
376 let perm1 = instance.permutation(&input1);
377 let perm2 = instance.permutation(&input1);
378 let perm3 = instance.permutation(&input2);
379 assert_eq!(perm1, perm2);
380 assert_ne!(perm1, perm3);
381 }
382 }
383 }
384}
385
386#[cfg(test)]
387mod neptune_tests_bn256 {
388 use super::*;
389 use crate::{
390 fields::{bn256::FpBN256, utils},
391 neptune::neptune_instances::NEPTUNE_BN_PARAMS,
392 };
393 type Scalar = FpBN256;
394
395 static TESTRUNS: usize = 5;
396
397 fn matmul(input: &[Scalar], mat: &[Vec<Scalar>]) -> Vec<Scalar> {
398 let t = mat.len();
399 debug_assert!(t == input.len());
400 let mut out = vec![Scalar::from(0); t];
401 for row in 0..t {
402 for (col, inp) in input.iter().enumerate().take(t) {
403 let mut tmp = mat[row][col];
404 tmp *= inp;
405 out[row] += tmp;
406 }
407 }
408 out
409 }
410
411 fn build_mi(neptune_params: &Arc<NeptuneParams<Scalar>>) -> Vec<Vec<Scalar>> {
412 let t = neptune_params.t;
413 let mut mi = vec![vec![Scalar::from(1); t]; t];
414 for (i, matrow) in mi.iter_mut().enumerate().take(t) {
415 matrow[i] = neptune_params.mu[i];
416 matrow[i] += Scalar::from(1); }
418 mi
419 }
420
421 fn matmul_equalities(t: usize) {
422 let neptune_params = Arc::new(NeptuneParams::<Scalar>::new(t, 3, 2, 1));
423 let neptune = Neptune::new(&neptune_params);
424 let t = neptune.params.t;
425
426 let me = &neptune_params.m_e;
428 for (row, matrow) in me.iter().enumerate().take(t) {
429 for (col, matrowcol) in matrow.iter().enumerate().take(t) {
430 if (row + col) % 2 == 0 {
431 assert!(*matrowcol != Scalar::from(0));
432 } else {
433 assert_eq!(*matrowcol, Scalar::from(0));
434 }
435 }
436 }
437
438 let mi = build_mi(&neptune_params);
439 for _ in 0..TESTRUNS {
440 let input: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
441 let external1 = neptune.external_matmul(&input);
442 let external2 = matmul(&input, me);
443 assert_eq!(external1, external2);
444
445 let internal1 = neptune.internal_matmul(&input);
446 let internal2 = matmul(&input, &mi);
447 assert_eq!(internal1, internal2);
448 }
449 }
450
451 #[test]
452 fn matmul_equalities_4() {
453 matmul_equalities(4);
454 }
455
456 #[test]
457 fn matmul_equalities_6() {
458 matmul_equalities(6);
459 }
460
461 #[test]
462 fn matmul_equalities_8() {
463 matmul_equalities(8);
464 }
465
466 #[test]
467 fn matmul_equalities_10() {
468 matmul_equalities(10);
469 }
470
471 #[test]
472 fn matmul_equalities_60() {
473 matmul_equalities(60);
474 }
475
476 #[test]
477 fn consistent_perm() {
478 let neptune = Neptune::new(&NEPTUNE_BN_PARAMS);
479 let t = neptune.params.t;
480 for _ in 0..TESTRUNS {
481 let input1: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
482
483 let mut input2: Vec<Scalar>;
484 loop {
485 input2 = (0..t).map(|_| utils::random_scalar()).collect();
486 if input1 != input2 {
487 break;
488 }
489 }
490
491 let perm1 = neptune.permutation(&input1);
492 let perm2 = neptune.permutation(&input1);
493 let perm3 = neptune.permutation(&input2);
494 assert_eq!(perm1, perm2);
495 assert_ne!(perm1, perm3);
496 }
497 }
498}
499
500#[cfg(test)]
501mod neptune_tests_goldilocks {
502 use super::*;
503 use crate::{fields::{goldilocks::FpGoldiLocks, utils}};
504 use crate::neptune::neptune_instances::{
505 NEPTUNE_GOLDILOCKS_8_PARAMS,
506 NEPTUNE_GOLDILOCKS_12_PARAMS,
507 NEPTUNE_GOLDILOCKS_16_PARAMS,
508 NEPTUNE_GOLDILOCKS_20_PARAMS,
509 };
510 type Scalar = FpGoldiLocks;
511
512 static TESTRUNS: usize = 5;
513
514 fn matmul(input: &[Scalar], mat: &[Vec<Scalar>]) -> Vec<Scalar> {
515 let t = mat.len();
516 debug_assert!(t == input.len());
517 let mut out = vec![Scalar::from(0); t];
518 for row in 0..t {
519 for (col, inp) in input.iter().enumerate().take(t) {
520 let mut tmp = mat[row][col];
521 tmp *= inp;
522 out[row] += tmp;
523 }
524 }
525 out
526 }
527
528 fn build_mi(neptune_params: &Arc<NeptuneParams<Scalar>>) -> Vec<Vec<Scalar>> {
529 let t = neptune_params.t;
530 let mut mi = vec![vec![Scalar::from(1); t]; t];
531 for (i, matrow) in mi.iter_mut().enumerate().take(t) {
532 matrow[i] = neptune_params.mu[i];
533 matrow[i] += Scalar::from(1); }
535 mi
536 }
537
538 fn matmul_equalities(t: usize) {
539 let neptune_params = Arc::new(NeptuneParams::<Scalar>::new(t, 3, 2, 1));
540 let neptune = Neptune::new(&neptune_params);
541 let t = neptune.params.t;
542
543 let me = &neptune_params.m_e;
545 for (row, matrow) in me.iter().enumerate().take(t) {
546 for (col, matrowcol) in matrow.iter().enumerate().take(t) {
547 if (row + col) % 2 == 0 {
548 assert!(*matrowcol != Scalar::from(0));
549 } else {
550 assert_eq!(*matrowcol, Scalar::from(0));
551 }
552 }
553 }
554
555 let mi = build_mi(&neptune_params);
556 for _ in 0..TESTRUNS {
557 let input: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
558 let external1 = neptune.external_matmul(&input);
559 let external2 = matmul(&input, me);
560 assert_eq!(external1, external2);
561
562 let internal1 = neptune.internal_matmul(&input);
563 let internal2 = matmul(&input, &mi);
564 assert_eq!(internal1, internal2);
565 }
566 }
567
568 #[test]
569 fn matmul_equalities_4() {
570 matmul_equalities(4);
571 }
572
573 #[test]
574 fn matmul_equalities_6() {
575 matmul_equalities(6);
576 }
577
578 #[test]
579 fn matmul_equalities_8() {
580 matmul_equalities(8);
581 }
582
583 #[test]
584 fn matmul_equalities_10() {
585 matmul_equalities(10);
586 }
587
588 #[test]
589 fn matmul_equalities_60() {
590 matmul_equalities(60);
591 }
592
593 #[test]
594 fn consistent_perm() {
595 let instances = vec![
596 Neptune::new(&NEPTUNE_GOLDILOCKS_8_PARAMS),
597 Neptune::new(&NEPTUNE_GOLDILOCKS_12_PARAMS),
598 Neptune::new(&NEPTUNE_GOLDILOCKS_16_PARAMS),
599 Neptune::new(&NEPTUNE_GOLDILOCKS_20_PARAMS),
600 ];
601 for instance in instances {
602 let t = instance.params.t;
603 for _ in 0..TESTRUNS {
604 let input1: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
605
606 let mut input2: Vec<Scalar>;
607 loop {
608 input2 = (0..t).map(|_| utils::random_scalar()).collect();
609 if input1 != input2 {
610 break;
611 }
612 }
613
614 let perm1 = instance.permutation(&input1);
615 let perm2 = instance.permutation(&input1);
616 let perm3 = instance.permutation(&input2);
617 assert_eq!(perm1, perm2);
618 assert_ne!(perm1, perm3);
619 }
620 }
621 }
622}
623
624#[cfg(test)]
625mod neptune_tests_babybear {
626 use super::*;
627 use crate::{
628 fields::{babybear::FpBabyBear, utils},
629 neptune::neptune_instances::NEPTUNE_BABYBEAR_16_PARAMS,
630 neptune::neptune_instances::NEPTUNE_BABYBEAR_24_PARAMS,
631 };
632 type Scalar = FpBabyBear;
633
634 static TESTRUNS: usize = 5;
635
636 fn matmul(input: &[Scalar], mat: &[Vec<Scalar>]) -> Vec<Scalar> {
637 let t = mat.len();
638 debug_assert!(t == input.len());
639 let mut out = vec![Scalar::from(0); t];
640 for row in 0..t {
641 for (col, inp) in input.iter().enumerate().take(t) {
642 let mut tmp = mat[row][col];
643 tmp *= inp;
644 out[row] += tmp;
645 }
646 }
647 out
648 }
649
650 fn build_mi(neptune_params: &Arc<NeptuneParams<Scalar>>) -> Vec<Vec<Scalar>> {
651 let t = neptune_params.t;
652 let mut mi = vec![vec![Scalar::from(1); t]; t];
653 for (i, matrow) in mi.iter_mut().enumerate().take(t) {
654 matrow[i] = neptune_params.mu[i];
655 matrow[i] += Scalar::from(1); }
657 mi
658 }
659
660 fn matmul_equalities(t: usize) {
661 let neptune_params = Arc::new(NeptuneParams::<Scalar>::new(t, 3, 2, 1));
662 let neptune = Neptune::new(&neptune_params);
663 let t = neptune.params.t;
664
665 let me = &neptune_params.m_e;
667 for (row, matrow) in me.iter().enumerate().take(t) {
668 for (col, matrowcol) in matrow.iter().enumerate().take(t) {
669 if (row + col) % 2 == 0 {
670 assert!(*matrowcol != Scalar::from(0));
671 } else {
672 assert_eq!(*matrowcol, Scalar::from(0));
673 }
674 }
675 }
676
677 let mi = build_mi(&neptune_params);
678 for _ in 0..TESTRUNS {
679 let input: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
680 let external1 = neptune.external_matmul(&input);
681 let external2 = matmul(&input, me);
682 assert_eq!(external1, external2);
683
684 let internal1 = neptune.internal_matmul(&input);
685 let internal2 = matmul(&input, &mi);
686 assert_eq!(internal1, internal2);
687 }
688 }
689
690 #[test]
691 fn matmul_equalities_4() {
692 matmul_equalities(4);
693 }
694
695 #[test]
696 fn matmul_equalities_6() {
697 matmul_equalities(6);
698 }
699
700 #[test]
701 fn matmul_equalities_8() {
702 matmul_equalities(8);
703 }
704
705 #[test]
706 fn matmul_equalities_10() {
707 matmul_equalities(10);
708 }
709
710 #[test]
711 fn matmul_equalities_60() {
712 matmul_equalities(60);
713 }
714
715 #[test]
716 fn consistent_perm() {
717 let instances = vec![
718 Neptune::new(&NEPTUNE_BABYBEAR_16_PARAMS),
719 Neptune::new(&NEPTUNE_BABYBEAR_24_PARAMS),
720 ];
721 for instance in instances {
722 let t = instance.params.t;
723 for _ in 0..TESTRUNS {
724 let input1: Vec<Scalar> = (0..t).map(|_| utils::random_scalar()).collect();
725
726 let mut input2: Vec<Scalar>;
727 loop {
728 input2 = (0..t).map(|_| utils::random_scalar()).collect();
729 if input1 != input2 {
730 break;
731 }
732 }
733
734 let perm1 = instance.permutation(&input1);
735 let perm2 = instance.permutation(&input1);
736 let perm3 = instance.permutation(&input2);
737 assert_eq!(perm1, perm2);
738 assert_ne!(perm1, perm3);
739 }
740 }
741 }
742}