poseidon_primitives/poseidon/primitives/
mds.rs

1use ff::FromUniformBytes;
2
3use super::{grain::Grain, Mds};
4
5pub(super) fn generate_mds<F: Ord + FromUniformBytes<64>, const T: usize>(
6    grain: &mut Grain<F>,
7    mut select: usize,
8) -> (Mds<F, T>, Mds<F, T>) {
9    let (xs, ys, mds) = loop {
10        // Generate two [F; T] arrays of unique field elements.
11        let (xs, ys) = loop {
12            let mut vals: Vec<_> = (0..2 * T)
13                .map(|_| grain.next_field_element_without_rejection())
14                .collect();
15
16            // Check that we have unique field elements.
17            let mut unique = vals.clone();
18            unique.sort_unstable();
19            unique.dedup();
20            if vals.len() == unique.len() {
21                let rhs = vals.split_off(T);
22                break (vals, rhs);
23            }
24        };
25
26        // We need to ensure that the MDS is secure. Instead of checking the MDS against
27        // the relevant algorithms directly, we witness a fixed number of MDS matrices
28        // that we need to sample from the given Grain state before obtaining a secure
29        // matrix. This can be determined out-of-band via the reference implementation in
30        // Sage.
31        if select != 0 {
32            select -= 1;
33            continue;
34        }
35
36        // Generate a Cauchy matrix, with elements a_ij in the form:
37        //     a_ij = 1/(x_i + y_j); x_i + y_j != 0
38        //
39        // It would be much easier to use the alternate definition:
40        //     a_ij = 1/(x_i - y_j); x_i - y_j != 0
41        //
42        // These are clearly equivalent on `y <- -y`, but it is easier to work with the
43        // negative formulation, because ensuring that xs ∪ ys is unique implies that
44        // x_i - y_j != 0 by construction (whereas the positive case does not hold). It
45        // also makes computation of the matrix inverse simpler below (the theorem used
46        // was formulated for the negative definition).
47        //
48        // However, the Poseidon paper and reference impl use the positive formulation,
49        // and we want to rely on the reference impl for MDS security, so we use the same
50        // formulation.
51        let mut mds = [[F::ZERO; T]; T];
52        #[allow(clippy::needless_range_loop)]
53        for i in 0..T {
54            for j in 0..T {
55                let sum = xs[i] + ys[j];
56                // We leverage the secure MDS selection counter to also check this.
57                assert!(!sum.is_zero_vartime());
58                mds[i][j] = sum.invert().unwrap();
59            }
60        }
61
62        break (xs, ys, mds);
63    };
64
65    // Compute the inverse. All square Cauchy matrices have a non-zero determinant and
66    // thus are invertible. The inverse for a Cauchy matrix of the form:
67    //
68    //     a_ij = 1/(x_i - y_j); x_i - y_j != 0
69    //
70    // has elements b_ij given by:
71    //
72    //     b_ij = (x_j - y_i) A_j(y_i) B_i(x_j)    (Schechter 1959, Theorem 1)
73    //
74    // where A_i(x) and B_i(x) are the Lagrange polynomials for xs and ys respectively.
75    //
76    // We adapt this to the positive Cauchy formulation by negating ys.
77    let mut mds_inv = [[F::ZERO; T]; T];
78    let l = |xs: &[F], j, x: F| {
79        let x_j = xs[j];
80        xs.iter().enumerate().fold(F::ONE, |acc, (m, x_m)| {
81            if m == j {
82                acc
83            } else {
84                let denominator: F = x_j - x_m;
85                // We can invert freely; by construction, the elements of xs are distinct.
86                acc * (x - x_m) * denominator.invert().unwrap()
87            }
88        })
89    };
90    let neg_ys: Vec<_> = ys.iter().map(|y| -*y).collect();
91    for i in 0..T {
92        for j in 0..T {
93            mds_inv[i][j] = (xs[j] - neg_ys[i]) * l(&xs, j, neg_ys[i]) * l(&neg_ys, i, xs[j]);
94        }
95    }
96
97    (mds, mds_inv)
98}
99
100#[cfg(test)]
101mod tests {
102    use super::super::pasta::Fp;
103    use super::{generate_mds, Grain};
104
105    use ff::Field;
106
107    #[test]
108    fn poseidon_mds() {
109        const T: usize = 3;
110        let mut grain = Grain::new(super::super::grain::SboxType::Pow, T as u16, 8, 56);
111        let (mds, mds_inv) = generate_mds::<Fp, T>(&mut grain, 0);
112
113        // Verify that MDS * MDS^-1 = I.
114        #[allow(clippy::needless_range_loop)]
115        for i in 0..T {
116            for j in 0..T {
117                let expected = if i == j { Fp::ONE } else { Fp::ZERO };
118                assert_eq!(
119                    (0..T).fold(Fp::ZERO, |acc, k| acc + (mds[i][k] * mds_inv[k][j])),
120                    expected
121                );
122            }
123        }
124    }
125}