p3_monty_31/dft/
mod.rs

1//! An implementation of the FFT for `MontyField31`
2extern crate alloc;
3
4use alloc::vec::Vec;
5use core::cell::RefCell;
6use core::iter;
7
8use itertools::izip;
9use p3_dft::TwoAdicSubgroupDft;
10use p3_field::{Field, FieldAlgebra};
11use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView};
12use p3_matrix::dense::RowMajorMatrix;
13use p3_matrix::Matrix;
14use p3_maybe_rayon::prelude::*;
15use tracing::{debug_span, instrument};
16
17mod backward;
18mod forward;
19
20use crate::{FieldParameters, MontyField31, MontyParameters, TwoAdicData};
21
22/// Multiply each element of column `j` of `mat` by `shift**j`.
23#[instrument(level = "debug", skip_all)]
24fn coset_shift_and_scale_rows<F: Field>(
25    out: &mut [F],
26    out_ncols: usize,
27    mat: &[F],
28    ncols: usize,
29    shift: F,
30    scale: F,
31) {
32    let powers = shift.shifted_powers(scale).take(ncols).collect::<Vec<_>>();
33    out.par_chunks_exact_mut(out_ncols)
34        .zip(mat.par_chunks_exact(ncols))
35        .for_each(|(out_row, in_row)| {
36            izip!(out_row.iter_mut(), in_row, &powers).for_each(|(out, &coeff, &weight)| {
37                *out = coeff * weight;
38            });
39        });
40}
41
42/// Recursive DFT, decimation-in-frequency in the forward direction,
43/// decimation-in-time in the backward (inverse) direction.
44#[derive(Clone, Debug, Default)]
45pub struct RecursiveDft<F> {
46    /// Memoized twiddle factors for each length log_n.
47    ///
48    /// TODO: The use of RefCell means this can't be shared across
49    /// threads; consider using RwLock or finding a better design
50    /// instead.
51    twiddles: RefCell<Vec<Vec<F>>>,
52    inv_twiddles: RefCell<Vec<Vec<F>>>,
53}
54
55impl<MP: FieldParameters + TwoAdicData> RecursiveDft<MontyField31<MP>> {
56    pub fn new(n: usize) -> Self {
57        let res = Self {
58            twiddles: RefCell::default(),
59            inv_twiddles: RefCell::default(),
60        };
61        res.update_twiddles(n);
62        res
63    }
64
65    #[inline]
66    fn decimation_in_freq_dft(
67        mat: &mut [MontyField31<MP>],
68        ncols: usize,
69        twiddles: &[Vec<MontyField31<MP>>],
70    ) {
71        if ncols > 1 {
72            let lg_fft_len = p3_util::log2_ceil_usize(ncols);
73            let roots_idx = (twiddles.len() + 1) - lg_fft_len;
74            let twiddles = &twiddles[roots_idx..];
75
76            mat.par_chunks_exact_mut(ncols)
77                .for_each(|v| MontyField31::forward_fft(v, twiddles))
78        }
79    }
80
81    #[inline]
82    fn decimation_in_time_dft(
83        mat: &mut [MontyField31<MP>],
84        ncols: usize,
85        twiddles: &[Vec<MontyField31<MP>>],
86    ) {
87        if ncols > 1 {
88            let lg_fft_len = p3_util::log2_ceil_usize(ncols);
89            let roots_idx = (twiddles.len() + 1) - lg_fft_len;
90            let twiddles = &twiddles[roots_idx..];
91
92            mat.par_chunks_exact_mut(ncols)
93                .for_each(|v| MontyField31::backward_fft(v, twiddles))
94        }
95    }
96
97    /// Compute twiddle factors, or take memoized ones if already available.
98    #[instrument(skip_all)]
99    fn update_twiddles(&self, fft_len: usize) {
100        // TODO: This recomputes the entire table from scratch if we
101        // need it to be bigger, which is wasteful.
102
103        // As we don't save the twiddles for the final layer where
104        // the only twiddle is 1, roots_of_unity_table(fft_len)
105        // returns a vector of twiddles of length log_2(fft_len) - 1.
106        let curr_max_fft_len = 2 << self.twiddles.borrow().len();
107        if fft_len > curr_max_fft_len {
108            let new_twiddles = MontyField31::roots_of_unity_table(fft_len);
109            // We can obtain the inverse twiddles by reversing and
110            // negating the twiddles.
111            let new_inv_twiddles = new_twiddles
112                .iter()
113                .map(|ts| {
114                    // The first twiddle is still one, we reverse and negate the rest...
115                    iter::once(MontyField31::ONE)
116                        .chain(
117                            ts[1..]
118                                .iter()
119                                .rev()
120                                // A twiddle t is never zero, so negation simplifies
121                                // to P - t.
122                                .map(|&t| MontyField31::new_monty(MP::PRIME - t.value)),
123                        )
124                        .collect()
125                })
126                .collect();
127            self.twiddles.replace(new_twiddles);
128            self.inv_twiddles.replace(new_inv_twiddles);
129        }
130    }
131}
132
133/// DFT implementation that uses DIT for the inverse "backward"
134/// direction and DIF for the "forward" direction.
135///
136/// The API mandates that the LDE is applied column-wise on the
137/// _row-major_ input. This is awkward for memory coherence, so the
138/// algorithm here transposes the input and operates on the rows in
139/// the typical way, then transposes back again for the output. Even
140/// for modestly large inputs, the cost of the two tranposes
141/// outweighed by the improved performance from operating row-wise.
142///
143/// The choice of DIT for inverse and DIF for "forward" transform mean
144/// that a (coset) LDE
145///
146/// - IDFT / zero extend / DFT
147///
148/// expands to
149///
150///   - bit-reverse input
151///   - invDFT DIT
152///     - result is in "correct" order
153///   - coset shift and zero extend result
154///   - DFT DIF on result
155///     - output is bit-reversed, as required for FRI.
156///
157/// Hence the only bit-reversal that needs to take place is on the input.
158///
159impl<MP: MontyParameters + FieldParameters + TwoAdicData> TwoAdicSubgroupDft<MontyField31<MP>>
160    for RecursiveDft<MontyField31<MP>>
161{
162    type Evaluations = BitReversedMatrixView<RowMajorMatrix<MontyField31<MP>>>;
163
164    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
165    fn dft_batch(&self, mut mat: RowMajorMatrix<MontyField31<MP>>) -> Self::Evaluations
166    where
167        MP: MontyParameters + FieldParameters + TwoAdicData,
168    {
169        let nrows = mat.height();
170        let ncols = mat.width();
171        if nrows <= 1 {
172            return mat.bit_reverse_rows();
173        }
174
175        let mut scratch = debug_span!("allocate scratch space")
176            .in_scope(|| RowMajorMatrix::default(nrows, ncols));
177
178        self.update_twiddles(nrows);
179        let twiddles = self.twiddles.borrow();
180
181        // transpose input
182        debug_span!("pre-transpose", nrows, ncols)
183            .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows));
184
185        debug_span!("dft batch", n_dfts = ncols, fft_len = nrows)
186            .in_scope(|| Self::decimation_in_freq_dft(&mut scratch.values, nrows, &twiddles));
187
188        // transpose output
189        debug_span!("post-transpose", nrows = ncols, ncols = nrows)
190            .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols));
191
192        mat.bit_reverse_rows()
193    }
194
195    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
196    fn idft_batch(&self, mat: RowMajorMatrix<MontyField31<MP>>) -> RowMajorMatrix<MontyField31<MP>>
197    where
198        MP: MontyParameters + FieldParameters + TwoAdicData,
199    {
200        let nrows = mat.height();
201        let ncols = mat.width();
202        if nrows <= 1 {
203            return mat;
204        }
205
206        let mut scratch = debug_span!("allocate scratch space")
207            .in_scope(|| RowMajorMatrix::default(nrows, ncols));
208
209        let mut mat =
210            debug_span!("initial bitrev").in_scope(|| mat.bit_reverse_rows().to_row_major_matrix());
211
212        self.update_twiddles(nrows);
213        let inv_twiddles = self.inv_twiddles.borrow();
214
215        // transpose input
216        debug_span!("pre-transpose", nrows, ncols)
217            .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows));
218
219        debug_span!("idft", n_dfts = ncols, fft_len = nrows)
220            .in_scope(|| Self::decimation_in_time_dft(&mut scratch.values, nrows, &inv_twiddles));
221
222        // transpose output
223        debug_span!("post-transpose", nrows = ncols, ncols = nrows)
224            .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols));
225
226        let inv_len = MontyField31::from_canonical_usize(nrows).inverse();
227        debug_span!("scale").in_scope(|| mat.scale(inv_len));
228        mat
229    }
230
231    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
232    fn coset_lde_batch(
233        &self,
234        mat: RowMajorMatrix<MontyField31<MP>>,
235        added_bits: usize,
236        shift: MontyField31<MP>,
237    ) -> Self::Evaluations {
238        let nrows = mat.height();
239        let ncols = mat.width();
240        let result_nrows = nrows << added_bits;
241
242        if nrows == 1 {
243            let dupd_rows = core::iter::repeat(mat.values)
244                .take(result_nrows)
245                .flatten()
246                .collect();
247            return RowMajorMatrix::new(dupd_rows, ncols).bit_reverse_rows();
248        }
249
250        let input_size = nrows * ncols;
251        let output_size = result_nrows * ncols;
252
253        let mat = mat.bit_reverse_rows().to_row_major_matrix();
254
255        // Allocate space for the output and the intermediate state.
256        let (mut output, mut padded) = debug_span!("allocate scratch space").in_scope(|| {
257            // Safety: These are pretty dodgy, but work because MontyField31 is #[repr(transparent)]
258            let output = MontyField31::<MP>::zero_vec(output_size);
259            let padded = MontyField31::<MP>::zero_vec(output_size);
260            (output, padded)
261        });
262
263        // `coeffs` will hold the result of the inverse FFT; use the
264        // output storage as scratch space.
265        let coeffs = &mut output[..input_size];
266
267        debug_span!("pre-transpose", nrows, ncols)
268            .in_scope(|| transpose::transpose(&mat.values, coeffs, ncols, nrows));
269
270        // Apply inverse DFT; result is not yet normalised.
271        self.update_twiddles(result_nrows);
272        let inv_twiddles = self.inv_twiddles.borrow();
273        debug_span!("inverse dft batch", n_dfts = ncols, fft_len = nrows)
274            .in_scope(|| Self::decimation_in_time_dft(coeffs, nrows, &inv_twiddles));
275
276        // At this point the inverse FFT of each column of `mat` appears
277        // as a row in `coeffs`.
278
279        // Normalise inverse DFT and coset shift in one go.
280        let inv_len = MontyField31::from_canonical_usize(nrows).inverse();
281        coset_shift_and_scale_rows(&mut padded, result_nrows, coeffs, nrows, shift, inv_len);
282
283        // `padded` is implicitly zero padded since it was initialised
284        // to zeros when declared above.
285
286        let twiddles = self.twiddles.borrow();
287
288        // Apply DFT
289        debug_span!("dft batch", n_dfts = ncols, fft_len = result_nrows)
290            .in_scope(|| Self::decimation_in_freq_dft(&mut padded, result_nrows, &twiddles));
291
292        // transpose output
293        debug_span!("post-transpose", nrows = ncols, ncols = result_nrows)
294            .in_scope(|| transpose::transpose(&padded, &mut output, result_nrows, ncols));
295
296        RowMajorMatrix::new(output, ncols).bit_reverse_rows()
297    }
298}