halo2_base/poseidon/hasher/
mds.rs
1#![allow(clippy::needless_range_loop)]
2use getset::Getters;
3
4use crate::ff::PrimeField;
5
6pub(crate) type Mds<F, const T: usize> = [[F; T]; T];
8
9#[derive(Debug, Clone, Getters)]
13pub struct MDSMatrices<F: PrimeField, const T: usize, const RATE: usize> {
14 #[getset(get = "pub")]
16 pub(crate) mds: MDSMatrix<F, T, RATE>,
17 #[getset(get = "pub")]
19 pub(crate) pre_sparse_mds: MDSMatrix<F, T, RATE>,
20 #[getset(get = "pub")]
22 pub(crate) sparse_matrices: Vec<SparseMDSMatrix<F, T, RATE>>,
23}
24
25#[derive(Debug, Clone, Getters)]
28pub struct SparseMDSMatrix<F: PrimeField, const T: usize, const RATE: usize> {
29 #[getset(get = "pub")]
31 pub(crate) row: [F; T],
32 #[getset(get = "pub")]
34 pub(crate) col_hat: [F; RATE],
35}
36
37#[derive(Clone, Debug)]
39pub struct MDSMatrix<F, const T: usize, const RATE: usize>(pub(crate) Mds<F, T>);
40
41impl<F, const T: usize, const RATE: usize> AsRef<Mds<F, T>> for MDSMatrix<F, T, RATE> {
42 fn as_ref(&self) -> &Mds<F, T> {
43 &self.0
44 }
45}
46
47impl<F: PrimeField, const T: usize, const RATE: usize> MDSMatrix<F, T, RATE> {
48 pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] {
49 let mut res = [F::ZERO; T];
50 for i in 0..T {
51 for j in 0..T {
52 res[i] += self.0[i][j] * v[j];
53 }
54 }
55 res
56 }
57
58 pub(crate) fn identity() -> Mds<F, T> {
59 let mut mds = [[F::ZERO; T]; T];
60 for i in 0..T {
61 mds[i][i] = F::ONE;
62 }
63 mds
64 }
65
66 pub(crate) fn mul(&self, other: &Self) -> Self {
68 let mut res = [[F::ZERO; T]; T];
69 for i in 0..T {
70 for j in 0..T {
71 for k in 0..T {
72 res[i][j] += self.0[i][k] * other.0[k][j];
73 }
74 }
75 }
76 Self(res)
77 }
78
79 pub(crate) fn transpose(&self) -> Self {
80 let mut res = [[F::ZERO; T]; T];
81 for i in 0..T {
82 for j in 0..T {
83 res[i][j] = self.0[j][i];
84 }
85 }
86 Self(res)
87 }
88
89 pub(crate) fn determinant<const N: usize>(m: [[F; N]; N]) -> F {
90 let mut res = F::ONE;
91 let mut m = m;
92 for i in 0..N {
93 let mut pivot = i;
94 while m[pivot][i] == F::ZERO {
95 pivot += 1;
96 assert!(pivot < N, "matrix is not invertible");
97 }
98 if pivot != i {
99 res = -res;
100 m.swap(pivot, i);
101 }
102 res *= m[i][i];
103 let inv = m[i][i].invert().unwrap();
104 for j in i + 1..N {
105 let factor = m[j][i] * inv;
106 for k in i + 1..N {
107 m[j][k] -= m[i][k] * factor;
108 }
109 }
110 }
111 res
112 }
113
114 pub(crate) fn factorise(&self) -> (Self, SparseMDSMatrix<F, T, RATE>) {
119 assert_eq!(RATE + 1, T);
120 let prime = |hat: Mds<F, RATE>| -> Self {
123 let mut prime = Self::identity();
124 for (prime_row, hat_row) in prime.iter_mut().skip(1).zip(hat.iter()) {
125 for (el_prime, el_hat) in prime_row.iter_mut().skip(1).zip(hat_row.iter()) {
126 *el_prime = *el_hat;
127 }
128 }
129 Self(prime)
130 };
131
132 let prime_prime = |w_hat: [F; RATE]| -> Mds<F, T> {
135 let mut prime_prime = Self::identity();
136 prime_prime[0] = self.0[0];
137 for (row, w) in prime_prime.iter_mut().skip(1).zip(w_hat.iter()) {
138 row[0] = *w
139 }
140 prime_prime
141 };
142
143 let w = self.0.iter().skip(1).map(|row| row[0]).collect::<Vec<_>>();
144 let mut m_hat = [[F::ZERO; RATE]; RATE];
146 for i in 0..RATE {
147 for j in 0..RATE {
148 m_hat[i][j] = self.0[i + 1][j + 1];
149 }
150 }
151 let mut w_hat = [F::ZERO; RATE];
154 let det = Self::determinant(m_hat);
155 let det_inv = Option::<F>::from(det.invert()).expect("matrix is not invertible");
156 for j in 0..RATE {
157 let mut m_hat_j = m_hat;
158 for i in 0..RATE {
159 m_hat_j[i][j] = w[i];
160 }
161 w_hat[j] = Self::determinant(m_hat_j) * det_inv;
162 }
163 let m_prime = prime(m_hat);
164 let m_prime_prime = prime_prime(w_hat);
165 let row: [F; T] =
167 m_prime_prime.iter().map(|row| row[0]).collect::<Vec<_>>().try_into().unwrap();
168 let col_hat: [F; RATE] = m_prime_prime[0][1..].try_into().unwrap();
170 (m_prime, SparseMDSMatrix { row, col_hat })
171 }
172}