1use crate::arithmetic::{self, log2_floor, FftGroup};
5
6use crate::multicore;
7pub use ff::Field;
8pub use halo2curves::{CurveAffine, CurveExt};
9
10use super::recursive::FFTData;
11
12pub const SPARSE_TWIDDLE_DEGREE: u32 = 10;
14
15fn best_fft_opt<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
17 let threads = multicore::current_num_threads();
18 let log_split = log2_floor(threads) as usize;
19 let n = a.len();
20 let sub_n = n >> log_split;
21 let split_m = 1 << log_split;
22
23 if sub_n >= split_m {
24 parallel_fft(a, omega, log_n);
25 } else {
26 serial_fft(a, omega, log_n);
27 }
28}
29
30fn serial_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
31 let n = a.len() as u32;
32 assert_eq!(n, 1 << log_n);
33
34 for k in 0..n as usize {
35 let rk = arithmetic::bitreverse(k, log_n as usize);
36 if k < rk {
37 a.swap(rk, k);
38 }
39 }
40
41 let mut m = 1;
42 for _ in 0..log_n {
43 let w_m: Scalar = omega.pow_vartime([u64::from(n / (2 * m)), 0, 0, 0]);
44
45 let mut k = 0;
46 while k < n {
47 let mut w = Scalar::ONE;
48 for j in 0..m {
49 let mut t = a[(k + j + m) as usize];
50 t *= &w;
51 a[(k + j + m) as usize] = a[(k + j) as usize];
52 a[(k + j + m) as usize] -= &t;
53 a[(k + j) as usize] += &t;
54 w *= &w_m;
55 }
56
57 k += 2 * m;
58 }
59
60 m *= 2;
61 }
62}
63
64fn serial_split_fft<Scalar: Field, G: FftGroup<Scalar>>(
65 a: &mut [G],
66 twiddle_lut: &[Scalar],
67 twiddle_scale: usize,
68 log_n: u32,
69) {
70 let n = a.len() as u32;
71 assert_eq!(n, 1 << log_n);
72
73 let mut m = 1;
74 for _ in 0..log_n {
75 let omega_idx = twiddle_scale * n as usize / (2 * m as usize); let low_idx = omega_idx % (1 << SPARSE_TWIDDLE_DEGREE);
77 let high_idx = omega_idx >> SPARSE_TWIDDLE_DEGREE;
78 let mut w_m = twiddle_lut[low_idx];
79 if high_idx > 0 {
80 w_m *= twiddle_lut[(1 << SPARSE_TWIDDLE_DEGREE) + high_idx];
81 }
82
83 let mut k = 0;
84 while k < n {
85 let mut w = Scalar::ONE;
86 for j in 0..m {
87 let mut t = a[(k + j + m) as usize];
88 t *= &w;
89 a[(k + j + m) as usize] = a[(k + j) as usize];
90 a[(k + j + m) as usize] -= &t;
91 a[(k + j) as usize] += &t;
92 w *= &w_m;
93 }
94
95 k += 2 * m;
96 }
97
98 m *= 2;
99 }
100}
101
102fn split_radix_fft<Scalar: Field, G: FftGroup<Scalar>>(
103 tmp: &mut [G],
104 a: &[G],
105 twiddle_lut: &[Scalar],
106 n: usize,
107 sub_fft_offset: usize,
108 log_split: usize,
109) {
110 let split_m = 1 << log_split;
111 let sub_n = n >> log_split;
112
113 let tmp_filler_val = tmp[0];
116 let mut t1 = vec![tmp_filler_val; split_m];
117 for i in 0..split_m {
118 t1[arithmetic::bitreverse(i, log_split)] = a[i * sub_n + sub_fft_offset];
119 }
120 serial_split_fft(&mut t1, twiddle_lut, sub_n, log_split as u32);
121
122 let sparse_degree = SPARSE_TWIDDLE_DEGREE;
123 let omega_idx = sub_fft_offset;
124 let low_idx = omega_idx % (1 << sparse_degree);
125 let high_idx = omega_idx >> sparse_degree;
126 let mut omega = twiddle_lut[low_idx];
127 if high_idx > 0 {
128 omega *= twiddle_lut[(1 << sparse_degree) + high_idx];
129 }
130 let mut w_m = Scalar::ONE;
131 for i in 0..split_m {
132 t1[i] *= &w_m;
133 tmp[i] = t1[i];
134 w_m *= omega;
135 }
136}
137
138fn generate_twiddle_lookup_table<F: Field>(
140 omega: F,
141 log_n: u32,
142 sparse_degree: u32,
143 with_last_level: bool,
144) -> Vec<F> {
145 let without_last_level = !with_last_level;
146 let is_lut_len_large = sparse_degree > log_n;
147
148 if is_lut_len_large {
150 let mut twiddle_lut = vec![F::ZERO; (1 << log_n) as usize];
151 parallelize(&mut twiddle_lut, |twiddle_lut, start| {
152 let mut w_n = omega.pow_vartime([start as u64, 0, 0, 0]);
153 for twiddle_lut in twiddle_lut.iter_mut() {
154 *twiddle_lut = w_n;
155 w_n *= omega;
156 }
157 });
158 return twiddle_lut;
159 }
160
161 let low_degree_lut_len = 1 << sparse_degree;
163 let high_degree_lut_len = 1 << (log_n - sparse_degree - without_last_level as u32);
164 let mut twiddle_lut = vec![F::ZERO; low_degree_lut_len + high_degree_lut_len];
165 parallelize(
166 &mut twiddle_lut[..low_degree_lut_len],
167 |twiddle_lut, start| {
168 let mut w_n = omega.pow_vartime([start as u64, 0, 0, 0]);
169 for twiddle_lut in twiddle_lut.iter_mut() {
170 *twiddle_lut = w_n;
171 w_n *= omega;
172 }
173 },
174 );
175 let high_degree_omega = omega.pow_vartime([(1 << sparse_degree) as u64, 0, 0, 0]);
176 parallelize(
177 &mut twiddle_lut[low_degree_lut_len..],
178 |twiddle_lut, start| {
179 let mut w_n = high_degree_omega.pow_vartime([start as u64, 0, 0, 0]);
180 for twiddle_lut in twiddle_lut.iter_mut() {
181 *twiddle_lut = w_n;
182 w_n *= high_degree_omega;
183 }
184 },
185 );
186 twiddle_lut
187}
188
189fn parallel_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
191 let n = a.len();
192 assert_eq!(n, 1 << log_n);
193
194 let log_split = log2_floor(multicore::current_num_threads()) as usize;
195 let split_m = 1 << log_split;
196 let sub_n = n >> log_split;
197 let twiddle_lut = generate_twiddle_lookup_table(omega, log_n, SPARSE_TWIDDLE_DEGREE, true);
198
199 let tmp_filler_val = a[0];
201 let mut tmp = vec![tmp_filler_val; n];
202 multicore::scope(|scope| {
203 let a = &*a;
204 let twiddle_lut = &*twiddle_lut;
205 for (chunk_idx, tmp) in tmp.chunks_mut(sub_n).enumerate() {
206 scope.spawn(move |_| {
207 let split_fft_offset = (chunk_idx * sub_n) >> log_split;
208 for (i, tmp) in tmp.chunks_mut(split_m).enumerate() {
209 let split_fft_offset = split_fft_offset + i;
210 split_radix_fft(tmp, a, twiddle_lut, n, split_fft_offset, log_split);
211 }
212 });
213 }
214 });
215
216 parallelize(a, |a, start| {
218 for (idx, a) in a.iter_mut().enumerate() {
219 let idx = start + idx;
220 let i = idx / sub_n;
221 let j = idx % sub_n;
222 *a = tmp[j * split_m + i];
223 }
224 });
225
226 let new_omega = omega.pow_vartime([split_m as u64, 0, 0, 0]);
228 multicore::scope(|scope| {
229 for a in a.chunks_mut(sub_n) {
230 scope.spawn(move |_| {
231 serial_fft(a, new_omega, log_n - log_split as u32);
232 });
233 }
234 });
235
236 let mask = (1 << log_split) - 1;
238 parallelize(&mut tmp, |tmp, start| {
239 for (idx, tmp) in tmp.iter_mut().enumerate() {
240 let idx = start + idx;
241 *tmp = a[idx];
242 }
243 });
244 parallelize(a, |a, start| {
245 for (idx, a) in a.iter_mut().enumerate() {
246 let idx = start + idx;
247 *a = tmp[sub_n * (idx & mask) + (idx >> log_split)];
248 }
249 });
250}
251
252fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
255 let n = v.len();
256 let num_threads = multicore::current_num_threads();
257 let mut chunk = n / num_threads;
258 if chunk < num_threads {
259 chunk = n;
260 }
261
262 multicore::scope(|scope| {
263 for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
264 let f = f.clone();
265 scope.spawn(move |_| {
266 let start = chunk_num * chunk;
267 f(v, start);
268 });
269 }
270 });
271}
272
273pub fn fft<Scalar: Field, G: FftGroup<Scalar>>(
275 data_in: &mut [G],
276 omega: Scalar,
277 log_n: u32,
278 _data: &FFTData<Scalar>,
279 _inverse: bool,
280) {
281 best_fft_opt(data_in, omega, log_n)
282}