p3_mds/
integrated_coset_mds.rs

1use alloc::vec::Vec;
2
3use p3_field::{FieldAlgebra, Powers, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::butterflies::{dif_butterfly, dit_butterfly, twiddle_free_butterfly};
8use crate::MdsPermutation;
9
10/// Like `CosetMds`, with a few differences:
11/// - (Bit reversed, a la Bowers) DIF + DIT rather than DIT + DIF
12/// - We skip bit reversals of the inputs and outputs
13/// - We don't weight by `1/N`, since this doesn't affect the MDS property
14/// - We integrate the coset shifts into the DIF's twiddle factors
15#[derive(Clone, Debug)]
16pub struct IntegratedCosetMds<F, const N: usize> {
17    ifft_twiddles: Vec<F>,
18    fft_twiddles: Vec<Vec<F>>,
19}
20
21impl<F: TwoAdicField, const N: usize> Default for IntegratedCosetMds<F, N> {
22    fn default() -> Self {
23        let log_n = log2_strict_usize(N);
24        let root = F::two_adic_generator(log_n);
25        let root_inv = root.inverse();
26        let coset_shift = F::GENERATOR;
27
28        let mut ifft_twiddles: Vec<F> = root_inv.powers().take(N / 2).collect();
29        reverse_slice_index_bits(&mut ifft_twiddles);
30
31        let fft_twiddles: Vec<Vec<F>> = (0..log_n)
32            .map(|layer| {
33                let shift_power = coset_shift.exp_power_of_2(layer);
34                let powers = Powers {
35                    base: root.exp_power_of_2(layer),
36                    current: shift_power,
37                };
38                let mut twiddles: Vec<_> = powers.take(N >> (layer + 1)).collect();
39                reverse_slice_index_bits(&mut twiddles);
40                twiddles
41            })
42            .collect();
43
44        Self {
45            ifft_twiddles,
46            fft_twiddles,
47        }
48    }
49}
50
51impl<FA: FieldAlgebra, const N: usize> Permutation<[FA; N]> for IntegratedCosetMds<FA::F, N> {
52    fn permute(&self, mut input: [FA; N]) -> [FA; N] {
53        self.permute_mut(&mut input);
54        input
55    }
56
57    fn permute_mut(&self, values: &mut [FA; N]) {
58        let log_n = log2_strict_usize(N);
59
60        // Bit-reversed DIF, aka Bowers G
61        for layer in 0..log_n {
62            bowers_g_layer(values, layer, &self.ifft_twiddles);
63        }
64
65        // Bit-reversed DIT, aka Bowers G^T
66        for layer in (0..log_n).rev() {
67            bowers_g_t_layer(values, layer, &self.fft_twiddles[layer]);
68        }
69    }
70}
71
72impl<FA: FieldAlgebra, const N: usize> MdsPermutation<FA, N> for IntegratedCosetMds<FA::F, N> {}
73
74#[inline]
75fn bowers_g_layer<FA: FieldAlgebra, const N: usize>(
76    values: &mut [FA; N],
77    log_half_block_size: usize,
78    twiddles: &[FA::F],
79) {
80    let log_block_size = log_half_block_size + 1;
81    let half_block_size = 1 << log_half_block_size;
82    let num_blocks = N >> log_block_size;
83
84    // Unroll first iteration with a twiddle factor of 1.
85    for hi in 0..half_block_size {
86        let lo = hi + half_block_size;
87        twiddle_free_butterfly(values, hi, lo);
88    }
89
90    for (block, &twiddle) in (1..num_blocks).zip(&twiddles[1..]) {
91        let block_start = block << log_block_size;
92        for hi in block_start..block_start + half_block_size {
93            let lo = hi + half_block_size;
94            dif_butterfly(values, hi, lo, twiddle);
95        }
96    }
97}
98
99#[inline]
100fn bowers_g_t_layer<FA: FieldAlgebra, const N: usize>(
101    values: &mut [FA; N],
102    log_half_block_size: usize,
103    twiddles: &[FA::F],
104) {
105    let log_block_size = log_half_block_size + 1;
106    let half_block_size = 1 << log_half_block_size;
107    let num_blocks = N >> log_block_size;
108
109    for (block, &twiddle) in (0..num_blocks).zip(twiddles) {
110        let block_start = block << log_block_size;
111        for hi in block_start..block_start + half_block_size {
112            let lo = hi + half_block_size;
113            dit_butterfly(values, hi, lo, twiddle);
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use p3_baby_bear::BabyBear;
121    use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
122    use p3_field::{Field, FieldAlgebra};
123    use p3_symmetric::Permutation;
124    use p3_util::reverse_slice_index_bits;
125    use rand::{thread_rng, Rng};
126
127    use crate::integrated_coset_mds::IntegratedCosetMds;
128
129    type F = BabyBear;
130    const N: usize = 16;
131
132    #[test]
133    fn matches_naive() {
134        let mut rng = thread_rng();
135        let mut arr: [F; N] = rng.gen();
136
137        let mut arr_rev = arr.to_vec();
138        reverse_slice_index_bits(&mut arr_rev);
139
140        let shift = F::GENERATOR;
141        let mut coset_lde_naive = NaiveDft.coset_lde(arr_rev, 0, shift);
142        reverse_slice_index_bits(&mut coset_lde_naive);
143        coset_lde_naive
144            .iter_mut()
145            .for_each(|x| *x *= F::from_canonical_usize(N));
146        IntegratedCosetMds::default().permute_mut(&mut arr);
147        assert_eq!(coset_lde_naive, arr);
148    }
149}