halo2_base/poseidon/hasher/
mds.rs

1#![allow(clippy::needless_range_loop)]
2use getset::Getters;
3
4use crate::ff::PrimeField;
5
6/// The type used to hold the MDS matrix
7pub(crate) type Mds<F, const T: usize> = [[F; T]; T];
8
9/// `MDSMatrices` holds the MDS matrix as well as transition matrix which is
10/// also called `pre_sparse_mds` and sparse matrices that enables us to reduce
11/// number of multiplications in apply MDS step
12#[derive(Debug, Clone, Getters)]
13pub struct MDSMatrices<F: PrimeField, const T: usize, const RATE: usize> {
14    /// MDS matrix
15    #[getset(get = "pub")]
16    pub(crate) mds: MDSMatrix<F, T, RATE>,
17    /// Transition matrix
18    #[getset(get = "pub")]
19    pub(crate) pre_sparse_mds: MDSMatrix<F, T, RATE>,
20    /// Sparse matrices
21    #[getset(get = "pub")]
22    pub(crate) sparse_matrices: Vec<SparseMDSMatrix<F, T, RATE>>,
23}
24
25/// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear
26/// layer of partial rounds instead of the original MDS
27#[derive(Debug, Clone, Getters)]
28pub struct SparseMDSMatrix<F: PrimeField, const T: usize, const RATE: usize> {
29    /// row
30    #[getset(get = "pub")]
31    pub(crate) row: [F; T],
32    /// column transpose
33    #[getset(get = "pub")]
34    pub(crate) col_hat: [F; RATE],
35}
36
37/// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon
38#[derive(Clone, Debug)]
39pub struct MDSMatrix<F, const T: usize, const RATE: usize>(pub(crate) Mds<F, T>);
40
41impl<F, const T: usize, const RATE: usize> AsRef<Mds<F, T>> for MDSMatrix<F, T, RATE> {
42    fn as_ref(&self) -> &Mds<F, T> {
43        &self.0
44    }
45}
46
47impl<F: PrimeField, const T: usize, const RATE: usize> MDSMatrix<F, T, RATE> {
48    pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] {
49        let mut res = [F::ZERO; T];
50        for i in 0..T {
51            for j in 0..T {
52                res[i] += self.0[i][j] * v[j];
53            }
54        }
55        res
56    }
57
58    pub(crate) fn identity() -> Mds<F, T> {
59        let mut mds = [[F::ZERO; T]; T];
60        for i in 0..T {
61            mds[i][i] = F::ONE;
62        }
63        mds
64    }
65
66    /// Multiplies two MDS matrices. Used in sparse matrix calculations
67    pub(crate) fn mul(&self, other: &Self) -> Self {
68        let mut res = [[F::ZERO; T]; T];
69        for i in 0..T {
70            for j in 0..T {
71                for k in 0..T {
72                    res[i][j] += self.0[i][k] * other.0[k][j];
73                }
74            }
75        }
76        Self(res)
77    }
78
79    pub(crate) fn transpose(&self) -> Self {
80        let mut res = [[F::ZERO; T]; T];
81        for i in 0..T {
82            for j in 0..T {
83                res[i][j] = self.0[j][i];
84            }
85        }
86        Self(res)
87    }
88
89    pub(crate) fn determinant<const N: usize>(m: [[F; N]; N]) -> F {
90        let mut res = F::ONE;
91        let mut m = m;
92        for i in 0..N {
93            let mut pivot = i;
94            while m[pivot][i] == F::ZERO {
95                pivot += 1;
96                assert!(pivot < N, "matrix is not invertible");
97            }
98            if pivot != i {
99                res = -res;
100                m.swap(pivot, i);
101            }
102            res *= m[i][i];
103            let inv = m[i][i].invert().unwrap();
104            for j in i + 1..N {
105                let factor = m[j][i] * inv;
106                for k in i + 1..N {
107                    m[j][k] -= m[i][k] * factor;
108                }
109            }
110        }
111        res
112    }
113
114    /// See Section B in Supplementary Material https://eprint.iacr.org/2019/458.pdf
115    /// Factorises an MDS matrix `M` into `M'` and `M''` where `M = M' *  M''`.
116    /// Resulted `M''` matrices are the sparse ones while `M'` will contribute
117    /// to the accumulator of the process
118    pub(crate) fn factorise(&self) -> (Self, SparseMDSMatrix<F, T, RATE>) {
119        assert_eq!(RATE + 1, T);
120        // Given `(t-1 * t-1)` MDS matrix called `hat` constructs the `t * t` matrix in
121        // form `[[1 | 0], [0 | m]]`, ie `hat` is the right bottom sub-matrix
122        let prime = |hat: Mds<F, RATE>| -> Self {
123            let mut prime = Self::identity();
124            for (prime_row, hat_row) in prime.iter_mut().skip(1).zip(hat.iter()) {
125                for (el_prime, el_hat) in prime_row.iter_mut().skip(1).zip(hat_row.iter()) {
126                    *el_prime = *el_hat;
127                }
128            }
129            Self(prime)
130        };
131
132        // Given `(t-1)` sized `w_hat` vector constructs the matrix in form
133        // `[[m_0_0 | m_0_i], [w_hat | identity]]`
134        let prime_prime = |w_hat: [F; RATE]| -> Mds<F, T> {
135            let mut prime_prime = Self::identity();
136            prime_prime[0] = self.0[0];
137            for (row, w) in prime_prime.iter_mut().skip(1).zip(w_hat.iter()) {
138                row[0] = *w
139            }
140            prime_prime
141        };
142
143        let w = self.0.iter().skip(1).map(|row| row[0]).collect::<Vec<_>>();
144        // m_hat is the `(t-1 * t-1)` right bottom sub-matrix of m := self.0
145        let mut m_hat = [[F::ZERO; RATE]; RATE];
146        for i in 0..RATE {
147            for j in 0..RATE {
148                m_hat[i][j] = self.0[i + 1][j + 1];
149            }
150        }
151        // w_hat = m_hat^{-1} * w, where m_hat^{-1} is matrix inverse and * is matrix mult
152        // we avoid computing m_hat^{-1} explicitly by using Cramer's rule: https://en.wikipedia.org/wiki/Cramer%27s_rule
153        let mut w_hat = [F::ZERO; RATE];
154        let det = Self::determinant(m_hat);
155        let det_inv = Option::<F>::from(det.invert()).expect("matrix is not invertible");
156        for j in 0..RATE {
157            let mut m_hat_j = m_hat;
158            for i in 0..RATE {
159                m_hat_j[i][j] = w[i];
160            }
161            w_hat[j] = Self::determinant(m_hat_j) * det_inv;
162        }
163        let m_prime = prime(m_hat);
164        let m_prime_prime = prime_prime(w_hat);
165        // row = first row of m_prime_prime.transpose() = first column of m_prime_prime
166        let row: [F; T] =
167            m_prime_prime.iter().map(|row| row[0]).collect::<Vec<_>>().try_into().unwrap();
168        // col_hat = first column of m_prime_prime.transpose() without first element = first row of m_prime_prime without first element
169        let col_hat: [F; RATE] = m_prime_prime[0][1..].try_into().unwrap();
170        (m_prime, SparseMDSMatrix { row, col_hat })
171    }
172}