halo2_axiom/fft/
baseline.rs

1//! This contains the baseline FFT implementation
2
3use ff::Field;
4
5use super::recursive::FFTData;
6use crate::{
7    arithmetic::{self, log2_floor, FftGroup},
8    multicore,
9};
10
11/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size
12/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative
13/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when
14/// interpreted as the coefficients of a polynomial of degree $n - 1$, is
15/// transformed into the evaluations of this polynomial at each of the $n$
16/// distinct powers of $\omega$. This transformation is invertible by providing
17/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element
18/// by $n$.
19///
20/// This will use multithreading if beneficial.
21fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
22    let threads = multicore::current_num_threads();
23    let log_threads = log2_floor(threads);
24    let n = a.len();
25    assert_eq!(n, 1 << log_n);
26
27    for k in 0..n {
28        let rk = arithmetic::bitreverse(k, log_n as usize);
29        if k < rk {
30            a.swap(rk, k);
31        }
32    }
33
34    //let start = start_measure(format!("twiddles {} ({})", a.len(), threads), false);
35    // precompute twiddle factors
36    let twiddles: Vec<_> = (0..(n / 2))
37        .scan(Scalar::ONE, |w, _| {
38            let tw = *w;
39            *w *= &omega;
40            Some(tw)
41        })
42        .collect();
43    //stop_measure(start);
44
45    if log_n <= log_threads {
46        let mut chunk = 2_usize;
47        let mut twiddle_chunk = n / 2;
48        for _ in 0..log_n {
49            a.chunks_mut(chunk).for_each(|coeffs| {
50                let (left, right) = coeffs.split_at_mut(chunk / 2);
51
52                // case when twiddle factor is one
53                let (a, left) = left.split_at_mut(1);
54                let (b, right) = right.split_at_mut(1);
55                let t = b[0];
56                b[0] = a[0];
57                a[0] += &t;
58                b[0] -= &t;
59
60                left.iter_mut()
61                    .zip(right.iter_mut())
62                    .enumerate()
63                    .for_each(|(i, (a, b))| {
64                        let mut t = *b;
65                        t *= &twiddles[(i + 1) * twiddle_chunk];
66                        *b = *a;
67                        *a += &t;
68                        *b -= &t;
69                    });
70            });
71            chunk *= 2;
72            twiddle_chunk /= 2;
73        }
74    } else {
75        recursive_butterfly_arithmetic(a, n, 1, &twiddles)
76    }
77}
78
79/// This perform recursive butterfly arithmetic
80fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
81    a: &mut [G],
82    n: usize,
83    twiddle_chunk: usize,
84    twiddles: &[Scalar],
85) {
86    if n == 2 {
87        let t = a[1];
88        a[1] = a[0];
89        a[0] += &t;
90        a[1] -= &t;
91    } else {
92        let (left, right) = a.split_at_mut(n / 2);
93        multicore::join(
94            || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
95            || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
96        );
97
98        // case when twiddle factor is one
99        let (a, left) = left.split_at_mut(1);
100        let (b, right) = right.split_at_mut(1);
101        let t = b[0];
102        b[0] = a[0];
103        a[0] += &t;
104        b[0] -= &t;
105
106        left.iter_mut()
107            .zip(right.iter_mut())
108            .enumerate()
109            .for_each(|(i, (a, b))| {
110                let mut t = *b;
111                t *= &twiddles[(i + 1) * twiddle_chunk];
112                *b = *a;
113                *a += &t;
114                *b -= &t;
115            });
116    }
117}
118
119/// Generic adaptor
120pub fn fft<Scalar: Field, G: FftGroup<Scalar>>(
121    a: &mut [G],
122    omega: Scalar,
123    log_n: u32,
124    _data: &FFTData<Scalar>,
125    _inverse: bool,
126) {
127    best_fft(a, omega, log_n)
128}