halo2_axiom/
fft.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
//! This is a module for dispatching between different FFT implementations at runtime based on environment variable `FFT`.

use ff::Field;

use self::recursive::FFTData;
use crate::arithmetic::FftGroup;

pub mod baseline;
pub mod parallel;
pub mod recursive;

/// Runtime dispatcher to concrete FFT implementation
pub fn fft<Scalar: Field, G: FftGroup<Scalar>>(
    a: &mut [G],
    omega: Scalar,
    log_n: u32,
    data: &FFTData<Scalar>,
    inverse: bool,
) {
    // Empirically, the parallel implementation requires less memory bandwidth, which is more performant on x86_64.
    #[cfg(target_arch = "x86_64")]
    parallel::fft(a, omega, log_n, data, inverse);
    #[cfg(not(target_arch = "x86_64"))]
    recursive::fft(a, omega, log_n, data, inverse)
}

#[cfg(test)]
mod tests {
    use ark_std::{end_timer, start_timer};
    use ff::Field;
    use halo2curves::bn256::Fr as Scalar;
    use rand_core::OsRng;

    use crate::{arithmetic::best_fft, fft, multicore, poly::EvaluationDomain};

    #[test]
    fn test_fft_recursive() {
        let k = 22;

        let domain = EvaluationDomain::<Scalar>::new(1, k);
        let n = domain.get_n() as usize;

        let input = vec![Scalar::random(OsRng); n];

        let num_threads = multicore::current_num_threads();

        let mut a = input.clone();
        let l_a = a.len();
        let start = start_timer!(|| format!("best fft {} ({})", a.len(), num_threads));
        fft::baseline::fft(
            &mut a,
            domain.get_omega(),
            k,
            domain.get_fft_data(l_a),
            false,
        );
        end_timer!(start);

        let mut c = input.clone();
        let l_c = c.len();
        let start = start_timer!(|| format!("parallel fft {} ({})", a.len(), num_threads));
        fft::parallel::fft(
            &mut c,
            domain.get_omega(),
            k,
            domain.get_fft_data(l_c),
            false,
        );
        end_timer!(start);

        let mut b = input;
        let l_b = b.len();
        let start = start_timer!(|| format!("recursive fft {} ({})", a.len(), num_threads));
        fft::recursive::fft(
            &mut b,
            domain.get_omega(),
            k,
            domain.get_fft_data(l_b),
            false,
        );
        end_timer!(start);

        for i in 0..n {
            //log_info(format!("{}: {} {}", i, a[i], b[i]));
            assert_eq!(a[i], b[i]);
            assert_eq!(a[i], c[i]);
        }
    }

    #[test]
    fn test_ifft_recursive() {
        let k = 22;

        let domain = EvaluationDomain::<Scalar>::new(1, k);
        let n = domain.get_n() as usize;

        let input = vec![Scalar::random(OsRng); n];

        let mut a = input.clone();
        let l_a = a.len();
        fft::recursive::fft(
            &mut a,
            domain.get_omega(),
            k,
            domain.get_fft_data(l_a),
            false,
        );
        fft::recursive::fft(
            &mut a,
            domain.get_omega_inv(), // doesn't actually do anything
            k,
            domain.get_fft_data(l_a),
            true,
        );
        let ifft_divisor = Scalar::from(n as u64).invert().unwrap();

        for i in 0..n {
            assert_eq!(input[i], a[i] * ifft_divisor);
        }
    }

    #[test]
    fn test_mem_leak() {
        let j = 1;
        let k = 3;
        let domain = EvaluationDomain::new(j, k);
        let omega = domain.get_omega();
        let l = 1 << k;
        let data = domain.get_fft_data(l);
        let mut a = (0..(1 << k))
            .map(|_| Scalar::random(OsRng))
            .collect::<Vec<_>>();

        best_fft(&mut a, omega, k, data, false);
    }
}