halo2_axiom/
fft.rs

1//! This is a module for dispatching between different FFT implementations at runtime based on environment variable `FFT`.
2
3use ff::Field;
4
5use self::recursive::FFTData;
6use crate::arithmetic::FftGroup;
7
8pub mod baseline;
9pub mod parallel;
10pub mod recursive;
11
12/// Runtime dispatcher to concrete FFT implementation
13pub fn fft<Scalar: Field, G: FftGroup<Scalar>>(
14    a: &mut [G],
15    omega: Scalar,
16    log_n: u32,
17    data: &FFTData<Scalar>,
18    inverse: bool,
19) {
20    // Empirically, the parallel implementation requires less memory bandwidth, which is more performant on x86_64.
21    #[cfg(target_arch = "x86_64")]
22    parallel::fft(a, omega, log_n, data, inverse);
23    #[cfg(not(target_arch = "x86_64"))]
24    recursive::fft(a, omega, log_n, data, inverse)
25}
26
27#[cfg(test)]
28mod tests {
29    use ark_std::{end_timer, start_timer};
30    use ff::Field;
31    use halo2curves::bn256::Fr as Scalar;
32    use rand_core::OsRng;
33
34    use crate::{arithmetic::best_fft, fft, multicore, poly::EvaluationDomain};
35
36    #[test]
37    fn test_fft_recursive() {
38        let k = 22;
39
40        let domain = EvaluationDomain::<Scalar>::new(1, k);
41        let n = domain.get_n() as usize;
42
43        let input = vec![Scalar::random(OsRng); n];
44
45        let num_threads = multicore::current_num_threads();
46
47        let mut a = input.clone();
48        let l_a = a.len();
49        let start = start_timer!(|| format!("best fft {} ({})", a.len(), num_threads));
50        fft::baseline::fft(
51            &mut a,
52            domain.get_omega(),
53            k,
54            domain.get_fft_data(l_a),
55            false,
56        );
57        end_timer!(start);
58
59        let mut c = input.clone();
60        let l_c = c.len();
61        let start = start_timer!(|| format!("parallel fft {} ({})", a.len(), num_threads));
62        fft::parallel::fft(
63            &mut c,
64            domain.get_omega(),
65            k,
66            domain.get_fft_data(l_c),
67            false,
68        );
69        end_timer!(start);
70
71        let mut b = input;
72        let l_b = b.len();
73        let start = start_timer!(|| format!("recursive fft {} ({})", a.len(), num_threads));
74        fft::recursive::fft(
75            &mut b,
76            domain.get_omega(),
77            k,
78            domain.get_fft_data(l_b),
79            false,
80        );
81        end_timer!(start);
82
83        for i in 0..n {
84            //log_info(format!("{}: {} {}", i, a[i], b[i]));
85            assert_eq!(a[i], b[i]);
86            assert_eq!(a[i], c[i]);
87        }
88    }
89
90    #[test]
91    fn test_ifft_recursive() {
92        let k = 22;
93
94        let domain = EvaluationDomain::<Scalar>::new(1, k);
95        let n = domain.get_n() as usize;
96
97        let input = vec![Scalar::random(OsRng); n];
98
99        let mut a = input.clone();
100        let l_a = a.len();
101        fft::recursive::fft(
102            &mut a,
103            domain.get_omega(),
104            k,
105            domain.get_fft_data(l_a),
106            false,
107        );
108        fft::recursive::fft(
109            &mut a,
110            domain.get_omega_inv(), // doesn't actually do anything
111            k,
112            domain.get_fft_data(l_a),
113            true,
114        );
115        let ifft_divisor = Scalar::from(n as u64).invert().unwrap();
116
117        for i in 0..n {
118            assert_eq!(input[i], a[i] * ifft_divisor);
119        }
120    }
121
122    #[test]
123    fn test_mem_leak() {
124        let j = 1;
125        let k = 3;
126        let domain = EvaluationDomain::new(j, k);
127        let omega = domain.get_omega();
128        let l = 1 << k;
129        let data = domain.get_fft_data(l);
130        let mut a = (0..(1 << k))
131            .map(|_| Scalar::random(OsRng))
132            .collect::<Vec<_>>();
133
134        best_fft(&mut a, omega, k, data, false);
135    }
136}