halo2_axiom/
fft.rs
1use ff::Field;
4
5use self::recursive::FFTData;
6use crate::arithmetic::FftGroup;
7
8pub mod baseline;
9pub mod parallel;
10pub mod recursive;
11
12pub fn fft<Scalar: Field, G: FftGroup<Scalar>>(
14 a: &mut [G],
15 omega: Scalar,
16 log_n: u32,
17 data: &FFTData<Scalar>,
18 inverse: bool,
19) {
20 #[cfg(target_arch = "x86_64")]
22 parallel::fft(a, omega, log_n, data, inverse);
23 #[cfg(not(target_arch = "x86_64"))]
24 recursive::fft(a, omega, log_n, data, inverse)
25}
26
27#[cfg(test)]
28mod tests {
29 use ark_std::{end_timer, start_timer};
30 use ff::Field;
31 use halo2curves::bn256::Fr as Scalar;
32 use rand_core::OsRng;
33
34 use crate::{arithmetic::best_fft, fft, multicore, poly::EvaluationDomain};
35
36 #[test]
37 fn test_fft_recursive() {
38 let k = 22;
39
40 let domain = EvaluationDomain::<Scalar>::new(1, k);
41 let n = domain.get_n() as usize;
42
43 let input = vec![Scalar::random(OsRng); n];
44
45 let num_threads = multicore::current_num_threads();
46
47 let mut a = input.clone();
48 let l_a = a.len();
49 let start = start_timer!(|| format!("best fft {} ({})", a.len(), num_threads));
50 fft::baseline::fft(
51 &mut a,
52 domain.get_omega(),
53 k,
54 domain.get_fft_data(l_a),
55 false,
56 );
57 end_timer!(start);
58
59 let mut c = input.clone();
60 let l_c = c.len();
61 let start = start_timer!(|| format!("parallel fft {} ({})", a.len(), num_threads));
62 fft::parallel::fft(
63 &mut c,
64 domain.get_omega(),
65 k,
66 domain.get_fft_data(l_c),
67 false,
68 );
69 end_timer!(start);
70
71 let mut b = input;
72 let l_b = b.len();
73 let start = start_timer!(|| format!("recursive fft {} ({})", a.len(), num_threads));
74 fft::recursive::fft(
75 &mut b,
76 domain.get_omega(),
77 k,
78 domain.get_fft_data(l_b),
79 false,
80 );
81 end_timer!(start);
82
83 for i in 0..n {
84 assert_eq!(a[i], b[i]);
86 assert_eq!(a[i], c[i]);
87 }
88 }
89
90 #[test]
91 fn test_ifft_recursive() {
92 let k = 22;
93
94 let domain = EvaluationDomain::<Scalar>::new(1, k);
95 let n = domain.get_n() as usize;
96
97 let input = vec![Scalar::random(OsRng); n];
98
99 let mut a = input.clone();
100 let l_a = a.len();
101 fft::recursive::fft(
102 &mut a,
103 domain.get_omega(),
104 k,
105 domain.get_fft_data(l_a),
106 false,
107 );
108 fft::recursive::fft(
109 &mut a,
110 domain.get_omega_inv(), k,
112 domain.get_fft_data(l_a),
113 true,
114 );
115 let ifft_divisor = Scalar::from(n as u64).invert().unwrap();
116
117 for i in 0..n {
118 assert_eq!(input[i], a[i] * ifft_divisor);
119 }
120 }
121
122 #[test]
123 fn test_mem_leak() {
124 let j = 1;
125 let k = 3;
126 let domain = EvaluationDomain::new(j, k);
127 let omega = domain.get_omega();
128 let l = 1 << k;
129 let data = domain.get_fft_data(l);
130 let mut a = (0..(1 << k))
131 .map(|_| Scalar::random(OsRng))
132 .collect::<Vec<_>>();
133
134 best_fft(&mut a, omega, k, data, false);
135 }
136}