p3_monty_31/dft/
backward.rs

1//! Discrete Fourier Transform, in-place, decimation-in-time
2//!
3//! Straightforward recursive algorithm, "unrolled" up to size 256.
4//!
5//! Inspired by Bernstein's djbfft: https://cr.yp.to/djbfft.html
6
7extern crate alloc;
8use alloc::vec::Vec;
9
10use itertools::izip;
11use p3_field::{Field, FieldAlgebra, PackedFieldPow2, PackedValue};
12use p3_util::log2_strict_usize;
13
14use crate::utils::monty_reduce;
15use crate::{FieldParameters, MontyField31, TwoAdicData};
16
17#[inline(always)]
18fn backward_butterfly<T: FieldAlgebra + Copy>(x: T, y: T, roots: T) -> (T, T) {
19    let t = y * roots;
20    (x + t, x - t)
21}
22
23#[inline(always)]
24fn backward_butterfly_interleaved<const HALF_RADIX: usize, T: PackedFieldPow2>(
25    x: T,
26    y: T,
27    roots: T,
28) -> (T, T) {
29    let (x, y) = x.interleave(y, HALF_RADIX);
30    let (x, y) = backward_butterfly(x, y, roots);
31    x.interleave(y, HALF_RADIX)
32}
33
34#[inline]
35fn backward_pass_packed<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
36    let packed_roots = T::pack_slice(roots);
37    let n = input.len();
38    let (xs, ys) = unsafe { input.split_at_mut_unchecked(n / 2) };
39
40    izip!(xs, ys, packed_roots)
41        .for_each(|(x, y, &roots)| (*x, *y) = backward_butterfly(*x, *y, roots));
42}
43
44#[inline]
45fn backward_iterative_layer_1<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
46    let packed_roots = T::pack_slice(roots);
47    let n = input.len();
48    let (top_half, bottom_half) = unsafe { input.split_at_mut_unchecked(n / 2) };
49    let (xs, ys) = unsafe { top_half.split_at_mut_unchecked(n / 4) };
50    let (zs, ws) = unsafe { bottom_half.split_at_mut_unchecked(n / 4) };
51
52    izip!(xs, ys, zs, ws, packed_roots).for_each(|(x, y, z, w, &root)| {
53        (*x, *y) = backward_butterfly(*x, *y, root);
54        (*z, *w) = backward_butterfly(*z, *w, root);
55    });
56}
57
58#[inline]
59fn backward_iterative_packed<const HALF_RADIX: usize, T: PackedFieldPow2>(
60    input: &mut [T],
61    roots: &[T::Scalar],
62) {
63    // roots[0] == 1
64    // roots <-- [1, roots[1], ..., roots[HALF_RADIX-1], 1, roots[1], ...]
65    let roots = T::from_fn(|i| roots[i % HALF_RADIX]);
66
67    input.chunks_exact_mut(2).for_each(|pair| {
68        let (x, y) = backward_butterfly_interleaved::<HALF_RADIX, _>(pair[0], pair[1], roots);
69        pair[0] = x;
70        pair[1] = y;
71    });
72}
73
74#[inline]
75fn backward_iterative_packed_radix_2<T: PackedFieldPow2>(input: &mut [T]) {
76    input.chunks_exact_mut(2).for_each(|pair| {
77        let x = pair[0];
78        let y = pair[1];
79        let (mut x, y) = x.interleave(y, 1);
80        let t = x - y; // roots[0] == 1
81        x += y;
82        let (x, y) = x.interleave(t, 1);
83        pair[0] = x;
84        pair[1] = y;
85    });
86}
87
88impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
89    /// Breadth-first DIT FFT for smallish vectors (must be >= 64)
90    #[inline]
91    fn backward_iterative_layer(
92        packed_input: &mut [<Self as Field>::Packing],
93        roots: &[Self],
94        m: usize,
95    ) {
96        debug_assert_eq!(roots.len(), m);
97        let packed_roots = <Self as Field>::Packing::pack_slice(roots);
98
99        // lg_m >= 4, so m = 2^lg_m >= 2^4, hence packing_width divides m
100        let packed_m = m / <Self as Field>::Packing::WIDTH;
101        packed_input
102            .chunks_exact_mut(2 * packed_m)
103            .for_each(|layer_chunk| {
104                let (xs, ys) = unsafe { layer_chunk.split_at_mut_unchecked(packed_m) };
105
106                izip!(xs, ys, packed_roots)
107                    .for_each(|(x, y, &root)| (*x, *y) = backward_butterfly(*x, *y, root));
108            });
109    }
110
111    #[inline]
112    fn backward_iterative_packed_radix_16(input: &mut [<Self as Field>::Packing]) {
113        // Rather surprisingly, a version similar where the separate
114        // loops in each call to backward_iterative_packed() are
115        // combined into one, was not only not faster, but was
116        // actually a bit slower.
117
118        // Radix 2
119        backward_iterative_packed_radix_2(input);
120
121        // Radix 4
122        let roots4 = [MP::INV_ROOTS_8.as_ref()[0], MP::INV_ROOTS_8.as_ref()[2]];
123        if <Self as Field>::Packing::WIDTH >= 4 {
124            backward_iterative_packed::<2, _>(input, &roots4);
125        } else {
126            Self::backward_iterative_layer(input, &roots4, 2);
127        }
128
129        // Radix 8
130        if <Self as Field>::Packing::WIDTH >= 8 {
131            backward_iterative_packed::<4, _>(input, MP::INV_ROOTS_8.as_ref());
132        } else {
133            Self::backward_iterative_layer(input, MP::INV_ROOTS_8.as_ref(), 4);
134        }
135
136        // Radix 16
137        if <Self as Field>::Packing::WIDTH >= 16 {
138            backward_iterative_packed::<8, _>(input, MP::INV_ROOTS_16.as_ref());
139        } else {
140            Self::backward_iterative_layer(input, MP::INV_ROOTS_16.as_ref(), 8);
141        }
142    }
143
144    fn backward_iterative(packed_input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
145        assert!(packed_input.len() >= 2);
146        let packing_width = <Self as Field>::Packing::WIDTH;
147        let n = packed_input.len() * packing_width;
148        let lg_n = log2_strict_usize(n);
149
150        // Start loop after doing radix 16 separately. This value is determined by the largest
151        // packing width we will encounter, which is 16 at the moment for AVX512. Specifically
152        // it is log_2(max{possible packing widths}) = lg(16) = 4.
153        const FIRST_LOOP_LAYER: usize = 4;
154
155        // How many layers have we specialised after the main loop
156        const NUM_SPECIALISATIONS: usize = 2;
157
158        // Needed to avoid overlap of the 2 specialisations at the start
159        // with the radix-16 specialisation at the end of the loop
160        assert!(lg_n >= FIRST_LOOP_LAYER + NUM_SPECIALISATIONS);
161
162        Self::backward_iterative_packed_radix_16(packed_input);
163
164        for lg_m in FIRST_LOOP_LAYER..(lg_n - NUM_SPECIALISATIONS) {
165            let s = lg_n - lg_m - 1;
166            let m = 1 << lg_m;
167
168            let roots = &root_table[s];
169            debug_assert_eq!(roots.len(), m);
170
171            Self::backward_iterative_layer(packed_input, roots, m);
172        }
173        // Specialise the last few iterations; improves performance a little.
174        backward_iterative_layer_1(packed_input, &root_table[1]); // lg_m == lg_n - 2, s == 1
175        backward_pass_packed(packed_input, &root_table[0]); // lg_m == lg_n - 1, s == 0
176    }
177
178    #[inline]
179    fn backward_pass(input: &mut [Self], roots: &[Self]) {
180        let half_n = input.len() / 2;
181        assert_eq!(roots.len(), half_n);
182
183        // Safe because 0 <= half_n < a.len()
184        let (xs, ys) = unsafe { input.split_at_mut_unchecked(half_n) };
185
186        let s = xs[0] + ys[0];
187        let t = xs[0] - ys[0];
188        xs[0] = s;
189        ys[0] = t;
190
191        izip!(&mut xs[1..], &mut ys[1..], &roots[1..]).for_each(|(x, y, &root)| {
192            (*x, *y) = backward_butterfly(*x, *y, root);
193        });
194    }
195
196    #[inline(always)]
197    fn backward_2(a: &mut [Self]) {
198        assert_eq!(a.len(), 2);
199
200        let s = a[0] + a[1];
201        let t = a[0] - a[1];
202        a[0] = s;
203        a[1] = t;
204    }
205
206    #[inline(always)]
207    fn backward_4(a: &mut [Self]) {
208        assert_eq!(a.len(), 4);
209
210        // Read in bit-reversed order
211        let a0 = a[0];
212        let a2 = a[1];
213        let a1 = a[2];
214        let a3 = a[3];
215
216        // Expanding the calculation of t3 saves one instruction
217        let t1 = MP::PRIME + a1.value - a3.value;
218        let t3 = MontyField31::new_monty(monty_reduce::<MP>(
219            t1 as u64 * MP::INV_ROOTS_8.as_ref()[2].value as u64,
220        ));
221        let t5 = a1 + a3;
222        let t4 = a0 + a2;
223        let t2 = a0 - a2;
224
225        a[0] = t4 + t5;
226        a[1] = t2 + t3;
227        a[2] = t4 - t5;
228        a[3] = t2 - t3;
229    }
230
231    #[inline(always)]
232    fn backward_8(a: &mut [Self]) {
233        assert_eq!(a.len(), 8);
234
235        // Safe because a.len() == 8
236        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
237        Self::backward_4(a0);
238        Self::backward_4(a1);
239
240        Self::backward_pass(a, MP::INV_ROOTS_8.as_ref());
241    }
242
243    #[inline(always)]
244    fn backward_16(a: &mut [Self]) {
245        assert_eq!(a.len(), 16);
246
247        // Safe because a.len() == 16
248        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
249        Self::backward_8(a0);
250        Self::backward_8(a1);
251
252        Self::backward_pass(a, MP::INV_ROOTS_16.as_ref());
253    }
254
255    #[inline(always)]
256    fn backward_32(a: &mut [Self], root_table: &[Vec<Self>]) {
257        assert_eq!(a.len(), 32);
258
259        // Safe because a.len() == 32
260        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
261        Self::backward_16(a0);
262        Self::backward_16(a1);
263
264        Self::backward_pass(a, &root_table[0]);
265    }
266
267    /// Assumes `input.len() >= 64`.
268    /// current packing widths.
269    #[inline]
270    fn backward_fft_recur(input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
271        const ITERATIVE_FFT_THRESHOLD: usize = 1024;
272
273        let n = input.len() * <Self as Field>::Packing::WIDTH;
274        if n <= ITERATIVE_FFT_THRESHOLD {
275            Self::backward_iterative(input, root_table);
276        } else {
277            assert_eq!(n, 1 << (root_table.len() + 1));
278
279            // Safe because input.len() > ITERATIVE_FFT_THRESHOLD
280            let (a0, a1) = unsafe { input.split_at_mut_unchecked(input.len() / 2) };
281            Self::backward_fft_recur(a0, &root_table[1..]);
282            Self::backward_fft_recur(a1, &root_table[1..]);
283
284            backward_pass_packed(input, &root_table[0]);
285        }
286    }
287
288    #[inline]
289    pub fn backward_fft(input: &mut [Self], root_table: &[Vec<Self>]) {
290        let n = input.len();
291        if n == 1 {
292            return;
293        }
294
295        assert_eq!(n, 1 << (root_table.len() + 1));
296        match n {
297            32 => Self::backward_32(input, root_table),
298            16 => Self::backward_16(input),
299            8 => Self::backward_8(input),
300            4 => Self::backward_4(input),
301            2 => Self::backward_2(input),
302            _ => {
303                let packed_input = <Self as Field>::Packing::pack_slice_mut(input);
304                Self::backward_fft_recur(packed_input, root_table)
305            }
306        }
307    }
308}