halo2_axiom/fft/
parallel.rs

1//! This module provides common utilities, traits and structures for group,
2//! field and polynomial arithmetic.
3
4use crate::arithmetic::{self, log2_floor, FftGroup};
5
6use crate::multicore;
7pub use ff::Field;
8pub use halo2curves::{CurveAffine, CurveExt};
9
10use super::recursive::FFTData;
11
12/// A constant
13pub const SPARSE_TWIDDLE_DEGREE: u32 = 10;
14
15/// Dispatcher
16fn best_fft_opt<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
17    let threads = multicore::current_num_threads();
18    let log_split = log2_floor(threads) as usize;
19    let n = a.len();
20    let sub_n = n >> log_split;
21    let split_m = 1 << log_split;
22
23    if sub_n >= split_m {
24        parallel_fft(a, omega, log_n);
25    } else {
26        serial_fft(a, omega, log_n);
27    }
28}
29
30fn serial_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
31    let n = a.len() as u32;
32    assert_eq!(n, 1 << log_n);
33
34    for k in 0..n as usize {
35        let rk = arithmetic::bitreverse(k, log_n as usize);
36        if k < rk {
37            a.swap(rk, k);
38        }
39    }
40
41    let mut m = 1;
42    for _ in 0..log_n {
43        let w_m: Scalar = omega.pow_vartime([u64::from(n / (2 * m)), 0, 0, 0]);
44
45        let mut k = 0;
46        while k < n {
47            let mut w = Scalar::ONE;
48            for j in 0..m {
49                let mut t = a[(k + j + m) as usize];
50                t *= &w;
51                a[(k + j + m) as usize] = a[(k + j) as usize];
52                a[(k + j + m) as usize] -= &t;
53                a[(k + j) as usize] += &t;
54                w *= &w_m;
55            }
56
57            k += 2 * m;
58        }
59
60        m *= 2;
61    }
62}
63
64fn serial_split_fft<Scalar: Field, G: FftGroup<Scalar>>(
65    a: &mut [G],
66    twiddle_lut: &[Scalar],
67    twiddle_scale: usize,
68    log_n: u32,
69) {
70    let n = a.len() as u32;
71    assert_eq!(n, 1 << log_n);
72
73    let mut m = 1;
74    for _ in 0..log_n {
75        let omega_idx = twiddle_scale * n as usize / (2 * m as usize); // 1/2, 1/4, 1/8, ...
76        let low_idx = omega_idx % (1 << SPARSE_TWIDDLE_DEGREE);
77        let high_idx = omega_idx >> SPARSE_TWIDDLE_DEGREE;
78        let mut w_m = twiddle_lut[low_idx];
79        if high_idx > 0 {
80            w_m *= twiddle_lut[(1 << SPARSE_TWIDDLE_DEGREE) + high_idx];
81        }
82
83        let mut k = 0;
84        while k < n {
85            let mut w = Scalar::ONE;
86            for j in 0..m {
87                let mut t = a[(k + j + m) as usize];
88                t *= &w;
89                a[(k + j + m) as usize] = a[(k + j) as usize];
90                a[(k + j + m) as usize] -= &t;
91                a[(k + j) as usize] += &t;
92                w *= &w_m;
93            }
94
95            k += 2 * m;
96        }
97
98        m *= 2;
99    }
100}
101
102fn split_radix_fft<Scalar: Field, G: FftGroup<Scalar>>(
103    tmp: &mut [G],
104    a: &[G],
105    twiddle_lut: &[Scalar],
106    n: usize,
107    sub_fft_offset: usize,
108    log_split: usize,
109) {
110    let split_m = 1 << log_split;
111    let sub_n = n >> log_split;
112
113    // we use out-place bitreverse here, split_m <= num_threads, so the buffer spase is small
114    // and it's is good for data locality
115    let tmp_filler_val = tmp[0];
116    let mut t1 = vec![tmp_filler_val; split_m];
117    for i in 0..split_m {
118        t1[arithmetic::bitreverse(i, log_split)] = a[i * sub_n + sub_fft_offset];
119    }
120    serial_split_fft(&mut t1, twiddle_lut, sub_n, log_split as u32);
121
122    let sparse_degree = SPARSE_TWIDDLE_DEGREE;
123    let omega_idx = sub_fft_offset;
124    let low_idx = omega_idx % (1 << sparse_degree);
125    let high_idx = omega_idx >> sparse_degree;
126    let mut omega = twiddle_lut[low_idx];
127    if high_idx > 0 {
128        omega *= twiddle_lut[(1 << sparse_degree) + high_idx];
129    }
130    let mut w_m = Scalar::ONE;
131    for i in 0..split_m {
132        t1[i] *= &w_m;
133        tmp[i] = t1[i];
134        w_m *= omega;
135    }
136}
137
138/// Precalculate twiddles factors
139fn generate_twiddle_lookup_table<F: Field>(
140    omega: F,
141    log_n: u32,
142    sparse_degree: u32,
143    with_last_level: bool,
144) -> Vec<F> {
145    let without_last_level = !with_last_level;
146    let is_lut_len_large = sparse_degree > log_n;
147
148    // dense
149    if is_lut_len_large {
150        let mut twiddle_lut = vec![F::ZERO; (1 << log_n) as usize];
151        parallelize(&mut twiddle_lut, |twiddle_lut, start| {
152            let mut w_n = omega.pow_vartime([start as u64, 0, 0, 0]);
153            for twiddle_lut in twiddle_lut.iter_mut() {
154                *twiddle_lut = w_n;
155                w_n *= omega;
156            }
157        });
158        return twiddle_lut;
159    }
160
161    // sparse
162    let low_degree_lut_len = 1 << sparse_degree;
163    let high_degree_lut_len = 1 << (log_n - sparse_degree - without_last_level as u32);
164    let mut twiddle_lut = vec![F::ZERO; low_degree_lut_len + high_degree_lut_len];
165    parallelize(
166        &mut twiddle_lut[..low_degree_lut_len],
167        |twiddle_lut, start| {
168            let mut w_n = omega.pow_vartime([start as u64, 0, 0, 0]);
169            for twiddle_lut in twiddle_lut.iter_mut() {
170                *twiddle_lut = w_n;
171                w_n *= omega;
172            }
173        },
174    );
175    let high_degree_omega = omega.pow_vartime([(1 << sparse_degree) as u64, 0, 0, 0]);
176    parallelize(
177        &mut twiddle_lut[low_degree_lut_len..],
178        |twiddle_lut, start| {
179            let mut w_n = high_degree_omega.pow_vartime([start as u64, 0, 0, 0]);
180            for twiddle_lut in twiddle_lut.iter_mut() {
181                *twiddle_lut = w_n;
182                w_n *= high_degree_omega;
183            }
184        },
185    );
186    twiddle_lut
187}
188
189/// The parallel implementation
190fn parallel_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
191    let n = a.len();
192    assert_eq!(n, 1 << log_n);
193
194    let log_split = log2_floor(multicore::current_num_threads()) as usize;
195    let split_m = 1 << log_split;
196    let sub_n = n >> log_split;
197    let twiddle_lut = generate_twiddle_lookup_table(omega, log_n, SPARSE_TWIDDLE_DEGREE, true);
198
199    // split fft
200    let tmp_filler_val = a[0];
201    let mut tmp = vec![tmp_filler_val; n];
202    multicore::scope(|scope| {
203        let a = &*a;
204        let twiddle_lut = &*twiddle_lut;
205        for (chunk_idx, tmp) in tmp.chunks_mut(sub_n).enumerate() {
206            scope.spawn(move |_| {
207                let split_fft_offset = (chunk_idx * sub_n) >> log_split;
208                for (i, tmp) in tmp.chunks_mut(split_m).enumerate() {
209                    let split_fft_offset = split_fft_offset + i;
210                    split_radix_fft(tmp, a, twiddle_lut, n, split_fft_offset, log_split);
211                }
212            });
213        }
214    });
215
216    // shuffle
217    parallelize(a, |a, start| {
218        for (idx, a) in a.iter_mut().enumerate() {
219            let idx = start + idx;
220            let i = idx / sub_n;
221            let j = idx % sub_n;
222            *a = tmp[j * split_m + i];
223        }
224    });
225
226    // sub fft
227    let new_omega = omega.pow_vartime([split_m as u64, 0, 0, 0]);
228    multicore::scope(|scope| {
229        for a in a.chunks_mut(sub_n) {
230            scope.spawn(move |_| {
231                serial_fft(a, new_omega, log_n - log_split as u32);
232            });
233        }
234    });
235
236    // copy & unshuffle
237    let mask = (1 << log_split) - 1;
238    parallelize(&mut tmp, |tmp, start| {
239        for (idx, tmp) in tmp.iter_mut().enumerate() {
240            let idx = start + idx;
241            *tmp = a[idx];
242        }
243    });
244    parallelize(a, |a, start| {
245        for (idx, a) in a.iter_mut().enumerate() {
246            let idx = start + idx;
247            *a = tmp[sub_n * (idx & mask) + (idx >> log_split)];
248        }
249    });
250}
251
252/// This simple utility function will parallelize an operation that is to be
253/// performed over a mutable slice.
254fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
255    let n = v.len();
256    let num_threads = multicore::current_num_threads();
257    let mut chunk = n / num_threads;
258    if chunk < num_threads {
259        chunk = n;
260    }
261
262    multicore::scope(|scope| {
263        for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
264            let f = f.clone();
265            scope.spawn(move |_| {
266                let start = chunk_num * chunk;
267                f(v, start);
268            });
269        }
270    });
271}
272
273/// Generic adaptor
274pub fn fft<Scalar: Field, G: FftGroup<Scalar>>(
275    data_in: &mut [G],
276    omega: Scalar,
277    log_n: u32,
278    _data: &FFTData<Scalar>,
279    _inverse: bool,
280) {
281    best_fft_opt(data_in, omega, log_n)
282}