zkhash/poseidon/
poseidon_params.rs

1use ark_ff::PrimeField;
2
3use crate::utils;
4
5#[derive(Clone, Debug)]
6pub struct PoseidonParams<S: PrimeField> {
7    pub(crate) t: usize, // statesize
8    pub(crate) d: usize, // sbox degree
9    pub(crate) rounds_f_beginning: usize,
10    pub(crate) rounds_p: usize,
11    #[allow(dead_code)]
12    pub(crate) rounds_f_end: usize,
13    pub(crate) rounds: usize,
14    pub(crate) mds: Vec<Vec<S>>,
15    pub(crate) round_constants: Vec<Vec<S>>,
16    pub(crate) opt_round_constants: Vec<Vec<S>>, // optimized
17    pub(crate) w_hat: Vec<Vec<S>>,               // optimized
18    pub(crate) v: Vec<Vec<S>>,                   // optimized
19    pub(crate) m_i: Vec<Vec<S>>,                 // optimized
20}
21
22impl<S: PrimeField> PoseidonParams<S> {
23    #[allow(clippy::too_many_arguments)]
24    pub fn new(
25        t: usize,
26        d: usize,
27        rounds_f: usize,
28        rounds_p: usize,
29        mds: &[Vec<S>],
30        round_constants: &[Vec<S>],
31    ) -> Self {
32        assert!(d == 3 || d == 5 || d == 7);
33        assert_eq!(mds.len(), t);
34        assert_eq!(rounds_f % 2, 0);
35        let r = rounds_f / 2;
36        let rounds = rounds_f + rounds_p;
37
38        let (m_i_, v_, w_hat_) = Self::equivalent_matrices(mds, t, rounds_p);
39        let opt_round_constants_ = Self::equivalent_round_constants(round_constants, mds, r, rounds_p);
40
41        PoseidonParams {
42            t,
43            d,
44            rounds_f_beginning: r,
45            rounds_p,
46            rounds_f_end: r,
47            rounds,
48            mds: mds.to_owned(),
49            round_constants: round_constants.to_owned(),
50            opt_round_constants: opt_round_constants_,
51            w_hat: w_hat_,
52            v: v_,
53            m_i: m_i_,
54        }
55    }
56
57    #[allow(clippy::type_complexity)]
58    pub fn equivalent_matrices(
59        mds: &[Vec<S>],
60        t: usize,
61        rounds_p: usize,
62    ) -> (Vec<Vec<S>>, Vec<Vec<S>>, Vec<Vec<S>>) {
63        let mut w_hat = Vec::with_capacity(rounds_p);
64        let mut v = Vec::with_capacity(rounds_p);
65        let mut m_i = vec![vec![S::zero(); t]; t];
66
67        let mds_ = utils::mat_transpose(mds);
68        let mut m_mul = mds_.clone();
69
70        for _ in 0..rounds_p {
71            // calc m_hat, w and v
72            let mut m_hat = vec![vec![S::zero(); t - 1]; t - 1];
73            let mut w = vec![S::zero(); t - 1];
74            let mut v_ = vec![S::zero(); t - 1];
75            v_[..(t - 1)].clone_from_slice(&m_mul[0][1..t]);
76            for row in 1..t {
77                for col in 1..t {
78                    m_hat[row - 1][col - 1] = m_mul[row][col];
79                }
80                w[row - 1] = m_mul[row][0];
81            }
82            // calc_w_hat
83            let m_hat_inv = utils::mat_inverse(&m_hat);
84            let w_hat_ = Self::mat_vec_mul(&m_hat_inv, &w);
85
86            w_hat.push(w_hat_);
87            v.push(v_);
88
89            // update m_i
90            m_i = m_mul.clone();
91            m_i[0][0] = S::one();
92            for i in 1..t {
93                m_i[0][i] = S::zero();
94                m_i[i][0] = S::zero();
95            }
96            m_mul = Self::mat_mat_mul(&mds_, &m_i);
97        }
98
99        (utils::mat_transpose(&m_i), v, w_hat)
100    }
101
102    pub fn equivalent_round_constants(
103        round_constants: &[Vec<S>],
104        mds: &[Vec<S>],
105        rounds_f_beginning: usize,
106        rounds_p: usize,
107    ) -> Vec<Vec<S>> {
108        let mut opt = vec![Vec::new(); rounds_p];
109        let mds_inv = utils::mat_inverse(mds);
110
111        let p_end = rounds_f_beginning + rounds_p - 1;
112        let mut tmp = round_constants[p_end].clone();
113        for i in (0..rounds_p - 1).rev() {
114            let inv_cip = Self::mat_vec_mul(&mds_inv, &tmp);
115            opt[i + 1] = vec![inv_cip[0]];
116            tmp = round_constants[rounds_f_beginning + i].clone();
117            for i in 1..inv_cip.len() {
118                tmp[i].add_assign(&inv_cip[i]);
119            }
120        }
121        opt[0] = tmp;
122
123        opt
124    }
125
126    pub fn mat_vec_mul(mat: &[Vec<S>], input: &[S]) -> Vec<S> {
127        let t = mat.len();
128        debug_assert!(t == input.len());
129        let mut out = vec![S::zero(); t];
130        for row in 0..t {
131            for (col, inp) in input.iter().enumerate() {
132                let mut tmp = mat[row][col];
133                tmp.mul_assign(inp);
134                out[row].add_assign(&tmp);
135            }
136        }
137        out
138    }
139
140    pub fn mat_mat_mul(mat1: &[Vec<S>], mat2: &[Vec<S>]) -> Vec<Vec<S>> {
141        let t = mat1.len();
142        let mut out = vec![vec![S::zero(); t]; t];
143        for row in 0..t {
144            for col1 in 0..t {
145                for (col2, m2) in mat2.iter().enumerate() {
146                    let mut tmp = mat1[row][col2];
147                    tmp.mul_assign(&m2[col1]);
148                    out[row][col1].add_assign(&tmp);
149                }
150            }
151        }
152        out
153    }
154
155}