halo2_axiom/fft/
recursive.rs

1//! This contains the recursive FFT.
2
3use crate::{
4    arithmetic::{self, parallelize, FftGroup},
5    multicore,
6};
7
8pub use ff::Field;
9pub use halo2curves::{CurveAffine, CurveExt};
10
11/// FFTStage
12#[derive(Clone, Debug)]
13pub struct FFTStage {
14    radix: usize,
15    length: usize,
16}
17
18/// FFT stages
19fn get_stages(size: usize, radixes: Vec<usize>) -> Vec<FFTStage> {
20    let mut stages: Vec<FFTStage> = vec![];
21
22    let mut n = size;
23
24    // Use the specified radices
25    for &radix in &radixes {
26        n /= radix;
27        stages.push(FFTStage { radix, length: n });
28    }
29
30    // Fill in the rest of the tree if needed
31    let mut p = 2;
32    while n > 1 {
33        while n % p != 0 {
34            if p == 4 {
35                p = 2;
36            }
37        }
38        n /= p;
39        stages.push(FFTStage {
40            radix: p,
41            length: n,
42        });
43    }
44
45    /*for i in 0..stages.len() {
46        log_info(format!("Stage {}: {}, {}", i, stages[i].radix, stages[i].length));
47    }*/
48
49    stages
50}
51
52/// FFTData
53#[derive(Clone, Debug)]
54pub struct FFTData<F: arithmetic::Field> {
55    n: usize,
56
57    stages: Vec<FFTStage>,
58
59    f_twiddles: Vec<Vec<F>>,
60    inv_twiddles: Vec<Vec<F>>,
61    //scratch: Vec<F>,
62}
63
64impl<F: arithmetic::Field> Default for FFTData<F> {
65    fn default() -> Self {
66        Self {
67            n: Default::default(),
68            stages: Default::default(),
69            f_twiddles: Default::default(),
70            inv_twiddles: Default::default(),
71        }
72    }
73}
74
75impl<F: arithmetic::Field> FFTData<F> {
76    /// Create FFT data
77    pub fn new(n: usize, omega: F, omega_inv: F) -> Self {
78        let stages = get_stages(n, vec![]);
79        let mut f_twiddles = vec![];
80        let mut inv_twiddles = vec![];
81        let mut scratch = vec![F::ZERO; n];
82
83        // Generate stage twiddles
84        for inv in 0..2 {
85            let inverse = inv == 0;
86            let o = if inverse { omega_inv } else { omega };
87            let stage_twiddles = if inverse {
88                &mut inv_twiddles
89            } else {
90                &mut f_twiddles
91            };
92
93            let twiddles = &mut scratch;
94
95            // Twiddles
96            parallelize(twiddles, |twiddles, start| {
97                let w_m = o;
98                let mut w = o.pow_vartime([start as u64]);
99                for value in twiddles.iter_mut() {
100                    *value = w;
101                    w *= w_m;
102                }
103            });
104
105            // Re-order twiddles for cache friendliness
106            let num_stages = stages.len();
107            stage_twiddles.resize(num_stages, vec![]);
108            for l in 0..num_stages {
109                let radix = stages[l].radix;
110                let stage_length = stages[l].length;
111
112                let num_twiddles = stage_length * (radix - 1);
113                stage_twiddles[l].resize(num_twiddles + 1, F::ZERO);
114
115                // Set j
116                stage_twiddles[l][num_twiddles] = twiddles[(twiddles.len() * 3) / 4];
117
118                let stride = n / (stage_length * radix);
119                let mut tws = vec![0usize; radix - 1];
120                for i in 0..stage_length {
121                    for j in 0..radix - 1 {
122                        stage_twiddles[l][i * (radix - 1) + j] = twiddles[tws[j]];
123                        tws[j] += (j + 1) * stride;
124                    }
125                }
126            }
127        }
128
129        Self {
130            n,
131            stages,
132            f_twiddles,
133            inv_twiddles,
134            //scratch,
135        }
136    }
137
138    /// Return private field `n`
139    pub fn get_n(&self) -> usize {
140        self.n
141    }
142}
143
144/// Radix 2 butterfly
145fn butterfly_2<Scalar: Field, G: FftGroup<Scalar>>(
146    out: &mut [G],
147    twiddles: &[Scalar],
148    stage_length: usize,
149) {
150    let mut out_offset = 0;
151    let mut out_offset2 = stage_length;
152
153    let t = out[out_offset2];
154    out[out_offset2] = out[out_offset] - &t;
155    out[out_offset] += &t;
156    out_offset2 += 1;
157    out_offset += 1;
158
159    for twiddle in twiddles[1..stage_length].iter() {
160        let t = out[out_offset2] * twiddle;
161        out[out_offset2] = out[out_offset] - &t;
162        out[out_offset] += &t;
163        out_offset2 += 1;
164        out_offset += 1;
165    }
166}
167
168/// Radix 2 butterfly
169fn butterfly_2_parallel<Scalar: Field, G: FftGroup<Scalar>>(
170    out: &mut [G],
171    twiddles: &[Scalar],
172    _stage_length: usize,
173    num_threads: usize,
174) {
175    let n = out.len();
176    let mut chunk = n / num_threads;
177    if chunk < num_threads {
178        chunk = n;
179    }
180
181    multicore::scope(|scope| {
182        let (part_a, part_b) = out.split_at_mut(n / 2);
183        for (i, (part0, part1)) in part_a
184            .chunks_mut(chunk)
185            .zip(part_b.chunks_mut(chunk))
186            .enumerate()
187        {
188            scope.spawn(move |_| {
189                let offset = i * chunk;
190                for k in 0..part0.len() {
191                    let t = part1[k] * &twiddles[offset + k];
192                    part1[k] = part0[k] - &t;
193                    part0[k] += &t;
194                }
195            });
196        }
197    });
198}
199
200/// Radix 4 butterfly
201fn butterfly_4<Scalar: Field, G: FftGroup<Scalar>>(
202    out: &mut [G],
203    twiddles: &[Scalar],
204    stage_length: usize,
205) {
206    let j = twiddles[twiddles.len() - 1];
207    let mut tw = 0;
208
209    /* Case twiddle == one */
210    {
211        let i0 = 0;
212        let i1 = stage_length;
213        let i2 = stage_length * 2;
214        let i3 = stage_length * 3;
215
216        let z0 = out[i0];
217        let z1 = out[i1];
218        let z2 = out[i2];
219        let z3 = out[i3];
220
221        let t1 = z0 + &z2;
222        let t2 = z1 + &z3;
223        let t3 = z0 - &z2;
224        let t4j = (z1 - &z3) * &j;
225
226        out[i0] = t1 + &t2;
227        out[i1] = t3 - &t4j;
228        out[i2] = t1 - &t2;
229        out[i3] = t3 + &t4j;
230
231        tw += 3;
232    }
233
234    for k in 1..stage_length {
235        let i0 = k;
236        let i1 = k + stage_length;
237        let i2 = k + stage_length * 2;
238        let i3 = k + stage_length * 3;
239
240        let z0 = out[i0];
241        let z1 = out[i1] * &twiddles[tw];
242        let z2 = out[i2] * &twiddles[tw + 1];
243        let z3 = out[i3] * &twiddles[tw + 2];
244
245        let t1 = z0 + &z2;
246        let t2 = z1 + &z3;
247        let t3 = z0 - &z2;
248        let t4j = (z1 - &z3) * &j;
249
250        out[i0] = t1 + &t2;
251        out[i1] = t3 - &t4j;
252        out[i2] = t1 - &t2;
253        out[i3] = t3 + &t4j;
254
255        tw += 3;
256    }
257}
258
259/// Radix 4 butterfly
260fn butterfly_4_parallel<Scalar: Field, G: FftGroup<Scalar>>(
261    out: &mut [G],
262    twiddles: &[Scalar],
263    _stage_length: usize,
264    num_threads: usize,
265) {
266    let j = twiddles[twiddles.len() - 1];
267
268    let n = out.len();
269    let mut chunk = n / num_threads;
270    if chunk < num_threads {
271        chunk = n;
272    }
273    multicore::scope(|scope| {
274        //let mut parts: Vec<&mut [F]> = out.chunks_mut(4).collect();
275        //out.chunks_mut(4).map(|c| c.chunks_mut(chunk)).fold(predicate)
276        let (part_a, part_b) = out.split_at_mut(n / 2);
277        let (part_aa, part_ab) = part_a.split_at_mut(n / 4);
278        let (part_ba, part_bb) = part_b.split_at_mut(n / 4);
279        for (i, (((part0, part1), part2), part3)) in part_aa
280            .chunks_mut(chunk)
281            .zip(part_ab.chunks_mut(chunk))
282            .zip(part_ba.chunks_mut(chunk))
283            .zip(part_bb.chunks_mut(chunk))
284            .enumerate()
285        {
286            scope.spawn(move |_| {
287                let offset = i * chunk;
288                let mut tw = offset * 3;
289                for k in 0..part1.len() {
290                    let z0 = part0[k];
291                    let z1 = part1[k] * &twiddles[tw];
292                    let z2 = part2[k] * &twiddles[tw + 1];
293                    let z3 = part3[k] * &twiddles[tw + 2];
294
295                    let t1 = z0 + &z2;
296                    let t2 = z1 + &z3;
297                    let t3 = z0 - &z2;
298                    let t4j = (z1 - &z3) * &j;
299
300                    part0[k] = t1 + &t2;
301                    part1[k] = t3 - &t4j;
302                    part2[k] = t1 - &t2;
303                    part3[k] = t3 + &t4j;
304
305                    tw += 3;
306                }
307            });
308        }
309    });
310}
311
312/// Inner recursion
313#[allow(clippy::too_many_arguments)]
314fn recursive_fft_inner<Scalar: Field, G: FftGroup<Scalar>>(
315    data_in: &[G],
316    data_out: &mut [G],
317    twiddles: &Vec<Vec<Scalar>>,
318    stages: &Vec<FFTStage>,
319    in_offset: usize,
320    stride: usize,
321    level: usize,
322    num_threads: usize,
323) {
324    let radix = stages[level].radix;
325    let stage_length = stages[level].length;
326
327    if num_threads > 1 {
328        if stage_length == 1 {
329            for i in 0..radix {
330                data_out[i] = data_in[in_offset + i * stride];
331            }
332        } else {
333            let num_threads_recursive = if num_threads >= radix {
334                radix
335            } else {
336                num_threads
337            };
338            parallelize_count(data_out, num_threads_recursive, |data_out, i| {
339                let num_threads_in_recursion = if num_threads < radix {
340                    1
341                } else {
342                    (num_threads + i) / radix
343                };
344                recursive_fft_inner(
345                    data_in,
346                    data_out,
347                    twiddles,
348                    stages,
349                    in_offset + i * stride,
350                    stride * radix,
351                    level + 1,
352                    num_threads_in_recursion,
353                )
354            });
355        }
356        match radix {
357            2 => butterfly_2_parallel(data_out, &twiddles[level], stage_length, num_threads),
358            4 => butterfly_4_parallel(data_out, &twiddles[level], stage_length, num_threads),
359            _ => unimplemented!("radix unsupported"),
360        }
361    } else {
362        if stage_length == 1 {
363            for i in 0..radix {
364                data_out[i] = data_in[in_offset + i * stride];
365            }
366        } else {
367            for i in 0..radix {
368                recursive_fft_inner(
369                    data_in,
370                    &mut data_out[i * stage_length..(i + 1) * stage_length],
371                    twiddles,
372                    stages,
373                    in_offset + i * stride,
374                    stride * radix,
375                    level + 1,
376                    num_threads,
377                );
378            }
379        }
380        match radix {
381            2 => butterfly_2(data_out, &twiddles[level], stage_length),
382            4 => butterfly_4(data_out, &twiddles[level], stage_length),
383            _ => unimplemented!("radix unsupported"),
384        }
385    }
386}
387
388/// Todo: Brechts impl starts here
389fn recursive_fft<Scalar: Field, G: FftGroup<Scalar>>(
390    data: &FFTData<Scalar>,
391    data_in: &mut Vec<G>,
392    inverse: bool,
393) {
394    let num_threads = multicore::current_num_threads();
395    //let start = start_measure(format!("recursive fft {} ({})", data_in.len(), num_threads), false);
396
397    // TODO: reuse scratch buffer between FFTs
398    //let start_mem = start_measure(format!("alloc"), false);
399    let filler = data_in[0];
400    let mut scratch = vec![filler; data_in.len()];
401    //stop_measure(start_mem);
402
403    recursive_fft_inner(
404        data_in,
405        &mut /*data.*/scratch,
406        if inverse {
407            &data.inv_twiddles
408        } else {
409            &data.f_twiddles
410        },
411        &data.stages,
412        0,
413        1,
414        0,
415        num_threads,
416    );
417    //let duration = stop_measure(start);
418
419    //let start = start_measure(format!("copy"), false);
420    // Will simply swap the vector's buffer, no data is actually copied
421    std::mem::swap(data_in, &mut /*data.*/scratch);
422    //stop_measure(start);
423}
424
425/// This simple utility function will parallelize an operation that is to be
426/// performed over a mutable slice.
427fn parallelize_count<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(
428    v: &mut [T],
429    num_threads: usize,
430    f: F,
431) {
432    let n = v.len();
433    let mut chunk = n / num_threads;
434    if chunk < num_threads {
435        chunk = n;
436    }
437
438    multicore::scope(|scope| {
439        for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
440            let f = f.clone();
441            scope.spawn(move |_| {
442                f(v, chunk_num);
443            });
444        }
445    });
446}
447
448/// Generic adaptor
449pub fn fft<Scalar: Field, G: FftGroup<Scalar>>(
450    data_in: &mut [G],
451    _omega: Scalar,
452    _log_n: u32,
453    data: &FFTData<Scalar>,
454    inverse: bool,
455) {
456    let orig_len = data_in.len();
457    let mut data_in_vec = data_in.to_vec();
458    recursive_fft(data, &mut data_in_vec, inverse);
459    data_in.copy_from_slice(&data_in_vec);
460    assert_eq!(orig_len, data_in.len());
461}