p3_monty_31/dft/
forward.rs

1//! Discrete Fourier Transform, in-place, decimation-in-frequency
2//!
3//! Straightforward recursive algorithm, "unrolled" up to size 256.
4//!
5//! Inspired by Bernstein's djbfft: https://cr.yp.to/djbfft.html
6
7extern crate alloc;
8
9use alloc::vec::Vec;
10
11use itertools::izip;
12use p3_field::{Field, FieldAlgebra, PackedFieldPow2, PackedValue, TwoAdicField};
13use p3_util::log2_strict_usize;
14
15use crate::utils::monty_reduce;
16use crate::{FieldParameters, MontyField31, TwoAdicData};
17
18impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
19    /// Given a field element `gen` of order n where `n = 2^lg_n`,
20    /// return a vector of vectors `table` where table[i] is the
21    /// vector of twiddle factors for an fft of length n/2^i. The
22    /// values g_i^k for k >= i/2 are skipped as these are just the
23    /// negatives of the other roots (using g_i^{i/2} = -1).  The
24    /// value gen^0 = 1 is included to aid consistency between the
25    /// packed and non-packed variants.
26    pub fn roots_of_unity_table(n: usize) -> Vec<Vec<Self>> {
27        let lg_n = log2_strict_usize(n);
28        let gen = Self::two_adic_generator(lg_n);
29        let half_n = 1 << (lg_n - 1);
30        // nth_roots = [1, g, g^2, g^3, ..., g^{n/2 - 1}]
31        let nth_roots: Vec<_> = gen.powers().take(half_n).collect();
32
33        (0..(lg_n - 1))
34            .map(|i| nth_roots.iter().step_by(1 << i).copied().collect())
35            .collect()
36    }
37}
38
39#[inline(always)]
40fn forward_butterfly<T: FieldAlgebra + Copy>(x: T, y: T, roots: T) -> (T, T) {
41    let t = x - y;
42    (x + y, t * roots)
43}
44
45#[inline(always)]
46fn forward_butterfly_interleaved<const HALF_RADIX: usize, T: PackedFieldPow2>(
47    x: T,
48    y: T,
49    roots: T,
50) -> (T, T) {
51    let (x, y) = x.interleave(y, HALF_RADIX);
52    let (x, y) = forward_butterfly(x, y, roots);
53    x.interleave(y, HALF_RADIX)
54}
55
56#[inline]
57fn forward_pass_packed<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
58    let packed_roots = T::pack_slice(roots);
59    let n = input.len();
60    let (xs, ys) = unsafe { input.split_at_mut_unchecked(n / 2) };
61
62    izip!(xs, ys, packed_roots)
63        .for_each(|(x, y, &roots)| (*x, *y) = forward_butterfly(*x, *y, roots));
64}
65
66#[inline]
67fn forward_iterative_layer_1<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
68    let packed_roots = T::pack_slice(roots);
69    let n = input.len();
70    let (top_half, bottom_half) = unsafe { input.split_at_mut_unchecked(n / 2) };
71    let (xs, ys) = unsafe { top_half.split_at_mut_unchecked(n / 4) };
72    let (zs, ws) = unsafe { bottom_half.split_at_mut_unchecked(n / 4) };
73
74    izip!(xs, ys, zs, ws, packed_roots).for_each(|(x, y, z, w, &root)| {
75        (*x, *y) = forward_butterfly(*x, *y, root);
76        (*z, *w) = forward_butterfly(*z, *w, root);
77    });
78}
79
80#[inline]
81fn forward_iterative_packed<const HALF_RADIX: usize, T: PackedFieldPow2>(
82    input: &mut [T],
83    roots: &[T::Scalar],
84) {
85    // roots[0] == 1
86    // roots <-- [1, roots[1], ..., roots[HALF_RADIX-1], 1, roots[1], ...]
87    let roots = T::from_fn(|i| roots[i % HALF_RADIX]);
88
89    input.chunks_exact_mut(2).for_each(|pair| {
90        let (x, y) = forward_butterfly_interleaved::<HALF_RADIX, _>(pair[0], pair[1], roots);
91        pair[0] = x;
92        pair[1] = y;
93    });
94}
95
96#[inline]
97fn forward_iterative_packed_radix_2<T: PackedFieldPow2>(input: &mut [T]) {
98    input.chunks_exact_mut(2).for_each(|pair| {
99        let x = pair[0];
100        let y = pair[1];
101        let (mut x, y) = x.interleave(y, 1);
102        let t = x - y; // roots[0] == 1
103        x += y;
104        let (x, y) = x.interleave(t, 1);
105        pair[0] = x;
106        pair[1] = y;
107    });
108}
109
110impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
111    #[inline]
112    fn forward_iterative_layer(
113        packed_input: &mut [<Self as Field>::Packing],
114        roots: &[Self],
115        m: usize,
116    ) {
117        debug_assert_eq!(roots.len(), m);
118        let packed_roots = <Self as Field>::Packing::pack_slice(roots);
119
120        // lg_m >= 4, so m = 2^lg_m >= 2^4, hence packing_width divides m
121        let packed_m = m / <Self as Field>::Packing::WIDTH;
122        packed_input
123            .chunks_exact_mut(2 * packed_m)
124            .for_each(|layer_chunk| {
125                let (xs, ys) = unsafe { layer_chunk.split_at_mut_unchecked(packed_m) };
126
127                izip!(xs, ys, packed_roots)
128                    .for_each(|(x, y, &root)| (*x, *y) = forward_butterfly(*x, *y, root));
129            });
130    }
131
132    #[inline]
133    fn forward_iterative_packed_radix_16(input: &mut [<Self as Field>::Packing]) {
134        // Rather surprisingly, a version similar where the separate
135        // loops in each call to forward_iterative_packed() are
136        // combined into one, was not only not faster, but was
137        // actually a bit slower.
138
139        // Radix 16
140        if <Self as Field>::Packing::WIDTH >= 16 {
141            forward_iterative_packed::<8, _>(input, MP::ROOTS_16.as_ref());
142        } else {
143            Self::forward_iterative_layer(input, MP::ROOTS_16.as_ref(), 8);
144        }
145
146        // Radix 8
147        if <Self as Field>::Packing::WIDTH >= 8 {
148            forward_iterative_packed::<4, _>(input, MP::ROOTS_8.as_ref());
149        } else {
150            Self::forward_iterative_layer(input, MP::ROOTS_8.as_ref(), 4);
151        }
152
153        // Radix 4
154        let roots4 = [MP::ROOTS_8.as_ref()[0], MP::ROOTS_8.as_ref()[2]];
155        if <Self as Field>::Packing::WIDTH >= 4 {
156            forward_iterative_packed::<2, _>(input, &roots4);
157        } else {
158            Self::forward_iterative_layer(input, &roots4, 2);
159        }
160
161        // Radix 2
162        forward_iterative_packed_radix_2(input);
163    }
164
165    /// Breadth-first DIF FFT for smallish vectors (must be >= 64)
166    #[inline]
167    fn forward_iterative(packed_input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
168        assert!(packed_input.len() >= 2);
169        let packing_width = <Self as Field>::Packing::WIDTH;
170        let n = packed_input.len() * packing_width;
171        let lg_n = log2_strict_usize(n);
172
173        // Stop loop early to do radix 16 separately. This value is determined by the largest
174        // packing width we will encounter, which is 16 at the moment for AVX512. Specifically
175        // it is log_2(max{possible packing widths}) = lg(16) = 4.
176        const LAST_LOOP_LAYER: usize = 4;
177
178        // How many layers have we specialised before the main loop
179        const NUM_SPECIALISATIONS: usize = 2;
180
181        // Needed to avoid overlap of the 2 specialisations at the start
182        // with the radix-16 specialisation at the end of the loop
183        assert!(lg_n >= LAST_LOOP_LAYER + NUM_SPECIALISATIONS);
184
185        // Specialise the first NUM_SPECIALISATIONS iterations; improves performance a little.
186        forward_pass_packed(packed_input, &root_table[0]); // lg_m == lg_n - 1, s == 0
187        forward_iterative_layer_1(packed_input, &root_table[1]); // lg_m == lg_n - 2, s == 1
188
189        // loop from lg_n-2 down to 4.
190        for lg_m in (LAST_LOOP_LAYER..(lg_n - NUM_SPECIALISATIONS)).rev() {
191            let s = lg_n - lg_m - 1;
192            let m = 1 << lg_m;
193
194            let roots = &root_table[s];
195            debug_assert_eq!(roots.len(), m);
196
197            Self::forward_iterative_layer(packed_input, roots, m);
198        }
199
200        // Last 4 layers
201        Self::forward_iterative_packed_radix_16(packed_input);
202    }
203
204    #[inline(always)]
205    fn forward_butterfly(x: Self, y: Self, w: Self) -> (Self, Self) {
206        let t = MP::PRIME + x.value - y.value;
207        (
208            x + y,
209            Self::new_monty(monty_reduce::<MP>(t as u64 * w.value as u64)),
210        )
211    }
212
213    #[inline]
214    fn forward_pass(input: &mut [Self], roots: &[Self]) {
215        let half_n = input.len() / 2;
216        assert_eq!(roots.len(), half_n);
217
218        // Safe because 0 <= half_n < a.len()
219        let (xs, ys) = unsafe { input.split_at_mut_unchecked(half_n) };
220
221        let s = xs[0] + ys[0];
222        let t = xs[0] - ys[0];
223        xs[0] = s;
224        ys[0] = t;
225
226        izip!(&mut xs[1..], &mut ys[1..], &roots[1..]).for_each(|(x, y, &root)| {
227            (*x, *y) = Self::forward_butterfly(*x, *y, root);
228        });
229    }
230
231    #[inline(always)]
232    fn forward_2(a: &mut [Self]) {
233        assert_eq!(a.len(), 2);
234
235        let s = a[0] + a[1];
236        let t = a[0] - a[1];
237        a[0] = s;
238        a[1] = t;
239    }
240
241    #[inline(always)]
242    fn forward_4(a: &mut [Self]) {
243        assert_eq!(a.len(), 4);
244
245        // Expanding the calculation of t3 saves one instruction
246        let t1 = MP::PRIME + a[1].value - a[3].value;
247        let t3 = MontyField31::new_monty(monty_reduce::<MP>(
248            t1 as u64 * MP::ROOTS_8.as_ref()[2].value as u64,
249        ));
250        let t5 = a[1] + a[3];
251        let t4 = a[0] + a[2];
252        let t2 = a[0] - a[2];
253
254        // Return in bit-reversed order
255        a[0] = t4 + t5;
256        a[1] = t4 - t5;
257        a[2] = t2 + t3;
258        a[3] = t2 - t3;
259    }
260
261    #[inline(always)]
262    fn forward_8(a: &mut [Self]) {
263        assert_eq!(a.len(), 8);
264
265        Self::forward_pass(a, MP::ROOTS_8.as_ref());
266
267        // Safe because a.len() == 8
268        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
269        Self::forward_4(a0);
270        Self::forward_4(a1);
271    }
272
273    #[inline(always)]
274    fn forward_16(a: &mut [Self]) {
275        assert_eq!(a.len(), 16);
276
277        Self::forward_pass(a, MP::ROOTS_16.as_ref());
278
279        // Safe because a.len() == 16
280        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
281        Self::forward_8(a0);
282        Self::forward_8(a1);
283    }
284
285    #[inline(always)]
286    fn forward_32(a: &mut [Self], root_table: &[Vec<Self>]) {
287        assert_eq!(a.len(), 32);
288
289        Self::forward_pass(a, &root_table[0]);
290
291        // Safe because a.len() == 32
292        let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
293        Self::forward_16(a0);
294        Self::forward_16(a1);
295    }
296
297    /// Assumes `input.len() >= 64`.
298    #[inline]
299    fn forward_fft_recur(input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
300        const ITERATIVE_FFT_THRESHOLD: usize = 1024;
301
302        let n = input.len() * <Self as Field>::Packing::WIDTH;
303        if n <= ITERATIVE_FFT_THRESHOLD {
304            Self::forward_iterative(input, root_table);
305        } else {
306            assert_eq!(n, 1 << (root_table.len() + 1));
307            forward_pass_packed(input, &root_table[0]);
308
309            // Safe because input.len() > ITERATIVE_FFT_THRESHOLD
310            let (a0, a1) = unsafe { input.split_at_mut_unchecked(input.len() / 2) };
311
312            Self::forward_fft_recur(a0, &root_table[1..]);
313            Self::forward_fft_recur(a1, &root_table[1..]);
314        }
315    }
316
317    #[inline]
318    pub fn forward_fft(input: &mut [Self], root_table: &[Vec<Self>]) {
319        let n = input.len();
320        if n == 1 {
321            return;
322        }
323        assert_eq!(n, 1 << (root_table.len() + 1));
324        match n {
325            32 => Self::forward_32(input, root_table),
326            16 => Self::forward_16(input),
327            8 => Self::forward_8(input),
328            4 => Self::forward_4(input),
329            2 => Self::forward_2(input),
330            _ => {
331                let packed_input = <Self as Field>::Packing::pack_slice_mut(input);
332                Self::forward_fft_recur(packed_input, root_table)
333            }
334        }
335    }
336}