halo2curves_axiom/
fft.rs

1pub use crate::{CurveAffine, CurveExt};
2use ff::Field;
3use group::{GroupOpsOwned, ScalarMulOwned};
4
5/// This represents an element of a group with basic operations that can be
6/// performed. This allows an FFT implementation (for example) to operate
7/// generically over either a field or elliptic curve group.
8pub trait FftGroup<Scalar: Field>:
9    Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
10{
11}
12
13impl<T, Scalar> FftGroup<Scalar> for T
14where
15    Scalar: Field,
16    T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
17{
18}
19
20/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size
21/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative
22/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when
23/// interpreted as the coefficients of a polynomial of degree $n - 1$, is
24/// transformed into the evaluations of this polynomial at each of the $n$
25/// distinct powers of $\omega$. This transformation is invertible by providing
26/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element
27/// by $n$.
28///
29/// This will use multithreading if beneficial.
30pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
31    fn bitreverse(mut n: usize, l: usize) -> usize {
32        let mut r = 0;
33        for _ in 0..l {
34            r = (r << 1) | (n & 1);
35            n >>= 1;
36        }
37        r
38    }
39
40    let threads = rayon::current_num_threads();
41    let log_threads = threads.ilog2();
42    let n = a.len();
43    assert_eq!(n, 1 << log_n);
44
45    for k in 0..n {
46        let rk = bitreverse(k, log_n as usize);
47        if k < rk {
48            a.swap(rk, k);
49        }
50    }
51
52    // precompute twiddle factors
53    let twiddles: Vec<_> = (0..(n / 2))
54        .scan(Scalar::ONE, |w, _| {
55            let tw = *w;
56            *w *= &omega;
57            Some(tw)
58        })
59        .collect();
60
61    if log_n <= log_threads {
62        let mut chunk = 2_usize;
63        let mut twiddle_chunk = n / 2;
64        for _ in 0..log_n {
65            a.chunks_mut(chunk).for_each(|coeffs| {
66                let (left, right) = coeffs.split_at_mut(chunk / 2);
67
68                // case when twiddle factor is one
69                let (a, left) = left.split_at_mut(1);
70                let (b, right) = right.split_at_mut(1);
71                let t = b[0];
72                b[0] = a[0];
73                a[0] += &t;
74                b[0] -= &t;
75
76                left.iter_mut()
77                    .zip(right.iter_mut())
78                    .enumerate()
79                    .for_each(|(i, (a, b))| {
80                        let mut t = *b;
81                        t *= &twiddles[(i + 1) * twiddle_chunk];
82                        *b = *a;
83                        *a += &t;
84                        *b -= &t;
85                    });
86            });
87            chunk *= 2;
88            twiddle_chunk /= 2;
89        }
90    } else {
91        recursive_butterfly_arithmetic(a, n, 1, &twiddles)
92    }
93}
94
95/// This perform recursive butterfly arithmetic
96pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
97    a: &mut [G],
98    n: usize,
99    twiddle_chunk: usize,
100    twiddles: &[Scalar],
101) {
102    if n == 2 {
103        let t = a[1];
104        a[1] = a[0];
105        a[0] += &t;
106        a[1] -= &t;
107    } else {
108        let (left, right) = a.split_at_mut(n / 2);
109        rayon::join(
110            || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
111            || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
112        );
113
114        // case when twiddle factor is one
115        let (a, left) = left.split_at_mut(1);
116        let (b, right) = right.split_at_mut(1);
117        let t = b[0];
118        b[0] = a[0];
119        a[0] += &t;
120        b[0] -= &t;
121
122        left.iter_mut()
123            .zip(right.iter_mut())
124            .enumerate()
125            .for_each(|(i, (a, b))| {
126                let mut t = *b;
127                t *= &twiddles[(i + 1) * twiddle_chunk];
128                *b = *a;
129                *a += &t;
130                *b -= &t;
131            });
132    }
133}