zkhash/poseidon2/
poseidon2_params.rs
1use ark_ff::PrimeField;
2
3use crate::utils;
4
5#[derive(Clone, Debug)]
6pub struct Poseidon2Params<F: PrimeField> {
7 pub(crate) t: usize, pub(crate) d: usize, pub(crate) rounds_f_beginning: usize,
10 pub(crate) rounds_p: usize,
11 #[allow(dead_code)]
12 pub(crate) rounds_f_end: usize,
13 pub(crate) rounds: usize,
14 pub(crate) mat_internal_diag_m_1: Vec<F>,
15 pub(crate) _mat_internal: Vec<Vec<F>>,
16 pub(crate) round_constants: Vec<Vec<F>>,
17}
18
19impl<F: PrimeField> Poseidon2Params<F> {
20 #[allow(clippy::too_many_arguments)]
21
22 pub const INIT_SHAKE: &'static str = "Poseidon2";
23
24 pub fn new(
25 t: usize,
26 d: usize,
27 rounds_f: usize,
28 rounds_p: usize,
29 mat_internal_diag_m_1: &[F],
30 mat_internal: &[Vec<F>],
31 round_constants: &[Vec<F>],
32 ) -> Self {
33 assert!(d == 3 || d == 5 || d == 7 || d == 11);
34 assert_eq!(rounds_f % 2, 0);
35 let r = rounds_f / 2;
36 let rounds = rounds_f + rounds_p;
37
38 Poseidon2Params {
39 t,
40 d,
41 rounds_f_beginning: r,
42 rounds_p,
43 rounds_f_end: r,
44 rounds,
45 mat_internal_diag_m_1: mat_internal_diag_m_1.to_owned(),
46 _mat_internal: mat_internal.to_owned(),
47 round_constants: round_constants.to_owned(),
48 }
49 }
50
51 pub fn equivalent_round_constants(
53 round_constants: &[Vec<F>],
54 mat_internal: &[Vec<F>],
55 rounds_f_beginning: usize,
56 rounds_p: usize,
57 ) -> Vec<Vec<F>> {
58 let mut opt = vec![Vec::new(); rounds_p + 1];
59 let mat_internal_inv = utils::mat_inverse(mat_internal);
60
61 let p_end = rounds_f_beginning + rounds_p - 1;
62 let mut tmp = round_constants[p_end].clone();
63 for i in (0..rounds_p - 1).rev() {
64 let inv_cip = Self::mat_vec_mul(&mat_internal_inv, &tmp);
65 opt[i + 1] = vec![inv_cip[0]];
66 tmp = round_constants[rounds_f_beginning + i].clone();
67 for i in 1..inv_cip.len() {
68 tmp[i].add_assign(&inv_cip[i]);
69 }
70 }
71 opt[0] = tmp;
72 opt[rounds_p] = vec![F::zero(); opt[0].len()]; opt
75 }
76
77 pub fn mat_vec_mul(mat: &[Vec<F>], input: &[F]) -> Vec<F> {
78 let t = mat.len();
79 debug_assert!(t == input.len());
80 let mut out = vec![F::zero(); t];
81 for row in 0..t {
82 for (col, inp) in input.iter().enumerate() {
83 let mut tmp = mat[row][col];
84 tmp.mul_assign(inp);
85 out[row].add_assign(&tmp);
86 }
87 }
88 out
89 }
90
91}