p3_mds/
coset_mds.rs

1use alloc::vec::Vec;
2
3use p3_field::{Algebra, Field, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::MdsPermutation;
8use crate::butterflies::{bowers_g_layer, bowers_g_t_layer};
9
10/// A Reed-Solomon based MDS permutation.
11///
12/// An MDS permutation which works by interpreting the input as evaluations of a polynomial over a
13/// power-of-two subgroup, and computing evaluations over a coset of that subgroup. This can be
14/// viewed as returning the parity elements of a systematic Reed-Solomon code. Since Reed-Solomon
15/// codes are MDS, this is an MDS permutation.
16#[derive(Clone, Debug)]
17pub struct CosetMds<F, const N: usize> {
18    fft_twiddles: Vec<F>,
19    ifft_twiddles: Vec<F>,
20    weights: [F; N],
21}
22
23impl<F, const N: usize> Default for CosetMds<F, N>
24where
25    F: TwoAdicField,
26{
27    fn default() -> Self {
28        let log_n = log2_strict_usize(N);
29
30        let root = F::two_adic_generator(log_n);
31        let root_inv = root.inverse();
32        let mut fft_twiddles: Vec<F> = root.powers().collect_n(N / 2);
33        let mut ifft_twiddles: Vec<F> = root_inv.powers().collect_n(N / 2);
34        reverse_slice_index_bits(&mut fft_twiddles);
35        reverse_slice_index_bits(&mut ifft_twiddles);
36
37        let shift = F::GENERATOR;
38        let mut weights: [F; N] = shift.powers().collect_n(N).try_into().unwrap();
39        reverse_slice_index_bits(&mut weights);
40        Self {
41            fft_twiddles,
42            ifft_twiddles,
43            weights,
44        }
45    }
46}
47
48impl<F: TwoAdicField, A: Algebra<F>, const N: usize> Permutation<[A; N]> for CosetMds<F, N> {
49    fn permute_mut(&self, values: &mut [A; N]) {
50        // Inverse DFT, except we skip bit reversal and rescaling by 1/N.
51        bowers_g_t(values, &self.ifft_twiddles);
52
53        // Multiply by powers of the coset shift (see default coset LDE impl for an explanation)
54        for (value, weight) in values.iter_mut().zip(self.weights) {
55            *value = value.clone() * weight;
56        }
57
58        // DFT, assuming bit-reversed input.
59        bowers_g(values, &self.fft_twiddles);
60    }
61}
62
63impl<F: TwoAdicField, A: Algebra<F>, const N: usize> MdsPermutation<A, N> for CosetMds<F, N> {}
64
65/// Executes the Bowers G network. This is like a DFT, except it assumes the input is in
66/// bit-reversed order.
67#[inline]
68fn bowers_g<F: Field, A: Algebra<F>, const N: usize>(values: &mut [A; N], twiddles: &[F]) {
69    let log_n = log2_strict_usize(N);
70    for log_half_block_size in 0..log_n {
71        bowers_g_layer(values, log_half_block_size, twiddles);
72    }
73}
74
75/// Executes the Bowers G^T network. This is like an inverse DFT, except we skip rescaling by
76/// `1/N`, and the output is bit-reversed.
77#[inline]
78fn bowers_g_t<F: Field, A: Algebra<F>, const N: usize>(values: &mut [A; N], twiddles: &[F]) {
79    let log_n = log2_strict_usize(N);
80    for log_half_block_size in (0..log_n).rev() {
81        bowers_g_t_layer(values, log_half_block_size, twiddles);
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use core::array;
88
89    use p3_baby_bear::BabyBear;
90    use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
91    use p3_field::TwoAdicField;
92    use p3_goldilocks::Goldilocks;
93    use p3_symmetric::Permutation;
94    use rand::distr::{Distribution, StandardUniform};
95    use rand::rngs::SmallRng;
96    use rand::{Rng, SeedableRng};
97
98    use crate::coset_mds::CosetMds;
99
100    fn matches_naive_for<F, const N: usize>()
101    where
102        F: TwoAdicField,
103        StandardUniform: Distribution<F>,
104    {
105        let mut rng = SmallRng::seed_from_u64(1);
106        let mut arr: [F; N] = array::from_fn(|_| rng.random());
107
108        let shift = F::GENERATOR;
109        let mut coset_lde_naive = NaiveDft.coset_lde(arr.to_vec(), 0, shift);
110
111        let scale = F::from_usize(N);
112        coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
113
114        CosetMds::<F, N>::default().permute_mut(&mut arr);
115        assert_eq!(coset_lde_naive, arr);
116    }
117
118    macro_rules! matches_naive_test {
119        ($name:ident, $field:ty, $n:expr) => {
120            #[test]
121            fn $name() {
122                matches_naive_for::<$field, $n>();
123            }
124        };
125    }
126
127    matches_naive_test!(matches_naive_baby_bear_1, BabyBear, 1);
128    matches_naive_test!(matches_naive_baby_bear_2, BabyBear, 2);
129    matches_naive_test!(matches_naive_baby_bear_4, BabyBear, 4);
130    matches_naive_test!(matches_naive_baby_bear_8, BabyBear, 8);
131    matches_naive_test!(matches_naive_baby_bear_16, BabyBear, 16);
132    matches_naive_test!(matches_naive_baby_bear_32, BabyBear, 32);
133
134    matches_naive_test!(matches_naive_goldilocks_1, Goldilocks, 1);
135    matches_naive_test!(matches_naive_goldilocks_2, Goldilocks, 2);
136    matches_naive_test!(matches_naive_goldilocks_4, Goldilocks, 4);
137    matches_naive_test!(matches_naive_goldilocks_8, Goldilocks, 8);
138    matches_naive_test!(matches_naive_goldilocks_16, Goldilocks, 16);
139    matches_naive_test!(matches_naive_goldilocks_32, Goldilocks, 32);
140}