p3_dft/
radix_2_dit_parallel.rs

1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::vec::Vec;
4use core::cell::RefCell;
5use core::mem::{transmute, MaybeUninit};
6
7use itertools::{izip, Itertools};
8use p3_field::{Field, Powers, TwoAdicField};
9use p3_matrix::bitrev::{BitReversableMatrix, BitReversalPerm, BitReversedMatrixView};
10use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
11use p3_matrix::util::reverse_matrix_index_bits;
12use p3_matrix::Matrix;
13use p3_maybe_rayon::prelude::*;
14use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
15use tracing::{debug_span, instrument};
16
17use crate::butterflies::{Butterfly, DitButterfly};
18use crate::TwoAdicSubgroupDft;
19
20/// A parallel FFT algorithm which divides a butterfly network's layers into two halves.
21///
22/// For the first half, we apply a butterfly network with smaller blocks in earlier layers,
23/// i.e. either DIT or Bowers G. Then we bit-reverse, and for the second half, we continue executing
24/// the same network but in bit-reversed order. This way we're always working with small blocks,
25/// so within each half, we can have a certain amount of parallelism with no cross-thread
26/// communication.
27#[derive(Default, Clone, Debug)]
28pub struct Radix2DitParallel<F> {
29    /// Twiddles based on roots of unity, used in the forward DFT.
30    twiddles: RefCell<BTreeMap<usize, VectorPair<F>>>,
31
32    /// A map from `(log_h, shift)` to forward DFT twiddles with that coset shift baked in.
33    #[allow(clippy::type_complexity)]
34    coset_twiddles: RefCell<BTreeMap<(usize, F), Vec<Vec<F>>>>,
35
36    /// Twiddles based on inverse roots of unity, used in the inverse DFT.
37    inverse_twiddles: RefCell<BTreeMap<usize, VectorPair<F>>>,
38}
39
40/// A pair of vectors, one with twiddle factors in their natural order, the other bit-reversed.
41#[derive(Default, Clone, Debug)]
42struct VectorPair<F> {
43    twiddles: Vec<F>,
44    bitrev_twiddles: Vec<F>,
45}
46
47#[instrument(level = "debug", skip_all)]
48fn compute_twiddles<F: TwoAdicField + Ord>(log_h: usize) -> VectorPair<F> {
49    let half_h = (1 << log_h) >> 1;
50    let root = F::two_adic_generator(log_h);
51    let twiddles: Vec<F> = root.powers().take(half_h).collect();
52    let mut bit_reversed_twiddles = twiddles.clone();
53    reverse_slice_index_bits(&mut bit_reversed_twiddles);
54    VectorPair {
55        twiddles,
56        bitrev_twiddles: bit_reversed_twiddles,
57    }
58}
59
60#[instrument(level = "debug", skip_all)]
61fn compute_coset_twiddles<F: TwoAdicField + Ord>(log_h: usize, shift: F) -> Vec<Vec<F>> {
62    // In general either div_floor or div_ceil would work, but here we prefer div_ceil because it
63    // lets us assume below that the "first half" of the network has at least one layer of
64    // butterflies, even in the case of log_h = 1.
65    let mid = log_h.div_ceil(2);
66    let h = 1 << log_h;
67    let root = F::two_adic_generator(log_h);
68
69    (0..log_h)
70        .map(|layer| {
71            let shift_power = shift.exp_power_of_2(layer);
72            let powers = Powers {
73                base: root.exp_power_of_2(layer),
74                current: shift_power,
75            };
76            let mut twiddles: Vec<_> = powers.take(h >> (layer + 1)).collect();
77            let layer_rev = log_h - 1 - layer;
78            if layer_rev >= mid {
79                reverse_slice_index_bits(&mut twiddles);
80            }
81            twiddles
82        })
83        .collect()
84}
85
86#[instrument(level = "debug", skip_all)]
87fn compute_inverse_twiddles<F: TwoAdicField + Ord>(log_h: usize) -> VectorPair<F> {
88    let half_h = (1 << log_h) >> 1;
89    let root_inv = F::two_adic_generator(log_h).inverse();
90    let twiddles: Vec<F> = root_inv.powers().take(half_h).collect();
91    let mut bit_reversed_twiddles = twiddles.clone();
92
93    // In the middle of the coset LDE, we're in bit-reversed order.
94    reverse_slice_index_bits(&mut bit_reversed_twiddles);
95
96    VectorPair {
97        twiddles,
98        bitrev_twiddles: bit_reversed_twiddles,
99    }
100}
101
102impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
103    type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
104
105    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
106        let h = mat.height();
107        let log_h = log2_strict_usize(h);
108
109        // Compute twiddle factors, or take memoized ones if already available.
110        let mut twiddles_ref_mut = self.twiddles.borrow_mut();
111        let twiddles = twiddles_ref_mut
112            .entry(log_h)
113            .or_insert_with(|| compute_twiddles(log_h));
114
115        let mid = log_h.div_ceil(2);
116
117        // The first half looks like a normal DIT.
118        reverse_matrix_index_bits(&mut mat);
119        first_half(&mut mat, mid, &twiddles.twiddles);
120
121        // For the second half, we flip the DIT, working in bit-reversed order.
122        reverse_matrix_index_bits(&mut mat);
123        second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
124
125        mat.bit_reverse_rows()
126    }
127
128    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits = added_bits))]
129    fn coset_lde_batch(
130        &self,
131        mut mat: RowMajorMatrix<F>,
132        added_bits: usize,
133        shift: F,
134    ) -> Self::Evaluations {
135        let w = mat.width;
136        let h = mat.height();
137        let log_h = log2_strict_usize(h);
138        let mid = log_h.div_ceil(2);
139
140        let mut inverse_twiddles_ref_mut = self.inverse_twiddles.borrow_mut();
141        let inverse_twiddles = inverse_twiddles_ref_mut
142            .entry(log_h)
143            .or_insert_with(|| compute_inverse_twiddles(log_h));
144
145        // The first half looks like a normal DIT.
146        reverse_matrix_index_bits(&mut mat);
147        first_half(&mut mat, mid, &inverse_twiddles.twiddles);
148
149        // For the second half, we flip the DIT, working in bit-reversed order.
150        reverse_matrix_index_bits(&mut mat);
151        // We'll also scale by 1/h, as per the usual inverse DFT algorithm.
152        let scale = Some(F::from_canonical_usize(h).inverse());
153        second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
154        // We skip the final bit-reversal, since the next FFT expects bit-reversed input.
155
156        let lde_elems = w * (h << added_bits);
157        let elems_to_add = lde_elems - w * h;
158        debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
159
160        let g_big = F::two_adic_generator(log_h + added_bits);
161
162        let mat_ptr = mat.values.as_mut_ptr();
163        let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
164        let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
165        let rest_slice: &mut [MaybeUninit<F>] =
166            unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
167        let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
168        let mut rest_cosets_mat = rest_slice
169            .chunks_exact_mut(w * h)
170            .map(|slice| RowMajorMatrixViewMut::new(slice, w))
171            .collect_vec();
172
173        for coset_idx in 1..(1 << added_bits) {
174            let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
175            let coset_idx = reverse_bits_len(coset_idx, added_bits);
176            let dest = &mut rest_cosets_mat[coset_idx - 1]; // - 1 because we removed the first matrix.
177            coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
178        }
179
180        // Now run a forward DFT on the very first coset, this time in-place.
181        coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
182
183        // SAFETY: We wrote all values above.
184        unsafe {
185            mat.values.set_len(lde_elems);
186        }
187        BitReversalPerm::new_view(mat)
188    }
189}
190
191#[instrument(level = "debug", skip_all)]
192fn coset_dft<F: TwoAdicField + Ord>(
193    dft: &Radix2DitParallel<F>,
194    mat: &mut RowMajorMatrixViewMut<F>,
195    shift: F,
196) {
197    let log_h = log2_strict_usize(mat.height());
198    let mid = log_h.div_ceil(2);
199
200    let mut twiddles_ref_mut = dft.coset_twiddles.borrow_mut();
201    let twiddles = twiddles_ref_mut
202        .entry((log_h, shift))
203        .or_insert_with(|| compute_coset_twiddles(log_h, shift));
204
205    // The first half looks like a normal DIT.
206    first_half_general(mat, mid, twiddles);
207
208    // For the second half, we flip the DIT, working in bit-reversed order.
209    reverse_matrix_index_bits(mat);
210
211    second_half_general(mat, mid, twiddles);
212}
213
214/// Like `coset_dft`, except out-of-place.
215#[instrument(level = "debug", skip_all)]
216fn coset_dft_oop<F: TwoAdicField + Ord>(
217    dft: &Radix2DitParallel<F>,
218    src: &RowMajorMatrixView<F>,
219    dst_maybe: &mut RowMajorMatrixViewMut<MaybeUninit<F>>,
220    shift: F,
221) {
222    assert_eq!(src.dimensions(), dst_maybe.dimensions());
223
224    let log_h = log2_strict_usize(dst_maybe.height());
225
226    if log_h == 0 {
227        // This is an edge case where first_half_general_oop doesn't work, as it expects there to be
228        // at least one layer in the network, so we just copy instead.
229        let src_maybe = unsafe {
230            transmute::<&RowMajorMatrixView<F>, &RowMajorMatrixView<MaybeUninit<F>>>(src)
231        };
232        dst_maybe.copy_from(src_maybe);
233        return;
234    }
235
236    let mid = log_h.div_ceil(2);
237
238    let mut twiddles_ref_mut = dft.coset_twiddles.borrow_mut();
239    let twiddles = twiddles_ref_mut
240        .entry((log_h, shift))
241        .or_insert_with(|| compute_coset_twiddles(log_h, shift));
242
243    // The first half looks like a normal DIT.
244    first_half_general_oop(src, dst_maybe, mid, twiddles);
245
246    // dst is now initialized.
247    let dst = unsafe {
248        transmute::<&mut RowMajorMatrixViewMut<MaybeUninit<F>>, &mut RowMajorMatrixViewMut<F>>(
249            dst_maybe,
250        )
251    };
252
253    // For the second half, we flip the DIT, working in bit-reversed order.
254    reverse_matrix_index_bits(dst);
255
256    second_half_general(dst, mid, twiddles);
257}
258
259/// This can be used as the first half of a DIT butterfly network.
260#[instrument(level = "debug", skip_all)]
261fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
262    let log_h = log2_strict_usize(mat.height());
263
264    // max block size: 2^mid
265    mat.par_row_chunks_exact_mut(1 << mid)
266        .for_each(|mut submat| {
267            let mut backwards = false;
268            for layer in 0..mid {
269                let layer_rev = log_h - 1 - layer;
270                let layer_pow = 1 << layer_rev;
271                dit_layer(
272                    &mut submat,
273                    layer,
274                    twiddles.iter().copied().step_by(layer_pow),
275                    backwards,
276                );
277                backwards = !backwards;
278            }
279        });
280}
281
282/// Like `first_half`, except supporting different twiddle factors per layer, enabling coset shifts
283/// to be baked into them.
284#[instrument(level = "debug", skip_all)]
285fn first_half_general<F: Field>(
286    mat: &mut RowMajorMatrixViewMut<F>,
287    mid: usize,
288    twiddles: &[Vec<F>],
289) {
290    let log_h = log2_strict_usize(mat.height());
291    mat.par_row_chunks_exact_mut(1 << mid)
292        .for_each(|mut submat| {
293            let mut backwards = false;
294            for layer in 0..mid {
295                let layer_rev = log_h - 1 - layer;
296                dit_layer(
297                    &mut submat,
298                    layer,
299                    twiddles[layer_rev].iter().copied(),
300                    backwards,
301                );
302                backwards = !backwards;
303            }
304        });
305}
306
307/// Like `first_half_general`, except out-of-place.
308///
309/// Assumes there's at least one layer in the network, i.e. `src.height() > 1`.
310/// Undefined behavior otherwise.
311#[instrument(level = "debug", skip_all)]
312fn first_half_general_oop<F: Field>(
313    src: &RowMajorMatrixView<F>,
314    dst_maybe: &mut RowMajorMatrixViewMut<MaybeUninit<F>>,
315    mid: usize,
316    twiddles: &[Vec<F>],
317) {
318    let log_h = log2_strict_usize(src.height());
319    src.par_row_chunks_exact(1 << mid)
320        .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
321        .for_each(|(src_submat, mut dst_submat_maybe)| {
322            debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
323
324            // The first layer is special, done out-of-place.
325            // (Recall from the mid definition that there must be at least one layer here.)
326            let layer_rev = log_h - 1;
327            dit_layer_oop(
328                &src_submat,
329                &mut dst_submat_maybe,
330                0,
331                twiddles[layer_rev].iter().copied(),
332            );
333
334            // submat is now initialized.
335            let mut dst_submat = unsafe {
336                transmute::<RowMajorMatrixViewMut<MaybeUninit<F>>, RowMajorMatrixViewMut<F>>(
337                    dst_submat_maybe,
338                )
339            };
340
341            // Subsequent layers.
342            let mut backwards = true;
343            for layer in 1..mid {
344                let layer_rev = log_h - 1 - layer;
345                dit_layer(
346                    &mut dst_submat,
347                    layer,
348                    twiddles[layer_rev].iter().copied(),
349                    backwards,
350                );
351                backwards = !backwards;
352            }
353        });
354}
355
356/// This can be used as the second half of a DIT butterfly network. It works in bit-reversed order.
357///
358/// The optional `scale` parameter is used to scale the matrix by a constant factor. Normally that
359/// would be a separate step, but it's best to merge it into a butterfly network a just to avoid a
360/// separate pass through main memory.
361#[instrument(level = "debug", skip_all)]
362#[inline(always)] // To avoid branch on scale
363fn second_half<F: Field>(
364    mat: &mut RowMajorMatrix<F>,
365    mid: usize,
366    twiddles_rev: &[F],
367    scale: Option<F>,
368) {
369    let log_h = log2_strict_usize(mat.height());
370
371    // max block size: 2^(log_h - mid)
372    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
373        .enumerate()
374        .for_each(|(thread, mut submat)| {
375            let mut backwards = false;
376            if let Some(scale) = scale {
377                submat.scale(scale);
378            }
379            for layer in mid..log_h {
380                let first_block = thread << (layer - mid);
381                dit_layer_rev(
382                    &mut submat,
383                    log_h,
384                    layer,
385                    twiddles_rev[first_block..].iter().copied(),
386                    backwards,
387                );
388                backwards = !backwards;
389            }
390        });
391}
392
393/// Like `second_half`, except supporting different twiddle factors per layer, enabling coset shifts
394/// to be baked into them.
395#[instrument(level = "debug", skip_all)]
396fn second_half_general<F: Field>(
397    mat: &mut RowMajorMatrixViewMut<F>,
398    mid: usize,
399    twiddles_rev: &[Vec<F>],
400) {
401    let log_h = log2_strict_usize(mat.height());
402    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
403        .enumerate()
404        .for_each(|(thread, mut submat)| {
405            let mut backwards = false;
406            for layer in mid..log_h {
407                let layer_rev = log_h - 1 - layer;
408                let first_block = thread << (layer - mid);
409                dit_layer_rev(
410                    &mut submat,
411                    log_h,
412                    layer,
413                    twiddles_rev[layer_rev][first_block..].iter().copied(),
414                    backwards,
415                );
416                backwards = !backwards;
417            }
418        });
419}
420
421/// One layer of a DIT butterfly network.
422fn dit_layer<F: Field>(
423    submat: &mut RowMajorMatrixViewMut<'_, F>,
424    layer: usize,
425    twiddles: impl Iterator<Item = F> + Clone,
426    backwards: bool,
427) {
428    let half_block_size = 1 << layer;
429    let block_size = half_block_size * 2;
430    let width = submat.width();
431    debug_assert!(submat.height() >= block_size);
432
433    let process_block = |block: &mut [F]| {
434        let (lows, highs) = block.split_at_mut(half_block_size * width);
435
436        for (lo, hi, twiddle) in izip!(
437            lows.chunks_mut(width),
438            highs.chunks_mut(width),
439            twiddles.clone()
440        ) {
441            DitButterfly(twiddle).apply_to_rows(lo, hi);
442        }
443    };
444
445    let blocks = submat.values.chunks_mut(block_size * width);
446    if backwards {
447        for block in blocks.rev() {
448            process_block(block);
449        }
450    } else {
451        for block in blocks {
452            process_block(block);
453        }
454    }
455}
456
457/// One layer of a DIT butterfly network.
458fn dit_layer_oop<F: Field>(
459    src: &RowMajorMatrixView<F>,
460    dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
461    layer: usize,
462    twiddles: impl Iterator<Item = F> + Clone,
463) {
464    debug_assert_eq!(src.dimensions(), dst.dimensions());
465    let half_block_size = 1 << layer;
466    let block_size = half_block_size * 2;
467    let width = dst.width();
468    debug_assert!(dst.height() >= block_size);
469
470    let src_chunks = src.values.chunks(block_size * width);
471    let dst_chunks = dst.values.chunks_mut(block_size * width);
472    for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
473        let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
474        let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
475
476        for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
477            src_lows.chunks(width),
478            dst_lows.chunks_mut(width),
479            src_highs.chunks(width),
480            dst_highs.chunks_mut(width),
481            twiddles.clone()
482        ) {
483            DitButterfly(twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
484        }
485    }
486}
487
488/// Like `dit_layer`, except the matrix and twiddles are encoded in bit-reversed order.
489/// This can also be viewed as a layer of the Bowers G^T network.
490fn dit_layer_rev<F: Field>(
491    submat: &mut RowMajorMatrixViewMut<'_, F>,
492    log_h: usize,
493    layer: usize,
494    twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
495    backwards: bool,
496) {
497    let layer_rev = log_h - 1 - layer;
498
499    let half_block_size = 1 << layer_rev;
500    let block_size = half_block_size * 2;
501    let width = submat.width();
502    debug_assert!(submat.height() >= block_size);
503
504    let blocks_and_twiddles = submat
505        .values
506        .chunks_mut(block_size * width)
507        .zip(twiddles_rev);
508    if backwards {
509        for (block, twiddle) in blocks_and_twiddles.rev() {
510            let (lo, hi) = block.split_at_mut(half_block_size * width);
511            DitButterfly(twiddle).apply_to_rows(lo, hi)
512        }
513    } else {
514        for (block, twiddle) in blocks_and_twiddles {
515            let (lo, hi) = block.split_at_mut(half_block_size * width);
516            DitButterfly(twiddle).apply_to_rows(lo, hi)
517        }
518    }
519}