1extern crate alloc;
3
4use alloc::vec::Vec;
5use core::cell::RefCell;
6use core::iter;
7
8use itertools::izip;
9use p3_dft::TwoAdicSubgroupDft;
10use p3_field::{Field, FieldAlgebra};
11use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView};
12use p3_matrix::dense::RowMajorMatrix;
13use p3_matrix::Matrix;
14use p3_maybe_rayon::prelude::*;
15use tracing::{debug_span, instrument};
16
17mod backward;
18mod forward;
19
20use crate::{FieldParameters, MontyField31, MontyParameters, TwoAdicData};
21
22#[instrument(level = "debug", skip_all)]
24fn coset_shift_and_scale_rows<F: Field>(
25 out: &mut [F],
26 out_ncols: usize,
27 mat: &[F],
28 ncols: usize,
29 shift: F,
30 scale: F,
31) {
32 let powers = shift.shifted_powers(scale).take(ncols).collect::<Vec<_>>();
33 out.par_chunks_exact_mut(out_ncols)
34 .zip(mat.par_chunks_exact(ncols))
35 .for_each(|(out_row, in_row)| {
36 izip!(out_row.iter_mut(), in_row, &powers).for_each(|(out, &coeff, &weight)| {
37 *out = coeff * weight;
38 });
39 });
40}
41
42#[derive(Clone, Debug, Default)]
45pub struct RecursiveDft<F> {
46 twiddles: RefCell<Vec<Vec<F>>>,
52 inv_twiddles: RefCell<Vec<Vec<F>>>,
53}
54
55impl<MP: FieldParameters + TwoAdicData> RecursiveDft<MontyField31<MP>> {
56 pub fn new(n: usize) -> Self {
57 let res = Self {
58 twiddles: RefCell::default(),
59 inv_twiddles: RefCell::default(),
60 };
61 res.update_twiddles(n);
62 res
63 }
64
65 #[inline]
66 fn decimation_in_freq_dft(
67 mat: &mut [MontyField31<MP>],
68 ncols: usize,
69 twiddles: &[Vec<MontyField31<MP>>],
70 ) {
71 if ncols > 1 {
72 let lg_fft_len = p3_util::log2_ceil_usize(ncols);
73 let roots_idx = (twiddles.len() + 1) - lg_fft_len;
74 let twiddles = &twiddles[roots_idx..];
75
76 mat.par_chunks_exact_mut(ncols)
77 .for_each(|v| MontyField31::forward_fft(v, twiddles))
78 }
79 }
80
81 #[inline]
82 fn decimation_in_time_dft(
83 mat: &mut [MontyField31<MP>],
84 ncols: usize,
85 twiddles: &[Vec<MontyField31<MP>>],
86 ) {
87 if ncols > 1 {
88 let lg_fft_len = p3_util::log2_ceil_usize(ncols);
89 let roots_idx = (twiddles.len() + 1) - lg_fft_len;
90 let twiddles = &twiddles[roots_idx..];
91
92 mat.par_chunks_exact_mut(ncols)
93 .for_each(|v| MontyField31::backward_fft(v, twiddles))
94 }
95 }
96
97 #[instrument(skip_all)]
99 fn update_twiddles(&self, fft_len: usize) {
100 let curr_max_fft_len = 2 << self.twiddles.borrow().len();
107 if fft_len > curr_max_fft_len {
108 let new_twiddles = MontyField31::roots_of_unity_table(fft_len);
109 let new_inv_twiddles = new_twiddles
112 .iter()
113 .map(|ts| {
114 iter::once(MontyField31::ONE)
116 .chain(
117 ts[1..]
118 .iter()
119 .rev()
120 .map(|&t| MontyField31::new_monty(MP::PRIME - t.value)),
123 )
124 .collect()
125 })
126 .collect();
127 self.twiddles.replace(new_twiddles);
128 self.inv_twiddles.replace(new_inv_twiddles);
129 }
130 }
131}
132
133impl<MP: MontyParameters + FieldParameters + TwoAdicData> TwoAdicSubgroupDft<MontyField31<MP>>
160 for RecursiveDft<MontyField31<MP>>
161{
162 type Evaluations = BitReversedMatrixView<RowMajorMatrix<MontyField31<MP>>>;
163
164 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
165 fn dft_batch(&self, mut mat: RowMajorMatrix<MontyField31<MP>>) -> Self::Evaluations
166 where
167 MP: MontyParameters + FieldParameters + TwoAdicData,
168 {
169 let nrows = mat.height();
170 let ncols = mat.width();
171 if nrows <= 1 {
172 return mat.bit_reverse_rows();
173 }
174
175 let mut scratch = debug_span!("allocate scratch space")
176 .in_scope(|| RowMajorMatrix::default(nrows, ncols));
177
178 self.update_twiddles(nrows);
179 let twiddles = self.twiddles.borrow();
180
181 debug_span!("pre-transpose", nrows, ncols)
183 .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows));
184
185 debug_span!("dft batch", n_dfts = ncols, fft_len = nrows)
186 .in_scope(|| Self::decimation_in_freq_dft(&mut scratch.values, nrows, &twiddles));
187
188 debug_span!("post-transpose", nrows = ncols, ncols = nrows)
190 .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols));
191
192 mat.bit_reverse_rows()
193 }
194
195 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
196 fn idft_batch(&self, mat: RowMajorMatrix<MontyField31<MP>>) -> RowMajorMatrix<MontyField31<MP>>
197 where
198 MP: MontyParameters + FieldParameters + TwoAdicData,
199 {
200 let nrows = mat.height();
201 let ncols = mat.width();
202 if nrows <= 1 {
203 return mat;
204 }
205
206 let mut scratch = debug_span!("allocate scratch space")
207 .in_scope(|| RowMajorMatrix::default(nrows, ncols));
208
209 let mut mat =
210 debug_span!("initial bitrev").in_scope(|| mat.bit_reverse_rows().to_row_major_matrix());
211
212 self.update_twiddles(nrows);
213 let inv_twiddles = self.inv_twiddles.borrow();
214
215 debug_span!("pre-transpose", nrows, ncols)
217 .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows));
218
219 debug_span!("idft", n_dfts = ncols, fft_len = nrows)
220 .in_scope(|| Self::decimation_in_time_dft(&mut scratch.values, nrows, &inv_twiddles));
221
222 debug_span!("post-transpose", nrows = ncols, ncols = nrows)
224 .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols));
225
226 let inv_len = MontyField31::from_canonical_usize(nrows).inverse();
227 debug_span!("scale").in_scope(|| mat.scale(inv_len));
228 mat
229 }
230
231 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
232 fn coset_lde_batch(
233 &self,
234 mat: RowMajorMatrix<MontyField31<MP>>,
235 added_bits: usize,
236 shift: MontyField31<MP>,
237 ) -> Self::Evaluations {
238 let nrows = mat.height();
239 let ncols = mat.width();
240 let result_nrows = nrows << added_bits;
241
242 if nrows == 1 {
243 let dupd_rows = core::iter::repeat(mat.values)
244 .take(result_nrows)
245 .flatten()
246 .collect();
247 return RowMajorMatrix::new(dupd_rows, ncols).bit_reverse_rows();
248 }
249
250 let input_size = nrows * ncols;
251 let output_size = result_nrows * ncols;
252
253 let mat = mat.bit_reverse_rows().to_row_major_matrix();
254
255 let (mut output, mut padded) = debug_span!("allocate scratch space").in_scope(|| {
257 let output = MontyField31::<MP>::zero_vec(output_size);
259 let padded = MontyField31::<MP>::zero_vec(output_size);
260 (output, padded)
261 });
262
263 let coeffs = &mut output[..input_size];
266
267 debug_span!("pre-transpose", nrows, ncols)
268 .in_scope(|| transpose::transpose(&mat.values, coeffs, ncols, nrows));
269
270 self.update_twiddles(result_nrows);
272 let inv_twiddles = self.inv_twiddles.borrow();
273 debug_span!("inverse dft batch", n_dfts = ncols, fft_len = nrows)
274 .in_scope(|| Self::decimation_in_time_dft(coeffs, nrows, &inv_twiddles));
275
276 let inv_len = MontyField31::from_canonical_usize(nrows).inverse();
281 coset_shift_and_scale_rows(&mut padded, result_nrows, coeffs, nrows, shift, inv_len);
282
283 let twiddles = self.twiddles.borrow();
287
288 debug_span!("dft batch", n_dfts = ncols, fft_len = result_nrows)
290 .in_scope(|| Self::decimation_in_freq_dft(&mut padded, result_nrows, &twiddles));
291
292 debug_span!("post-transpose", nrows = ncols, ncols = result_nrows)
294 .in_scope(|| transpose::transpose(&padded, &mut output, result_nrows, ncols));
295
296 RowMajorMatrix::new(output, ncols).bit_reverse_rows()
297 }
298}