p3_mds/
coset_mds.rs

1use alloc::vec::Vec;
2
3use p3_field::{FieldAlgebra, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::butterflies::{dif_butterfly, dit_butterfly, twiddle_free_butterfly};
8use crate::MdsPermutation;
9
10/// A Reed-Solomon based MDS permutation.
11///
12/// An MDS permutation which works by interpreting the input as evaluations of a polynomial over a
13/// power-of-two subgroup, and computing evaluations over a coset of that subgroup. This can be
14/// viewed as returning the parity elements of a systematic Reed-Solomon code. Since Reed-Solomon
15/// codes are MDS, this is an MDS permutation.
16#[derive(Clone, Debug)]
17pub struct CosetMds<F, const N: usize> {
18    fft_twiddles: Vec<F>,
19    ifft_twiddles: Vec<F>,
20    weights: [F; N],
21}
22
23impl<F, const N: usize> Default for CosetMds<F, N>
24where
25    F: TwoAdicField,
26{
27    fn default() -> Self {
28        let log_n = log2_strict_usize(N);
29
30        let root = F::two_adic_generator(log_n);
31        let root_inv = root.inverse();
32        let mut fft_twiddles: Vec<F> = root.powers().take(N / 2).collect();
33        let mut ifft_twiddles: Vec<F> = root_inv.powers().take(N / 2).collect();
34        reverse_slice_index_bits(&mut fft_twiddles);
35        reverse_slice_index_bits(&mut ifft_twiddles);
36
37        let shift = F::GENERATOR;
38        let mut weights: [F; N] = shift
39            .powers()
40            .take(N)
41            .collect::<Vec<_>>()
42            .try_into()
43            .unwrap();
44        reverse_slice_index_bits(&mut weights);
45        Self {
46            fft_twiddles,
47            ifft_twiddles,
48            weights,
49        }
50    }
51}
52
53impl<FA, const N: usize> Permutation<[FA; N]> for CosetMds<FA::F, N>
54where
55    FA: FieldAlgebra,
56    FA::F: TwoAdicField,
57{
58    fn permute(&self, mut input: [FA; N]) -> [FA; N] {
59        self.permute_mut(&mut input);
60        input
61    }
62
63    fn permute_mut(&self, values: &mut [FA; N]) {
64        // Inverse DFT, except we skip bit reversal and rescaling by 1/N.
65        bowers_g_t(values, &self.ifft_twiddles);
66
67        // Multiply by powers of the coset shift (see default coset LDE impl for an explanation)
68        for (value, weight) in values.iter_mut().zip(self.weights) {
69            *value = value.clone() * FA::from_f(weight);
70        }
71
72        // DFT, assuming bit-reversed input.
73        bowers_g(values, &self.fft_twiddles);
74    }
75}
76
77impl<FA, const N: usize> MdsPermutation<FA, N> for CosetMds<FA::F, N>
78where
79    FA: FieldAlgebra,
80    FA::F: TwoAdicField,
81{
82}
83
84/// Executes the Bowers G network. This is like a DFT, except it assumes the input is in
85/// bit-reversed order.
86#[inline]
87fn bowers_g<FA: FieldAlgebra, const N: usize>(values: &mut [FA; N], twiddles: &[FA::F]) {
88    let log_n = log2_strict_usize(N);
89    for log_half_block_size in 0..log_n {
90        bowers_g_layer(values, log_half_block_size, twiddles);
91    }
92}
93
94/// Executes the Bowers G^T network. This is like an inverse DFT, except we skip rescaling by
95/// `1/N`, and the output is bit-reversed.
96#[inline]
97fn bowers_g_t<FA: FieldAlgebra, const N: usize>(values: &mut [FA; N], twiddles: &[FA::F]) {
98    let log_n = log2_strict_usize(N);
99    for log_half_block_size in (0..log_n).rev() {
100        bowers_g_t_layer(values, log_half_block_size, twiddles);
101    }
102}
103
104/// One layer of a Bowers G network. Equivalent to `bowers_g_t_layer` except for the butterfly.
105#[inline]
106fn bowers_g_layer<FA: FieldAlgebra, const N: usize>(
107    values: &mut [FA; N],
108    log_half_block_size: usize,
109    twiddles: &[FA::F],
110) {
111    let log_block_size = log_half_block_size + 1;
112    let half_block_size = 1 << log_half_block_size;
113    let num_blocks = N >> log_block_size;
114
115    // Unroll first iteration with a twiddle factor of 1.
116    for hi in 0..half_block_size {
117        let lo = hi + half_block_size;
118        twiddle_free_butterfly(values, hi, lo);
119    }
120
121    for (block, &twiddle) in (1..num_blocks).zip(&twiddles[1..]) {
122        let block_start = block << log_block_size;
123        for hi in block_start..block_start + half_block_size {
124            let lo = hi + half_block_size;
125            dif_butterfly(values, hi, lo, twiddle);
126        }
127    }
128}
129
130/// One layer of a Bowers G^T network. Equivalent to `bowers_g_layer` except for the butterfly.
131#[inline]
132fn bowers_g_t_layer<FA: FieldAlgebra, const N: usize>(
133    values: &mut [FA; N],
134    log_half_block_size: usize,
135    twiddles: &[FA::F],
136) {
137    let log_block_size = log_half_block_size + 1;
138    let half_block_size = 1 << log_half_block_size;
139    let num_blocks = N >> log_block_size;
140
141    // Unroll first iteration with a twiddle factor of 1.
142    for hi in 0..half_block_size {
143        let lo = hi + half_block_size;
144        twiddle_free_butterfly(values, hi, lo);
145    }
146
147    for (block, &twiddle) in (1..num_blocks).zip(&twiddles[1..]) {
148        let block_start = block << log_block_size;
149        for hi in block_start..block_start + half_block_size {
150            let lo = hi + half_block_size;
151            dit_butterfly(values, hi, lo, twiddle);
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use p3_baby_bear::BabyBear;
159    use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
160    use p3_field::{Field, FieldAlgebra};
161    use p3_symmetric::Permutation;
162    use rand::{thread_rng, Rng};
163
164    use crate::coset_mds::CosetMds;
165
166    #[test]
167    fn matches_naive() {
168        type F = BabyBear;
169        const N: usize = 8;
170
171        let mut rng = thread_rng();
172        let mut arr: [F; N] = rng.gen();
173
174        let shift = F::GENERATOR;
175        let mut coset_lde_naive = NaiveDft.coset_lde(arr.to_vec(), 0, shift);
176        coset_lde_naive
177            .iter_mut()
178            .for_each(|x| *x *= F::from_canonical_usize(N));
179        CosetMds::default().permute_mut(&mut arr);
180        assert_eq!(coset_lde_naive, arr);
181    }
182}