halo2curves/
fft.rs

1use ff::Field;
2use group::{GroupOpsOwned, ScalarMulOwned};
3
4pub use crate::{CurveAffine, CurveExt};
5
6/// This represents an element of a group with basic operations that can be
7/// performed. This allows an FFT implementation (for example) to operate
8/// generically over either a field or elliptic curve group.
9pub trait FftGroup<Scalar: Field>:
10    Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
11{
12}
13
14impl<T, Scalar> FftGroup<Scalar> for T
15where
16    Scalar: Field,
17    T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
18{
19}
20
21/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size
22/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative
23/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when
24/// interpreted as the coefficients of a polynomial of degree $n - 1$, is
25/// transformed into the evaluations of this polynomial at each of the $n$
26/// distinct powers of $\omega$. This transformation is invertible by providing
27/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element
28/// by $n$.
29///
30/// This will use multithreading if beneficial.
31pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
32    fn bitreverse(mut n: usize, l: usize) -> usize {
33        let mut r = 0;
34        for _ in 0..l {
35            r = (r << 1) | (n & 1);
36            n >>= 1;
37        }
38        r
39    }
40
41    let threads = rayon::current_num_threads();
42    let log_threads = threads.ilog2();
43    let n = a.len();
44    assert_eq!(n, 1 << log_n);
45
46    for k in 0..n {
47        let rk = bitreverse(k, log_n as usize);
48        if k < rk {
49            a.swap(rk, k);
50        }
51    }
52
53    // precompute twiddle factors
54    let twiddles: Vec<_> = (0..(n / 2))
55        .scan(Scalar::ONE, |w, _| {
56            let tw = *w;
57            *w *= &omega;
58            Some(tw)
59        })
60        .collect();
61
62    if log_n <= log_threads {
63        let mut chunk = 2_usize;
64        let mut twiddle_chunk = n / 2;
65        for _ in 0..log_n {
66            a.chunks_mut(chunk).for_each(|coeffs| {
67                let (left, right) = coeffs.split_at_mut(chunk / 2);
68
69                // case when twiddle factor is one
70                let (a, left) = left.split_at_mut(1);
71                let (b, right) = right.split_at_mut(1);
72                let t = b[0];
73                b[0] = a[0];
74                a[0] += &t;
75                b[0] -= &t;
76
77                left.iter_mut()
78                    .zip(right.iter_mut())
79                    .enumerate()
80                    .for_each(|(i, (a, b))| {
81                        let mut t = *b;
82                        t *= &twiddles[(i + 1) * twiddle_chunk];
83                        *b = *a;
84                        *a += &t;
85                        *b -= &t;
86                    });
87            });
88            chunk *= 2;
89            twiddle_chunk /= 2;
90        }
91    } else {
92        recursive_butterfly_arithmetic(a, n, 1, &twiddles)
93    }
94}
95
96/// This perform recursive butterfly arithmetic
97pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
98    a: &mut [G],
99    n: usize,
100    twiddle_chunk: usize,
101    twiddles: &[Scalar],
102) {
103    if n == 2 {
104        let t = a[1];
105        a[1] = a[0];
106        a[0] += &t;
107        a[1] -= &t;
108    } else {
109        let (left, right) = a.split_at_mut(n / 2);
110        rayon::join(
111            || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
112            || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
113        );
114
115        // case when twiddle factor is one
116        let (a, left) = left.split_at_mut(1);
117        let (b, right) = right.split_at_mut(1);
118        let t = b[0];
119        b[0] = a[0];
120        a[0] += &t;
121        b[0] -= &t;
122
123        left.iter_mut()
124            .zip(right.iter_mut())
125            .enumerate()
126            .for_each(|(i, (a, b))| {
127                let mut t = *b;
128                t *= &twiddles[(i + 1) * twiddle_chunk];
129                *b = *a;
130                *a += &t;
131                *b -= &t;
132            });
133    }
134}