halo2curves/
fft.rs
1use ff::Field;
2use group::{GroupOpsOwned, ScalarMulOwned};
3
4pub use crate::{CurveAffine, CurveExt};
5
6pub trait FftGroup<Scalar: Field>:
10 Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
11{
12}
13
14impl<T, Scalar> FftGroup<Scalar> for T
15where
16 Scalar: Field,
17 T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
18{
19}
20
21pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
32 fn bitreverse(mut n: usize, l: usize) -> usize {
33 let mut r = 0;
34 for _ in 0..l {
35 r = (r << 1) | (n & 1);
36 n >>= 1;
37 }
38 r
39 }
40
41 let threads = rayon::current_num_threads();
42 let log_threads = threads.ilog2();
43 let n = a.len();
44 assert_eq!(n, 1 << log_n);
45
46 for k in 0..n {
47 let rk = bitreverse(k, log_n as usize);
48 if k < rk {
49 a.swap(rk, k);
50 }
51 }
52
53 let twiddles: Vec<_> = (0..(n / 2))
55 .scan(Scalar::ONE, |w, _| {
56 let tw = *w;
57 *w *= ω
58 Some(tw)
59 })
60 .collect();
61
62 if log_n <= log_threads {
63 let mut chunk = 2_usize;
64 let mut twiddle_chunk = n / 2;
65 for _ in 0..log_n {
66 a.chunks_mut(chunk).for_each(|coeffs| {
67 let (left, right) = coeffs.split_at_mut(chunk / 2);
68
69 let (a, left) = left.split_at_mut(1);
71 let (b, right) = right.split_at_mut(1);
72 let t = b[0];
73 b[0] = a[0];
74 a[0] += &t;
75 b[0] -= &t;
76
77 left.iter_mut()
78 .zip(right.iter_mut())
79 .enumerate()
80 .for_each(|(i, (a, b))| {
81 let mut t = *b;
82 t *= &twiddles[(i + 1) * twiddle_chunk];
83 *b = *a;
84 *a += &t;
85 *b -= &t;
86 });
87 });
88 chunk *= 2;
89 twiddle_chunk /= 2;
90 }
91 } else {
92 recursive_butterfly_arithmetic(a, n, 1, &twiddles)
93 }
94}
95
96pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
98 a: &mut [G],
99 n: usize,
100 twiddle_chunk: usize,
101 twiddles: &[Scalar],
102) {
103 if n == 2 {
104 let t = a[1];
105 a[1] = a[0];
106 a[0] += &t;
107 a[1] -= &t;
108 } else {
109 let (left, right) = a.split_at_mut(n / 2);
110 rayon::join(
111 || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
112 || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
113 );
114
115 let (a, left) = left.split_at_mut(1);
117 let (b, right) = right.split_at_mut(1);
118 let t = b[0];
119 b[0] = a[0];
120 a[0] += &t;
121 b[0] -= &t;
122
123 left.iter_mut()
124 .zip(right.iter_mut())
125 .enumerate()
126 .for_each(|(i, (a, b))| {
127 let mut t = *b;
128 t *= &twiddles[(i + 1) * twiddle_chunk];
129 *b = *a;
130 *a += &t;
131 *b -= &t;
132 });
133 }
134}