halo2curves_axiom/
fft.rs
1pub use crate::{CurveAffine, CurveExt};
2use ff::Field;
3use group::{GroupOpsOwned, ScalarMulOwned};
4
5pub trait FftGroup<Scalar: Field>:
9 Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
10{
11}
12
13impl<T, Scalar> FftGroup<Scalar> for T
14where
15 Scalar: Field,
16 T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
17{
18}
19
20pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
31 fn bitreverse(mut n: usize, l: usize) -> usize {
32 let mut r = 0;
33 for _ in 0..l {
34 r = (r << 1) | (n & 1);
35 n >>= 1;
36 }
37 r
38 }
39
40 let threads = rayon::current_num_threads();
41 let log_threads = threads.ilog2();
42 let n = a.len();
43 assert_eq!(n, 1 << log_n);
44
45 for k in 0..n {
46 let rk = bitreverse(k, log_n as usize);
47 if k < rk {
48 a.swap(rk, k);
49 }
50 }
51
52 let twiddles: Vec<_> = (0..(n / 2))
54 .scan(Scalar::ONE, |w, _| {
55 let tw = *w;
56 *w *= ω
57 Some(tw)
58 })
59 .collect();
60
61 if log_n <= log_threads {
62 let mut chunk = 2_usize;
63 let mut twiddle_chunk = n / 2;
64 for _ in 0..log_n {
65 a.chunks_mut(chunk).for_each(|coeffs| {
66 let (left, right) = coeffs.split_at_mut(chunk / 2);
67
68 let (a, left) = left.split_at_mut(1);
70 let (b, right) = right.split_at_mut(1);
71 let t = b[0];
72 b[0] = a[0];
73 a[0] += &t;
74 b[0] -= &t;
75
76 left.iter_mut()
77 .zip(right.iter_mut())
78 .enumerate()
79 .for_each(|(i, (a, b))| {
80 let mut t = *b;
81 t *= &twiddles[(i + 1) * twiddle_chunk];
82 *b = *a;
83 *a += &t;
84 *b -= &t;
85 });
86 });
87 chunk *= 2;
88 twiddle_chunk /= 2;
89 }
90 } else {
91 recursive_butterfly_arithmetic(a, n, 1, &twiddles)
92 }
93}
94
95pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
97 a: &mut [G],
98 n: usize,
99 twiddle_chunk: usize,
100 twiddles: &[Scalar],
101) {
102 if n == 2 {
103 let t = a[1];
104 a[1] = a[0];
105 a[0] += &t;
106 a[1] -= &t;
107 } else {
108 let (left, right) = a.split_at_mut(n / 2);
109 rayon::join(
110 || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
111 || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
112 );
113
114 let (a, left) = left.split_at_mut(1);
116 let (b, right) = right.split_at_mut(1);
117 let t = b[0];
118 b[0] = a[0];
119 a[0] += &t;
120 b[0] -= &t;
121
122 left.iter_mut()
123 .zip(right.iter_mut())
124 .enumerate()
125 .for_each(|(i, (a, b))| {
126 let mut t = *b;
127 t *= &twiddles[(i + 1) * twiddle_chunk];
128 *b = *a;
129 *a += &t;
130 *b -= &t;
131 });
132 }
133}