1use alloc::vec;
2
3use p3_field::TwoAdicField;
4use p3_matrix::dense::RowMajorMatrix;
5use p3_matrix::Matrix;
6use p3_util::log2_strict_usize;
7
8use crate::TwoAdicSubgroupDft;
9
10#[derive(Default, Clone, Debug)]
11pub struct NaiveDft;
12
13impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for NaiveDft {
14 type Evaluations = RowMajorMatrix<F>;
15 fn dft_batch(&self, mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
16 let w = mat.width();
17 let h = mat.height();
18 let log_h = log2_strict_usize(h);
19 let g = F::two_adic_generator(log_h);
20
21 let mut res = RowMajorMatrix::new(vec![F::ZERO; w * h], w);
22 for (res_r, point) in g.powers().take(h).enumerate() {
23 for (src_r, point_power) in point.powers().take(h).enumerate() {
24 for c in 0..w {
25 res.values[res_r * w + c] += point_power * mat.values[src_r * w + c]
26 }
27 }
28 }
29
30 res
31 }
32}
33
34#[cfg(test)]
35mod tests {
36 use alloc::vec;
37
38 use p3_baby_bear::BabyBear;
39 use p3_field::{Field, FieldAlgebra};
40 use p3_goldilocks::Goldilocks;
41 use p3_matrix::dense::RowMajorMatrix;
42 use rand::thread_rng;
43
44 use crate::{NaiveDft, TwoAdicSubgroupDft};
45
46 #[test]
47 fn basic() {
48 type F = BabyBear;
49
50 let mat = RowMajorMatrix::new(
55 vec![
56 F::from_canonical_u8(5),
57 F::from_canonical_u8(2),
58 F::ZERO,
59 F::from_canonical_u8(4),
60 F::from_canonical_u8(3),
61 F::ZERO,
62 ],
63 3,
64 );
65
66 let dft = NaiveDft.dft_batch(mat);
67 assert_eq!(
72 dft,
73 RowMajorMatrix::new(
74 vec![
75 F::from_canonical_u8(9),
76 F::from_canonical_u8(5),
77 F::ZERO,
78 F::ONE,
79 F::NEG_ONE,
80 F::ZERO,
81 ],
82 3,
83 )
84 )
85 }
86
87 #[test]
88 fn dft_idft_consistency() {
89 type F = Goldilocks;
90 let mut rng = thread_rng();
91 let original = RowMajorMatrix::<F>::rand(&mut rng, 8, 3);
92 let dft = NaiveDft.dft_batch(original.clone());
93 let idft = NaiveDft.idft_batch(dft);
94 assert_eq!(original, idft);
95 }
96
97 #[test]
98 fn coset_dft_idft_consistency() {
99 type F = Goldilocks;
100 let generator = F::GENERATOR;
101 let mut rng = thread_rng();
102 let original = RowMajorMatrix::<F>::rand(&mut rng, 8, 3);
103 let dft = NaiveDft.coset_dft_batch(original.clone(), generator);
104 let idft = NaiveDft.coset_idft_batch(dft, generator);
105 assert_eq!(original, idft);
106 }
107}