p3_dft/
naive.rs

1use alloc::vec;
2
3use p3_field::TwoAdicField;
4use p3_matrix::dense::RowMajorMatrix;
5use p3_matrix::Matrix;
6use p3_util::log2_strict_usize;
7
8use crate::TwoAdicSubgroupDft;
9
10#[derive(Default, Clone, Debug)]
11pub struct NaiveDft;
12
13impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for NaiveDft {
14    type Evaluations = RowMajorMatrix<F>;
15    fn dft_batch(&self, mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
16        let w = mat.width();
17        let h = mat.height();
18        let log_h = log2_strict_usize(h);
19        let g = F::two_adic_generator(log_h);
20
21        let mut res = RowMajorMatrix::new(vec![F::ZERO; w * h], w);
22        for (res_r, point) in g.powers().take(h).enumerate() {
23            for (src_r, point_power) in point.powers().take(h).enumerate() {
24                for c in 0..w {
25                    res.values[res_r * w + c] += point_power * mat.values[src_r * w + c]
26                }
27            }
28        }
29
30        res
31    }
32}
33
34#[cfg(test)]
35mod tests {
36    use alloc::vec;
37
38    use p3_baby_bear::BabyBear;
39    use p3_field::{Field, FieldAlgebra};
40    use p3_goldilocks::Goldilocks;
41    use p3_matrix::dense::RowMajorMatrix;
42    use rand::thread_rng;
43
44    use crate::{NaiveDft, TwoAdicSubgroupDft};
45
46    #[test]
47    fn basic() {
48        type F = BabyBear;
49
50        // A few polynomials:
51        // 5 + 4x
52        // 2 + 3x
53        // 0
54        let mat = RowMajorMatrix::new(
55            vec![
56                F::from_canonical_u8(5),
57                F::from_canonical_u8(2),
58                F::ZERO,
59                F::from_canonical_u8(4),
60                F::from_canonical_u8(3),
61                F::ZERO,
62            ],
63            3,
64        );
65
66        let dft = NaiveDft.dft_batch(mat);
67        // Expected evaluations on {1, -1}:
68        // 9, 1
69        // 5, -1
70        // 0, 0
71        assert_eq!(
72            dft,
73            RowMajorMatrix::new(
74                vec![
75                    F::from_canonical_u8(9),
76                    F::from_canonical_u8(5),
77                    F::ZERO,
78                    F::ONE,
79                    F::NEG_ONE,
80                    F::ZERO,
81                ],
82                3,
83            )
84        )
85    }
86
87    #[test]
88    fn dft_idft_consistency() {
89        type F = Goldilocks;
90        let mut rng = thread_rng();
91        let original = RowMajorMatrix::<F>::rand(&mut rng, 8, 3);
92        let dft = NaiveDft.dft_batch(original.clone());
93        let idft = NaiveDft.idft_batch(dft);
94        assert_eq!(original, idft);
95    }
96
97    #[test]
98    fn coset_dft_idft_consistency() {
99        type F = Goldilocks;
100        let generator = F::GENERATOR;
101        let mut rng = thread_rng();
102        let original = RowMajorMatrix::<F>::rand(&mut rng, 8, 3);
103        let dft = NaiveDft.coset_dft_batch(original.clone(), generator);
104        let idft = NaiveDft.coset_idft_batch(dft, generator);
105        assert_eq!(original, idft);
106    }
107}