p3_mds/
coset_mds.rs
1use alloc::vec::Vec;
2
3use p3_field::{FieldAlgebra, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::butterflies::{dif_butterfly, dit_butterfly, twiddle_free_butterfly};
8use crate::MdsPermutation;
9
10#[derive(Clone, Debug)]
17pub struct CosetMds<F, const N: usize> {
18 fft_twiddles: Vec<F>,
19 ifft_twiddles: Vec<F>,
20 weights: [F; N],
21}
22
23impl<F, const N: usize> Default for CosetMds<F, N>
24where
25 F: TwoAdicField,
26{
27 fn default() -> Self {
28 let log_n = log2_strict_usize(N);
29
30 let root = F::two_adic_generator(log_n);
31 let root_inv = root.inverse();
32 let mut fft_twiddles: Vec<F> = root.powers().take(N / 2).collect();
33 let mut ifft_twiddles: Vec<F> = root_inv.powers().take(N / 2).collect();
34 reverse_slice_index_bits(&mut fft_twiddles);
35 reverse_slice_index_bits(&mut ifft_twiddles);
36
37 let shift = F::GENERATOR;
38 let mut weights: [F; N] = shift
39 .powers()
40 .take(N)
41 .collect::<Vec<_>>()
42 .try_into()
43 .unwrap();
44 reverse_slice_index_bits(&mut weights);
45 Self {
46 fft_twiddles,
47 ifft_twiddles,
48 weights,
49 }
50 }
51}
52
53impl<FA, const N: usize> Permutation<[FA; N]> for CosetMds<FA::F, N>
54where
55 FA: FieldAlgebra,
56 FA::F: TwoAdicField,
57{
58 fn permute(&self, mut input: [FA; N]) -> [FA; N] {
59 self.permute_mut(&mut input);
60 input
61 }
62
63 fn permute_mut(&self, values: &mut [FA; N]) {
64 bowers_g_t(values, &self.ifft_twiddles);
66
67 for (value, weight) in values.iter_mut().zip(self.weights) {
69 *value = value.clone() * FA::from_f(weight);
70 }
71
72 bowers_g(values, &self.fft_twiddles);
74 }
75}
76
77impl<FA, const N: usize> MdsPermutation<FA, N> for CosetMds<FA::F, N>
78where
79 FA: FieldAlgebra,
80 FA::F: TwoAdicField,
81{
82}
83
84#[inline]
87fn bowers_g<FA: FieldAlgebra, const N: usize>(values: &mut [FA; N], twiddles: &[FA::F]) {
88 let log_n = log2_strict_usize(N);
89 for log_half_block_size in 0..log_n {
90 bowers_g_layer(values, log_half_block_size, twiddles);
91 }
92}
93
94#[inline]
97fn bowers_g_t<FA: FieldAlgebra, const N: usize>(values: &mut [FA; N], twiddles: &[FA::F]) {
98 let log_n = log2_strict_usize(N);
99 for log_half_block_size in (0..log_n).rev() {
100 bowers_g_t_layer(values, log_half_block_size, twiddles);
101 }
102}
103
104#[inline]
106fn bowers_g_layer<FA: FieldAlgebra, const N: usize>(
107 values: &mut [FA; N],
108 log_half_block_size: usize,
109 twiddles: &[FA::F],
110) {
111 let log_block_size = log_half_block_size + 1;
112 let half_block_size = 1 << log_half_block_size;
113 let num_blocks = N >> log_block_size;
114
115 for hi in 0..half_block_size {
117 let lo = hi + half_block_size;
118 twiddle_free_butterfly(values, hi, lo);
119 }
120
121 for (block, &twiddle) in (1..num_blocks).zip(&twiddles[1..]) {
122 let block_start = block << log_block_size;
123 for hi in block_start..block_start + half_block_size {
124 let lo = hi + half_block_size;
125 dif_butterfly(values, hi, lo, twiddle);
126 }
127 }
128}
129
130#[inline]
132fn bowers_g_t_layer<FA: FieldAlgebra, const N: usize>(
133 values: &mut [FA; N],
134 log_half_block_size: usize,
135 twiddles: &[FA::F],
136) {
137 let log_block_size = log_half_block_size + 1;
138 let half_block_size = 1 << log_half_block_size;
139 let num_blocks = N >> log_block_size;
140
141 for hi in 0..half_block_size {
143 let lo = hi + half_block_size;
144 twiddle_free_butterfly(values, hi, lo);
145 }
146
147 for (block, &twiddle) in (1..num_blocks).zip(&twiddles[1..]) {
148 let block_start = block << log_block_size;
149 for hi in block_start..block_start + half_block_size {
150 let lo = hi + half_block_size;
151 dit_butterfly(values, hi, lo, twiddle);
152 }
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use p3_baby_bear::BabyBear;
159 use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
160 use p3_field::{Field, FieldAlgebra};
161 use p3_symmetric::Permutation;
162 use rand::{thread_rng, Rng};
163
164 use crate::coset_mds::CosetMds;
165
166 #[test]
167 fn matches_naive() {
168 type F = BabyBear;
169 const N: usize = 8;
170
171 let mut rng = thread_rng();
172 let mut arr: [F; N] = rng.gen();
173
174 let shift = F::GENERATOR;
175 let mut coset_lde_naive = NaiveDft.coset_lde(arr.to_vec(), 0, shift);
176 coset_lde_naive
177 .iter_mut()
178 .for_each(|x| *x *= F::from_canonical_usize(N));
179 CosetMds::default().permute_mut(&mut arr);
180 assert_eq!(coset_lde_naive, arr);
181 }
182}