1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::vec::Vec;
4use core::cell::RefCell;
5use core::mem::{transmute, MaybeUninit};
6
7use itertools::{izip, Itertools};
8use p3_field::{Field, Powers, TwoAdicField};
9use p3_matrix::bitrev::{BitReversableMatrix, BitReversalPerm, BitReversedMatrixView};
10use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
11use p3_matrix::util::reverse_matrix_index_bits;
12use p3_matrix::Matrix;
13use p3_maybe_rayon::prelude::*;
14use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
15use tracing::{debug_span, instrument};
16
17use crate::butterflies::{Butterfly, DitButterfly};
18use crate::TwoAdicSubgroupDft;
19
20#[derive(Default, Clone, Debug)]
28pub struct Radix2DitParallel<F> {
29 twiddles: RefCell<BTreeMap<usize, VectorPair<F>>>,
31
32 #[allow(clippy::type_complexity)]
34 coset_twiddles: RefCell<BTreeMap<(usize, F), Vec<Vec<F>>>>,
35
36 inverse_twiddles: RefCell<BTreeMap<usize, VectorPair<F>>>,
38}
39
40#[derive(Default, Clone, Debug)]
42struct VectorPair<F> {
43 twiddles: Vec<F>,
44 bitrev_twiddles: Vec<F>,
45}
46
47#[instrument(level = "debug", skip_all)]
48fn compute_twiddles<F: TwoAdicField + Ord>(log_h: usize) -> VectorPair<F> {
49 let half_h = (1 << log_h) >> 1;
50 let root = F::two_adic_generator(log_h);
51 let twiddles: Vec<F> = root.powers().take(half_h).collect();
52 let mut bit_reversed_twiddles = twiddles.clone();
53 reverse_slice_index_bits(&mut bit_reversed_twiddles);
54 VectorPair {
55 twiddles,
56 bitrev_twiddles: bit_reversed_twiddles,
57 }
58}
59
60#[instrument(level = "debug", skip_all)]
61fn compute_coset_twiddles<F: TwoAdicField + Ord>(log_h: usize, shift: F) -> Vec<Vec<F>> {
62 let mid = log_h.div_ceil(2);
66 let h = 1 << log_h;
67 let root = F::two_adic_generator(log_h);
68
69 (0..log_h)
70 .map(|layer| {
71 let shift_power = shift.exp_power_of_2(layer);
72 let powers = Powers {
73 base: root.exp_power_of_2(layer),
74 current: shift_power,
75 };
76 let mut twiddles: Vec<_> = powers.take(h >> (layer + 1)).collect();
77 let layer_rev = log_h - 1 - layer;
78 if layer_rev >= mid {
79 reverse_slice_index_bits(&mut twiddles);
80 }
81 twiddles
82 })
83 .collect()
84}
85
86#[instrument(level = "debug", skip_all)]
87fn compute_inverse_twiddles<F: TwoAdicField + Ord>(log_h: usize) -> VectorPair<F> {
88 let half_h = (1 << log_h) >> 1;
89 let root_inv = F::two_adic_generator(log_h).inverse();
90 let twiddles: Vec<F> = root_inv.powers().take(half_h).collect();
91 let mut bit_reversed_twiddles = twiddles.clone();
92
93 reverse_slice_index_bits(&mut bit_reversed_twiddles);
95
96 VectorPair {
97 twiddles,
98 bitrev_twiddles: bit_reversed_twiddles,
99 }
100}
101
102impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
103 type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
104
105 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
106 let h = mat.height();
107 let log_h = log2_strict_usize(h);
108
109 let mut twiddles_ref_mut = self.twiddles.borrow_mut();
111 let twiddles = twiddles_ref_mut
112 .entry(log_h)
113 .or_insert_with(|| compute_twiddles(log_h));
114
115 let mid = log_h.div_ceil(2);
116
117 reverse_matrix_index_bits(&mut mat);
119 first_half(&mut mat, mid, &twiddles.twiddles);
120
121 reverse_matrix_index_bits(&mut mat);
123 second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
124
125 mat.bit_reverse_rows()
126 }
127
128 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits = added_bits))]
129 fn coset_lde_batch(
130 &self,
131 mut mat: RowMajorMatrix<F>,
132 added_bits: usize,
133 shift: F,
134 ) -> Self::Evaluations {
135 let w = mat.width;
136 let h = mat.height();
137 let log_h = log2_strict_usize(h);
138 let mid = log_h.div_ceil(2);
139
140 let mut inverse_twiddles_ref_mut = self.inverse_twiddles.borrow_mut();
141 let inverse_twiddles = inverse_twiddles_ref_mut
142 .entry(log_h)
143 .or_insert_with(|| compute_inverse_twiddles(log_h));
144
145 reverse_matrix_index_bits(&mut mat);
147 first_half(&mut mat, mid, &inverse_twiddles.twiddles);
148
149 reverse_matrix_index_bits(&mut mat);
151 let scale = Some(F::from_canonical_usize(h).inverse());
153 second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
154 let lde_elems = w * (h << added_bits);
157 let elems_to_add = lde_elems - w * h;
158 debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
159
160 let g_big = F::two_adic_generator(log_h + added_bits);
161
162 let mat_ptr = mat.values.as_mut_ptr();
163 let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
164 let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
165 let rest_slice: &mut [MaybeUninit<F>] =
166 unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
167 let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
168 let mut rest_cosets_mat = rest_slice
169 .chunks_exact_mut(w * h)
170 .map(|slice| RowMajorMatrixViewMut::new(slice, w))
171 .collect_vec();
172
173 for coset_idx in 1..(1 << added_bits) {
174 let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
175 let coset_idx = reverse_bits_len(coset_idx, added_bits);
176 let dest = &mut rest_cosets_mat[coset_idx - 1]; coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
178 }
179
180 coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
182
183 unsafe {
185 mat.values.set_len(lde_elems);
186 }
187 BitReversalPerm::new_view(mat)
188 }
189}
190
191#[instrument(level = "debug", skip_all)]
192fn coset_dft<F: TwoAdicField + Ord>(
193 dft: &Radix2DitParallel<F>,
194 mat: &mut RowMajorMatrixViewMut<F>,
195 shift: F,
196) {
197 let log_h = log2_strict_usize(mat.height());
198 let mid = log_h.div_ceil(2);
199
200 let mut twiddles_ref_mut = dft.coset_twiddles.borrow_mut();
201 let twiddles = twiddles_ref_mut
202 .entry((log_h, shift))
203 .or_insert_with(|| compute_coset_twiddles(log_h, shift));
204
205 first_half_general(mat, mid, twiddles);
207
208 reverse_matrix_index_bits(mat);
210
211 second_half_general(mat, mid, twiddles);
212}
213
214#[instrument(level = "debug", skip_all)]
216fn coset_dft_oop<F: TwoAdicField + Ord>(
217 dft: &Radix2DitParallel<F>,
218 src: &RowMajorMatrixView<F>,
219 dst_maybe: &mut RowMajorMatrixViewMut<MaybeUninit<F>>,
220 shift: F,
221) {
222 assert_eq!(src.dimensions(), dst_maybe.dimensions());
223
224 let log_h = log2_strict_usize(dst_maybe.height());
225
226 if log_h == 0 {
227 let src_maybe = unsafe {
230 transmute::<&RowMajorMatrixView<F>, &RowMajorMatrixView<MaybeUninit<F>>>(src)
231 };
232 dst_maybe.copy_from(src_maybe);
233 return;
234 }
235
236 let mid = log_h.div_ceil(2);
237
238 let mut twiddles_ref_mut = dft.coset_twiddles.borrow_mut();
239 let twiddles = twiddles_ref_mut
240 .entry((log_h, shift))
241 .or_insert_with(|| compute_coset_twiddles(log_h, shift));
242
243 first_half_general_oop(src, dst_maybe, mid, twiddles);
245
246 let dst = unsafe {
248 transmute::<&mut RowMajorMatrixViewMut<MaybeUninit<F>>, &mut RowMajorMatrixViewMut<F>>(
249 dst_maybe,
250 )
251 };
252
253 reverse_matrix_index_bits(dst);
255
256 second_half_general(dst, mid, twiddles);
257}
258
259#[instrument(level = "debug", skip_all)]
261fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
262 let log_h = log2_strict_usize(mat.height());
263
264 mat.par_row_chunks_exact_mut(1 << mid)
266 .for_each(|mut submat| {
267 let mut backwards = false;
268 for layer in 0..mid {
269 let layer_rev = log_h - 1 - layer;
270 let layer_pow = 1 << layer_rev;
271 dit_layer(
272 &mut submat,
273 layer,
274 twiddles.iter().copied().step_by(layer_pow),
275 backwards,
276 );
277 backwards = !backwards;
278 }
279 });
280}
281
282#[instrument(level = "debug", skip_all)]
285fn first_half_general<F: Field>(
286 mat: &mut RowMajorMatrixViewMut<F>,
287 mid: usize,
288 twiddles: &[Vec<F>],
289) {
290 let log_h = log2_strict_usize(mat.height());
291 mat.par_row_chunks_exact_mut(1 << mid)
292 .for_each(|mut submat| {
293 let mut backwards = false;
294 for layer in 0..mid {
295 let layer_rev = log_h - 1 - layer;
296 dit_layer(
297 &mut submat,
298 layer,
299 twiddles[layer_rev].iter().copied(),
300 backwards,
301 );
302 backwards = !backwards;
303 }
304 });
305}
306
307#[instrument(level = "debug", skip_all)]
312fn first_half_general_oop<F: Field>(
313 src: &RowMajorMatrixView<F>,
314 dst_maybe: &mut RowMajorMatrixViewMut<MaybeUninit<F>>,
315 mid: usize,
316 twiddles: &[Vec<F>],
317) {
318 let log_h = log2_strict_usize(src.height());
319 src.par_row_chunks_exact(1 << mid)
320 .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
321 .for_each(|(src_submat, mut dst_submat_maybe)| {
322 debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
323
324 let layer_rev = log_h - 1;
327 dit_layer_oop(
328 &src_submat,
329 &mut dst_submat_maybe,
330 0,
331 twiddles[layer_rev].iter().copied(),
332 );
333
334 let mut dst_submat = unsafe {
336 transmute::<RowMajorMatrixViewMut<MaybeUninit<F>>, RowMajorMatrixViewMut<F>>(
337 dst_submat_maybe,
338 )
339 };
340
341 let mut backwards = true;
343 for layer in 1..mid {
344 let layer_rev = log_h - 1 - layer;
345 dit_layer(
346 &mut dst_submat,
347 layer,
348 twiddles[layer_rev].iter().copied(),
349 backwards,
350 );
351 backwards = !backwards;
352 }
353 });
354}
355
356#[instrument(level = "debug", skip_all)]
362#[inline(always)] fn second_half<F: Field>(
364 mat: &mut RowMajorMatrix<F>,
365 mid: usize,
366 twiddles_rev: &[F],
367 scale: Option<F>,
368) {
369 let log_h = log2_strict_usize(mat.height());
370
371 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
373 .enumerate()
374 .for_each(|(thread, mut submat)| {
375 let mut backwards = false;
376 if let Some(scale) = scale {
377 submat.scale(scale);
378 }
379 for layer in mid..log_h {
380 let first_block = thread << (layer - mid);
381 dit_layer_rev(
382 &mut submat,
383 log_h,
384 layer,
385 twiddles_rev[first_block..].iter().copied(),
386 backwards,
387 );
388 backwards = !backwards;
389 }
390 });
391}
392
393#[instrument(level = "debug", skip_all)]
396fn second_half_general<F: Field>(
397 mat: &mut RowMajorMatrixViewMut<F>,
398 mid: usize,
399 twiddles_rev: &[Vec<F>],
400) {
401 let log_h = log2_strict_usize(mat.height());
402 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
403 .enumerate()
404 .for_each(|(thread, mut submat)| {
405 let mut backwards = false;
406 for layer in mid..log_h {
407 let layer_rev = log_h - 1 - layer;
408 let first_block = thread << (layer - mid);
409 dit_layer_rev(
410 &mut submat,
411 log_h,
412 layer,
413 twiddles_rev[layer_rev][first_block..].iter().copied(),
414 backwards,
415 );
416 backwards = !backwards;
417 }
418 });
419}
420
421fn dit_layer<F: Field>(
423 submat: &mut RowMajorMatrixViewMut<'_, F>,
424 layer: usize,
425 twiddles: impl Iterator<Item = F> + Clone,
426 backwards: bool,
427) {
428 let half_block_size = 1 << layer;
429 let block_size = half_block_size * 2;
430 let width = submat.width();
431 debug_assert!(submat.height() >= block_size);
432
433 let process_block = |block: &mut [F]| {
434 let (lows, highs) = block.split_at_mut(half_block_size * width);
435
436 for (lo, hi, twiddle) in izip!(
437 lows.chunks_mut(width),
438 highs.chunks_mut(width),
439 twiddles.clone()
440 ) {
441 DitButterfly(twiddle).apply_to_rows(lo, hi);
442 }
443 };
444
445 let blocks = submat.values.chunks_mut(block_size * width);
446 if backwards {
447 for block in blocks.rev() {
448 process_block(block);
449 }
450 } else {
451 for block in blocks {
452 process_block(block);
453 }
454 }
455}
456
457fn dit_layer_oop<F: Field>(
459 src: &RowMajorMatrixView<F>,
460 dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
461 layer: usize,
462 twiddles: impl Iterator<Item = F> + Clone,
463) {
464 debug_assert_eq!(src.dimensions(), dst.dimensions());
465 let half_block_size = 1 << layer;
466 let block_size = half_block_size * 2;
467 let width = dst.width();
468 debug_assert!(dst.height() >= block_size);
469
470 let src_chunks = src.values.chunks(block_size * width);
471 let dst_chunks = dst.values.chunks_mut(block_size * width);
472 for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
473 let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
474 let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
475
476 for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
477 src_lows.chunks(width),
478 dst_lows.chunks_mut(width),
479 src_highs.chunks(width),
480 dst_highs.chunks_mut(width),
481 twiddles.clone()
482 ) {
483 DitButterfly(twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
484 }
485 }
486}
487
488fn dit_layer_rev<F: Field>(
491 submat: &mut RowMajorMatrixViewMut<'_, F>,
492 log_h: usize,
493 layer: usize,
494 twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
495 backwards: bool,
496) {
497 let layer_rev = log_h - 1 - layer;
498
499 let half_block_size = 1 << layer_rev;
500 let block_size = half_block_size * 2;
501 let width = submat.width();
502 debug_assert!(submat.height() >= block_size);
503
504 let blocks_and_twiddles = submat
505 .values
506 .chunks_mut(block_size * width)
507 .zip(twiddles_rev);
508 if backwards {
509 for (block, twiddle) in blocks_and_twiddles.rev() {
510 let (lo, hi) = block.split_at_mut(half_block_size * width);
511 DitButterfly(twiddle).apply_to_rows(lo, hi)
512 }
513 } else {
514 for (block, twiddle) in blocks_and_twiddles {
515 let (lo, hi) = block.split_at_mut(half_block_size * width);
516 DitButterfly(twiddle).apply_to_rows(lo, hi)
517 }
518 }
519}