1use crate::{
4 arithmetic::{self, parallelize, FftGroup},
5 multicore,
6};
7
8pub use ff::Field;
9pub use halo2curves::{CurveAffine, CurveExt};
10
11#[derive(Clone, Debug)]
13pub struct FFTStage {
14 radix: usize,
15 length: usize,
16}
17
18fn get_stages(size: usize, radixes: Vec<usize>) -> Vec<FFTStage> {
20 let mut stages: Vec<FFTStage> = vec![];
21
22 let mut n = size;
23
24 for &radix in &radixes {
26 n /= radix;
27 stages.push(FFTStage { radix, length: n });
28 }
29
30 let mut p = 2;
32 while n > 1 {
33 while n % p != 0 {
34 if p == 4 {
35 p = 2;
36 }
37 }
38 n /= p;
39 stages.push(FFTStage {
40 radix: p,
41 length: n,
42 });
43 }
44
45 stages
50}
51
52#[derive(Clone, Debug)]
54pub struct FFTData<F: arithmetic::Field> {
55 n: usize,
56
57 stages: Vec<FFTStage>,
58
59 f_twiddles: Vec<Vec<F>>,
60 inv_twiddles: Vec<Vec<F>>,
61 }
63
64impl<F: arithmetic::Field> Default for FFTData<F> {
65 fn default() -> Self {
66 Self {
67 n: Default::default(),
68 stages: Default::default(),
69 f_twiddles: Default::default(),
70 inv_twiddles: Default::default(),
71 }
72 }
73}
74
75impl<F: arithmetic::Field> FFTData<F> {
76 pub fn new(n: usize, omega: F, omega_inv: F) -> Self {
78 let stages = get_stages(n, vec![]);
79 let mut f_twiddles = vec![];
80 let mut inv_twiddles = vec![];
81 let mut scratch = vec![F::ZERO; n];
82
83 for inv in 0..2 {
85 let inverse = inv == 0;
86 let o = if inverse { omega_inv } else { omega };
87 let stage_twiddles = if inverse {
88 &mut inv_twiddles
89 } else {
90 &mut f_twiddles
91 };
92
93 let twiddles = &mut scratch;
94
95 parallelize(twiddles, |twiddles, start| {
97 let w_m = o;
98 let mut w = o.pow_vartime([start as u64]);
99 for value in twiddles.iter_mut() {
100 *value = w;
101 w *= w_m;
102 }
103 });
104
105 let num_stages = stages.len();
107 stage_twiddles.resize(num_stages, vec![]);
108 for l in 0..num_stages {
109 let radix = stages[l].radix;
110 let stage_length = stages[l].length;
111
112 let num_twiddles = stage_length * (radix - 1);
113 stage_twiddles[l].resize(num_twiddles + 1, F::ZERO);
114
115 stage_twiddles[l][num_twiddles] = twiddles[(twiddles.len() * 3) / 4];
117
118 let stride = n / (stage_length * radix);
119 let mut tws = vec![0usize; radix - 1];
120 for i in 0..stage_length {
121 for j in 0..radix - 1 {
122 stage_twiddles[l][i * (radix - 1) + j] = twiddles[tws[j]];
123 tws[j] += (j + 1) * stride;
124 }
125 }
126 }
127 }
128
129 Self {
130 n,
131 stages,
132 f_twiddles,
133 inv_twiddles,
134 }
136 }
137
138 pub fn get_n(&self) -> usize {
140 self.n
141 }
142}
143
144fn butterfly_2<Scalar: Field, G: FftGroup<Scalar>>(
146 out: &mut [G],
147 twiddles: &[Scalar],
148 stage_length: usize,
149) {
150 let mut out_offset = 0;
151 let mut out_offset2 = stage_length;
152
153 let t = out[out_offset2];
154 out[out_offset2] = out[out_offset] - &t;
155 out[out_offset] += &t;
156 out_offset2 += 1;
157 out_offset += 1;
158
159 for twiddle in twiddles[1..stage_length].iter() {
160 let t = out[out_offset2] * twiddle;
161 out[out_offset2] = out[out_offset] - &t;
162 out[out_offset] += &t;
163 out_offset2 += 1;
164 out_offset += 1;
165 }
166}
167
168fn butterfly_2_parallel<Scalar: Field, G: FftGroup<Scalar>>(
170 out: &mut [G],
171 twiddles: &[Scalar],
172 _stage_length: usize,
173 num_threads: usize,
174) {
175 let n = out.len();
176 let mut chunk = n / num_threads;
177 if chunk < num_threads {
178 chunk = n;
179 }
180
181 multicore::scope(|scope| {
182 let (part_a, part_b) = out.split_at_mut(n / 2);
183 for (i, (part0, part1)) in part_a
184 .chunks_mut(chunk)
185 .zip(part_b.chunks_mut(chunk))
186 .enumerate()
187 {
188 scope.spawn(move |_| {
189 let offset = i * chunk;
190 for k in 0..part0.len() {
191 let t = part1[k] * &twiddles[offset + k];
192 part1[k] = part0[k] - &t;
193 part0[k] += &t;
194 }
195 });
196 }
197 });
198}
199
200fn butterfly_4<Scalar: Field, G: FftGroup<Scalar>>(
202 out: &mut [G],
203 twiddles: &[Scalar],
204 stage_length: usize,
205) {
206 let j = twiddles[twiddles.len() - 1];
207 let mut tw = 0;
208
209 {
211 let i0 = 0;
212 let i1 = stage_length;
213 let i2 = stage_length * 2;
214 let i3 = stage_length * 3;
215
216 let z0 = out[i0];
217 let z1 = out[i1];
218 let z2 = out[i2];
219 let z3 = out[i3];
220
221 let t1 = z0 + &z2;
222 let t2 = z1 + &z3;
223 let t3 = z0 - &z2;
224 let t4j = (z1 - &z3) * &j;
225
226 out[i0] = t1 + &t2;
227 out[i1] = t3 - &t4j;
228 out[i2] = t1 - &t2;
229 out[i3] = t3 + &t4j;
230
231 tw += 3;
232 }
233
234 for k in 1..stage_length {
235 let i0 = k;
236 let i1 = k + stage_length;
237 let i2 = k + stage_length * 2;
238 let i3 = k + stage_length * 3;
239
240 let z0 = out[i0];
241 let z1 = out[i1] * &twiddles[tw];
242 let z2 = out[i2] * &twiddles[tw + 1];
243 let z3 = out[i3] * &twiddles[tw + 2];
244
245 let t1 = z0 + &z2;
246 let t2 = z1 + &z3;
247 let t3 = z0 - &z2;
248 let t4j = (z1 - &z3) * &j;
249
250 out[i0] = t1 + &t2;
251 out[i1] = t3 - &t4j;
252 out[i2] = t1 - &t2;
253 out[i3] = t3 + &t4j;
254
255 tw += 3;
256 }
257}
258
259fn butterfly_4_parallel<Scalar: Field, G: FftGroup<Scalar>>(
261 out: &mut [G],
262 twiddles: &[Scalar],
263 _stage_length: usize,
264 num_threads: usize,
265) {
266 let j = twiddles[twiddles.len() - 1];
267
268 let n = out.len();
269 let mut chunk = n / num_threads;
270 if chunk < num_threads {
271 chunk = n;
272 }
273 multicore::scope(|scope| {
274 let (part_a, part_b) = out.split_at_mut(n / 2);
277 let (part_aa, part_ab) = part_a.split_at_mut(n / 4);
278 let (part_ba, part_bb) = part_b.split_at_mut(n / 4);
279 for (i, (((part0, part1), part2), part3)) in part_aa
280 .chunks_mut(chunk)
281 .zip(part_ab.chunks_mut(chunk))
282 .zip(part_ba.chunks_mut(chunk))
283 .zip(part_bb.chunks_mut(chunk))
284 .enumerate()
285 {
286 scope.spawn(move |_| {
287 let offset = i * chunk;
288 let mut tw = offset * 3;
289 for k in 0..part1.len() {
290 let z0 = part0[k];
291 let z1 = part1[k] * &twiddles[tw];
292 let z2 = part2[k] * &twiddles[tw + 1];
293 let z3 = part3[k] * &twiddles[tw + 2];
294
295 let t1 = z0 + &z2;
296 let t2 = z1 + &z3;
297 let t3 = z0 - &z2;
298 let t4j = (z1 - &z3) * &j;
299
300 part0[k] = t1 + &t2;
301 part1[k] = t3 - &t4j;
302 part2[k] = t1 - &t2;
303 part3[k] = t3 + &t4j;
304
305 tw += 3;
306 }
307 });
308 }
309 });
310}
311
312#[allow(clippy::too_many_arguments)]
314fn recursive_fft_inner<Scalar: Field, G: FftGroup<Scalar>>(
315 data_in: &[G],
316 data_out: &mut [G],
317 twiddles: &Vec<Vec<Scalar>>,
318 stages: &Vec<FFTStage>,
319 in_offset: usize,
320 stride: usize,
321 level: usize,
322 num_threads: usize,
323) {
324 let radix = stages[level].radix;
325 let stage_length = stages[level].length;
326
327 if num_threads > 1 {
328 if stage_length == 1 {
329 for i in 0..radix {
330 data_out[i] = data_in[in_offset + i * stride];
331 }
332 } else {
333 let num_threads_recursive = if num_threads >= radix {
334 radix
335 } else {
336 num_threads
337 };
338 parallelize_count(data_out, num_threads_recursive, |data_out, i| {
339 let num_threads_in_recursion = if num_threads < radix {
340 1
341 } else {
342 (num_threads + i) / radix
343 };
344 recursive_fft_inner(
345 data_in,
346 data_out,
347 twiddles,
348 stages,
349 in_offset + i * stride,
350 stride * radix,
351 level + 1,
352 num_threads_in_recursion,
353 )
354 });
355 }
356 match radix {
357 2 => butterfly_2_parallel(data_out, &twiddles[level], stage_length, num_threads),
358 4 => butterfly_4_parallel(data_out, &twiddles[level], stage_length, num_threads),
359 _ => unimplemented!("radix unsupported"),
360 }
361 } else {
362 if stage_length == 1 {
363 for i in 0..radix {
364 data_out[i] = data_in[in_offset + i * stride];
365 }
366 } else {
367 for i in 0..radix {
368 recursive_fft_inner(
369 data_in,
370 &mut data_out[i * stage_length..(i + 1) * stage_length],
371 twiddles,
372 stages,
373 in_offset + i * stride,
374 stride * radix,
375 level + 1,
376 num_threads,
377 );
378 }
379 }
380 match radix {
381 2 => butterfly_2(data_out, &twiddles[level], stage_length),
382 4 => butterfly_4(data_out, &twiddles[level], stage_length),
383 _ => unimplemented!("radix unsupported"),
384 }
385 }
386}
387
388fn recursive_fft<Scalar: Field, G: FftGroup<Scalar>>(
390 data: &FFTData<Scalar>,
391 data_in: &mut Vec<G>,
392 inverse: bool,
393) {
394 let num_threads = multicore::current_num_threads();
395 let filler = data_in[0];
400 let mut scratch = vec![filler; data_in.len()];
401 recursive_fft_inner(
404 data_in,
405 &mut scratch,
406 if inverse {
407 &data.inv_twiddles
408 } else {
409 &data.f_twiddles
410 },
411 &data.stages,
412 0,
413 1,
414 0,
415 num_threads,
416 );
417 std::mem::swap(data_in, &mut scratch);
422 }
424
425fn parallelize_count<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(
428 v: &mut [T],
429 num_threads: usize,
430 f: F,
431) {
432 let n = v.len();
433 let mut chunk = n / num_threads;
434 if chunk < num_threads {
435 chunk = n;
436 }
437
438 multicore::scope(|scope| {
439 for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
440 let f = f.clone();
441 scope.spawn(move |_| {
442 f(v, chunk_num);
443 });
444 }
445 });
446}
447
448pub fn fft<Scalar: Field, G: FftGroup<Scalar>>(
450 data_in: &mut [G],
451 _omega: Scalar,
452 _log_n: u32,
453 data: &FFTData<Scalar>,
454 inverse: bool,
455) {
456 let orig_len = data_in.len();
457 let mut data_in_vec = data_in.to_vec();
458 recursive_fft(data, &mut data_in_vec, inverse);
459 data_in.copy_from_slice(&data_in_vec);
460 assert_eq!(orig_len, data_in.len());
461}