p3_dft/
radix_2_bowers.rs

1use alloc::vec::Vec;
2
3use p3_field::{Field, Powers, TwoAdicField};
4use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
5use p3_matrix::util::reverse_matrix_index_bits;
6use p3_matrix::Matrix;
7use p3_maybe_rayon::prelude::*;
8use p3_util::{log2_strict_usize, reverse_bits, reverse_slice_index_bits};
9use tracing::instrument;
10
11use crate::butterflies::{Butterfly, DifButterfly, DitButterfly, TwiddleFreeButterfly};
12use crate::util::divide_by_height;
13use crate::TwoAdicSubgroupDft;
14
15/// The Bowers G FFT algorithm.
16/// See: "Improved Twiddle Access for Fast Fourier Transforms"
17#[derive(Default, Clone)]
18pub struct Radix2Bowers;
19
20impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Bowers {
21    type Evaluations = RowMajorMatrix<F>;
22
23    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
24        reverse_matrix_index_bits(&mut mat);
25        bowers_g(&mut mat.as_view_mut());
26        mat
27    }
28
29    /// Compute the inverse DFT of each column in `mat`.
30    fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
31        bowers_g_t(&mut mat.as_view_mut());
32        divide_by_height(&mut mat);
33        reverse_matrix_index_bits(&mut mat);
34        mat
35    }
36
37    fn lde_batch(&self, mut mat: RowMajorMatrix<F>, added_bits: usize) -> RowMajorMatrix<F> {
38        bowers_g_t(&mut mat.as_view_mut());
39        divide_by_height(&mut mat);
40        mat = mat.bit_reversed_zero_pad(added_bits);
41        bowers_g(&mut mat.as_view_mut());
42        mat
43    }
44
45    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
46    fn coset_lde_batch(
47        &self,
48        mut mat: RowMajorMatrix<F>,
49        added_bits: usize,
50        shift: F,
51    ) -> RowMajorMatrix<F> {
52        let h = mat.height();
53        let h_inv = F::from_canonical_usize(h).inverse();
54
55        bowers_g_t(&mut mat.as_view_mut());
56
57        // Rescale coefficients in two ways:
58        // - divide by height (since we're doing an inverse DFT)
59        // - multiply by powers of the coset shift (see default coset LDE impl for an explanation)
60        let weights = Powers {
61            base: shift,
62            current: h_inv,
63        }
64        .take(h);
65        for (row, weight) in weights.enumerate() {
66            // reverse_bits because mat is encoded in bit-reversed order
67            mat.scale_row(reverse_bits(row, h), weight);
68        }
69
70        mat = mat.bit_reversed_zero_pad(added_bits);
71
72        bowers_g(&mut mat.as_view_mut());
73
74        mat
75    }
76}
77
78/// Executes the Bowers G network. This is like a DFT, except it assumes the input is in
79/// bit-reversed order.
80fn bowers_g<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<F>) {
81    let h = mat.height();
82    let log_h = log2_strict_usize(h);
83
84    let root = F::two_adic_generator(log_h);
85    let mut twiddles: Vec<_> = root.powers().take(h / 2).map(DifButterfly).collect();
86    reverse_slice_index_bits(&mut twiddles);
87
88    let log_h = log2_strict_usize(mat.height());
89    for log_half_block_size in 0..log_h {
90        butterfly_layer(mat, 1 << log_half_block_size, &twiddles)
91    }
92}
93
94/// Executes the Bowers G^T network. This is like an inverse DFT, except we skip rescaling by
95/// 1/height, and the output is bit-reversed.
96fn bowers_g_t<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<F>) {
97    let h = mat.height();
98    let log_h = log2_strict_usize(h);
99
100    let root_inv = F::two_adic_generator(log_h).inverse();
101    let mut twiddles: Vec<_> = root_inv.powers().take(h / 2).map(DitButterfly).collect();
102    reverse_slice_index_bits(&mut twiddles);
103
104    let log_h = log2_strict_usize(mat.height());
105    for log_half_block_size in (0..log_h).rev() {
106        butterfly_layer(mat, 1 << log_half_block_size, &twiddles)
107    }
108}
109
110fn butterfly_layer<F: Field, B: Butterfly<F>>(
111    mat: &mut RowMajorMatrixViewMut<F>,
112    half_block_size: usize,
113    twiddles: &[B],
114) {
115    mat.par_row_chunks_exact_mut(2 * half_block_size)
116        .enumerate()
117        .for_each(|(block, mut chunks)| {
118            let (mut hi_chunks, mut lo_chunks) = chunks.split_rows_mut(half_block_size);
119            hi_chunks
120                .par_rows_mut()
121                .zip(lo_chunks.par_rows_mut())
122                .for_each(|(hi_chunk, lo_chunk)| {
123                    if block == 0 {
124                        TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk)
125                    } else {
126                        twiddles[block].apply_to_rows(hi_chunk, lo_chunk);
127                    }
128                });
129        });
130}