p3_mds/
integrated_coset_mds.rs
1use alloc::vec::Vec;
2
3use p3_field::{FieldAlgebra, Powers, TwoAdicField};
4use p3_symmetric::Permutation;
5use p3_util::{log2_strict_usize, reverse_slice_index_bits};
6
7use crate::butterflies::{dif_butterfly, dit_butterfly, twiddle_free_butterfly};
8use crate::MdsPermutation;
9
10#[derive(Clone, Debug)]
16pub struct IntegratedCosetMds<F, const N: usize> {
17 ifft_twiddles: Vec<F>,
18 fft_twiddles: Vec<Vec<F>>,
19}
20
21impl<F: TwoAdicField, const N: usize> Default for IntegratedCosetMds<F, N> {
22 fn default() -> Self {
23 let log_n = log2_strict_usize(N);
24 let root = F::two_adic_generator(log_n);
25 let root_inv = root.inverse();
26 let coset_shift = F::GENERATOR;
27
28 let mut ifft_twiddles: Vec<F> = root_inv.powers().take(N / 2).collect();
29 reverse_slice_index_bits(&mut ifft_twiddles);
30
31 let fft_twiddles: Vec<Vec<F>> = (0..log_n)
32 .map(|layer| {
33 let shift_power = coset_shift.exp_power_of_2(layer);
34 let powers = Powers {
35 base: root.exp_power_of_2(layer),
36 current: shift_power,
37 };
38 let mut twiddles: Vec<_> = powers.take(N >> (layer + 1)).collect();
39 reverse_slice_index_bits(&mut twiddles);
40 twiddles
41 })
42 .collect();
43
44 Self {
45 ifft_twiddles,
46 fft_twiddles,
47 }
48 }
49}
50
51impl<FA: FieldAlgebra, const N: usize> Permutation<[FA; N]> for IntegratedCosetMds<FA::F, N> {
52 fn permute(&self, mut input: [FA; N]) -> [FA; N] {
53 self.permute_mut(&mut input);
54 input
55 }
56
57 fn permute_mut(&self, values: &mut [FA; N]) {
58 let log_n = log2_strict_usize(N);
59
60 for layer in 0..log_n {
62 bowers_g_layer(values, layer, &self.ifft_twiddles);
63 }
64
65 for layer in (0..log_n).rev() {
67 bowers_g_t_layer(values, layer, &self.fft_twiddles[layer]);
68 }
69 }
70}
71
72impl<FA: FieldAlgebra, const N: usize> MdsPermutation<FA, N> for IntegratedCosetMds<FA::F, N> {}
73
74#[inline]
75fn bowers_g_layer<FA: FieldAlgebra, const N: usize>(
76 values: &mut [FA; N],
77 log_half_block_size: usize,
78 twiddles: &[FA::F],
79) {
80 let log_block_size = log_half_block_size + 1;
81 let half_block_size = 1 << log_half_block_size;
82 let num_blocks = N >> log_block_size;
83
84 for hi in 0..half_block_size {
86 let lo = hi + half_block_size;
87 twiddle_free_butterfly(values, hi, lo);
88 }
89
90 for (block, &twiddle) in (1..num_blocks).zip(&twiddles[1..]) {
91 let block_start = block << log_block_size;
92 for hi in block_start..block_start + half_block_size {
93 let lo = hi + half_block_size;
94 dif_butterfly(values, hi, lo, twiddle);
95 }
96 }
97}
98
99#[inline]
100fn bowers_g_t_layer<FA: FieldAlgebra, const N: usize>(
101 values: &mut [FA; N],
102 log_half_block_size: usize,
103 twiddles: &[FA::F],
104) {
105 let log_block_size = log_half_block_size + 1;
106 let half_block_size = 1 << log_half_block_size;
107 let num_blocks = N >> log_block_size;
108
109 for (block, &twiddle) in (0..num_blocks).zip(twiddles) {
110 let block_start = block << log_block_size;
111 for hi in block_start..block_start + half_block_size {
112 let lo = hi + half_block_size;
113 dit_butterfly(values, hi, lo, twiddle);
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use p3_baby_bear::BabyBear;
121 use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
122 use p3_field::{Field, FieldAlgebra};
123 use p3_symmetric::Permutation;
124 use p3_util::reverse_slice_index_bits;
125 use rand::{thread_rng, Rng};
126
127 use crate::integrated_coset_mds::IntegratedCosetMds;
128
129 type F = BabyBear;
130 const N: usize = 16;
131
132 #[test]
133 fn matches_naive() {
134 let mut rng = thread_rng();
135 let mut arr: [F; N] = rng.gen();
136
137 let mut arr_rev = arr.to_vec();
138 reverse_slice_index_bits(&mut arr_rev);
139
140 let shift = F::GENERATOR;
141 let mut coset_lde_naive = NaiveDft.coset_lde(arr_rev, 0, shift);
142 reverse_slice_index_bits(&mut coset_lde_naive);
143 coset_lde_naive
144 .iter_mut()
145 .for_each(|x| *x *= F::from_canonical_usize(N));
146 IntegratedCosetMds::default().permute_mut(&mut arr);
147 assert_eq!(coset_lde_naive, arr);
148 }
149}