p3_mds/
integrated_coset_mds.rs

1use alloc::vec::Vec;
2
3use p3_field::{Algebra, Field, Powers, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::MdsPermutation;
8use crate::butterflies::{bowers_g_layer, bowers_g_t_layer_integrated};
9
10/// Like `CosetMds`, with a few differences:
11/// - (Bit reversed, a la Bowers) DIF + DIT rather than DIT + DIF
12/// - We skip bit reversals of the inputs and outputs
13/// - We don't weight by `1/N`, since this doesn't affect the MDS property
14/// - We integrate the coset shifts into the DIF's twiddle factors
15#[derive(Clone, Debug)]
16pub struct IntegratedCosetMds<F, const N: usize> {
17    ifft_twiddles: Vec<F>,
18    fft_twiddles: Vec<Vec<F>>,
19}
20
21impl<F: TwoAdicField, const N: usize> Default for IntegratedCosetMds<F, N> {
22    fn default() -> Self {
23        let log_n = log2_strict_usize(N);
24        let root = F::two_adic_generator(log_n);
25        let root_inv = root.inverse();
26        let coset_shift = F::GENERATOR;
27
28        let mut ifft_twiddles = root_inv.powers().collect_n(N / 2);
29        reverse_slice_index_bits(&mut ifft_twiddles);
30
31        let fft_twiddles: Vec<Vec<F>> = (0..log_n)
32            .map(|layer| {
33                let shift_power = coset_shift.exp_power_of_2(layer);
34                let powers = Powers {
35                    base: root.exp_power_of_2(layer),
36                    current: shift_power,
37                };
38                let mut twiddles = powers.collect_n(N >> (layer + 1));
39                reverse_slice_index_bits(&mut twiddles);
40                twiddles
41            })
42            .collect();
43
44        Self {
45            ifft_twiddles,
46            fft_twiddles,
47        }
48    }
49}
50
51impl<F: Field, A: Algebra<F>, const N: usize> Permutation<[A; N]> for IntegratedCosetMds<F, N> {
52    fn permute_mut(&self, values: &mut [A; N]) {
53        let log_n = log2_strict_usize(N);
54
55        // Bit-reversed DIF, aka Bowers G
56        for layer in 0..log_n {
57            bowers_g_layer(values, layer, &self.ifft_twiddles);
58        }
59
60        // Bit-reversed DIT, aka Bowers G^T
61        for layer in (0..log_n).rev() {
62            bowers_g_t_layer_integrated(values, layer, &self.fft_twiddles[layer]);
63        }
64    }
65}
66
67impl<F: Field, A: Algebra<F>, const N: usize> MdsPermutation<A, N> for IntegratedCosetMds<F, N> {}
68
69#[cfg(test)]
70mod tests {
71    use core::array;
72
73    use p3_baby_bear::BabyBear;
74    use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
75    use p3_field::TwoAdicField;
76    use p3_goldilocks::Goldilocks;
77    use p3_symmetric::Permutation;
78    use p3_util::reverse_slice_index_bits;
79    use rand::distr::{Distribution, StandardUniform};
80    use rand::rngs::SmallRng;
81    use rand::{Rng, SeedableRng};
82
83    use crate::integrated_coset_mds::IntegratedCosetMds;
84
85    fn matches_naive_for<F, const N: usize>()
86    where
87        F: TwoAdicField,
88        StandardUniform: Distribution<F>,
89    {
90        let mut rng = SmallRng::seed_from_u64(1);
91        let mut arr: [F; N] = array::from_fn(|_| rng.random());
92
93        let mut arr_rev = arr.to_vec();
94        reverse_slice_index_bits(&mut arr_rev);
95
96        let shift = F::GENERATOR;
97        let mut coset_lde_naive = NaiveDft.coset_lde(arr_rev, 0, shift);
98        reverse_slice_index_bits(&mut coset_lde_naive);
99
100        let scale = F::from_usize(N);
101        coset_lde_naive.iter_mut().for_each(|x| *x *= scale);
102
103        IntegratedCosetMds::<F, N>::default().permute_mut(&mut arr);
104        assert_eq!(coset_lde_naive, arr);
105    }
106
107    macro_rules! matches_naive_test {
108        ($name:ident, $field:ty, $n:expr) => {
109            #[test]
110            fn $name() {
111                matches_naive_for::<$field, $n>();
112            }
113        };
114    }
115
116    matches_naive_test!(matches_naive_baby_bear_1, BabyBear, 1);
117    matches_naive_test!(matches_naive_baby_bear_2, BabyBear, 2);
118    matches_naive_test!(matches_naive_baby_bear_4, BabyBear, 4);
119    matches_naive_test!(matches_naive_baby_bear_8, BabyBear, 8);
120    matches_naive_test!(matches_naive_baby_bear_16, BabyBear, 16);
121    matches_naive_test!(matches_naive_baby_bear_32, BabyBear, 32);
122
123    matches_naive_test!(matches_naive_goldilocks_1, Goldilocks, 1);
124    matches_naive_test!(matches_naive_goldilocks_2, Goldilocks, 2);
125    matches_naive_test!(matches_naive_goldilocks_4, Goldilocks, 4);
126    matches_naive_test!(matches_naive_goldilocks_8, Goldilocks, 8);
127    matches_naive_test!(matches_naive_goldilocks_16, Goldilocks, 16);
128    matches_naive_test!(matches_naive_goldilocks_32, Goldilocks, 32);
129}