1extern crate alloc;
8
9use alloc::vec::Vec;
10
11use itertools::izip;
12use p3_field::{Field, FieldAlgebra, PackedFieldPow2, PackedValue, TwoAdicField};
13use p3_util::log2_strict_usize;
14
15use crate::utils::monty_reduce;
16use crate::{FieldParameters, MontyField31, TwoAdicData};
17
18impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
19 pub fn roots_of_unity_table(n: usize) -> Vec<Vec<Self>> {
27 let lg_n = log2_strict_usize(n);
28 let gen = Self::two_adic_generator(lg_n);
29 let half_n = 1 << (lg_n - 1);
30 let nth_roots: Vec<_> = gen.powers().take(half_n).collect();
32
33 (0..(lg_n - 1))
34 .map(|i| nth_roots.iter().step_by(1 << i).copied().collect())
35 .collect()
36 }
37}
38
39#[inline(always)]
40fn forward_butterfly<T: FieldAlgebra + Copy>(x: T, y: T, roots: T) -> (T, T) {
41 let t = x - y;
42 (x + y, t * roots)
43}
44
45#[inline(always)]
46fn forward_butterfly_interleaved<const HALF_RADIX: usize, T: PackedFieldPow2>(
47 x: T,
48 y: T,
49 roots: T,
50) -> (T, T) {
51 let (x, y) = x.interleave(y, HALF_RADIX);
52 let (x, y) = forward_butterfly(x, y, roots);
53 x.interleave(y, HALF_RADIX)
54}
55
56#[inline]
57fn forward_pass_packed<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
58 let packed_roots = T::pack_slice(roots);
59 let n = input.len();
60 let (xs, ys) = unsafe { input.split_at_mut_unchecked(n / 2) };
61
62 izip!(xs, ys, packed_roots)
63 .for_each(|(x, y, &roots)| (*x, *y) = forward_butterfly(*x, *y, roots));
64}
65
66#[inline]
67fn forward_iterative_layer_1<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
68 let packed_roots = T::pack_slice(roots);
69 let n = input.len();
70 let (top_half, bottom_half) = unsafe { input.split_at_mut_unchecked(n / 2) };
71 let (xs, ys) = unsafe { top_half.split_at_mut_unchecked(n / 4) };
72 let (zs, ws) = unsafe { bottom_half.split_at_mut_unchecked(n / 4) };
73
74 izip!(xs, ys, zs, ws, packed_roots).for_each(|(x, y, z, w, &root)| {
75 (*x, *y) = forward_butterfly(*x, *y, root);
76 (*z, *w) = forward_butterfly(*z, *w, root);
77 });
78}
79
80#[inline]
81fn forward_iterative_packed<const HALF_RADIX: usize, T: PackedFieldPow2>(
82 input: &mut [T],
83 roots: &[T::Scalar],
84) {
85 let roots = T::from_fn(|i| roots[i % HALF_RADIX]);
88
89 input.chunks_exact_mut(2).for_each(|pair| {
90 let (x, y) = forward_butterfly_interleaved::<HALF_RADIX, _>(pair[0], pair[1], roots);
91 pair[0] = x;
92 pair[1] = y;
93 });
94}
95
96#[inline]
97fn forward_iterative_packed_radix_2<T: PackedFieldPow2>(input: &mut [T]) {
98 input.chunks_exact_mut(2).for_each(|pair| {
99 let x = pair[0];
100 let y = pair[1];
101 let (mut x, y) = x.interleave(y, 1);
102 let t = x - y; x += y;
104 let (x, y) = x.interleave(t, 1);
105 pair[0] = x;
106 pair[1] = y;
107 });
108}
109
110impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
111 #[inline]
112 fn forward_iterative_layer(
113 packed_input: &mut [<Self as Field>::Packing],
114 roots: &[Self],
115 m: usize,
116 ) {
117 debug_assert_eq!(roots.len(), m);
118 let packed_roots = <Self as Field>::Packing::pack_slice(roots);
119
120 let packed_m = m / <Self as Field>::Packing::WIDTH;
122 packed_input
123 .chunks_exact_mut(2 * packed_m)
124 .for_each(|layer_chunk| {
125 let (xs, ys) = unsafe { layer_chunk.split_at_mut_unchecked(packed_m) };
126
127 izip!(xs, ys, packed_roots)
128 .for_each(|(x, y, &root)| (*x, *y) = forward_butterfly(*x, *y, root));
129 });
130 }
131
132 #[inline]
133 fn forward_iterative_packed_radix_16(input: &mut [<Self as Field>::Packing]) {
134 if <Self as Field>::Packing::WIDTH >= 16 {
141 forward_iterative_packed::<8, _>(input, MP::ROOTS_16.as_ref());
142 } else {
143 Self::forward_iterative_layer(input, MP::ROOTS_16.as_ref(), 8);
144 }
145
146 if <Self as Field>::Packing::WIDTH >= 8 {
148 forward_iterative_packed::<4, _>(input, MP::ROOTS_8.as_ref());
149 } else {
150 Self::forward_iterative_layer(input, MP::ROOTS_8.as_ref(), 4);
151 }
152
153 let roots4 = [MP::ROOTS_8.as_ref()[0], MP::ROOTS_8.as_ref()[2]];
155 if <Self as Field>::Packing::WIDTH >= 4 {
156 forward_iterative_packed::<2, _>(input, &roots4);
157 } else {
158 Self::forward_iterative_layer(input, &roots4, 2);
159 }
160
161 forward_iterative_packed_radix_2(input);
163 }
164
165 #[inline]
167 fn forward_iterative(packed_input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
168 assert!(packed_input.len() >= 2);
169 let packing_width = <Self as Field>::Packing::WIDTH;
170 let n = packed_input.len() * packing_width;
171 let lg_n = log2_strict_usize(n);
172
173 const LAST_LOOP_LAYER: usize = 4;
177
178 const NUM_SPECIALISATIONS: usize = 2;
180
181 assert!(lg_n >= LAST_LOOP_LAYER + NUM_SPECIALISATIONS);
184
185 forward_pass_packed(packed_input, &root_table[0]); forward_iterative_layer_1(packed_input, &root_table[1]); for lg_m in (LAST_LOOP_LAYER..(lg_n - NUM_SPECIALISATIONS)).rev() {
191 let s = lg_n - lg_m - 1;
192 let m = 1 << lg_m;
193
194 let roots = &root_table[s];
195 debug_assert_eq!(roots.len(), m);
196
197 Self::forward_iterative_layer(packed_input, roots, m);
198 }
199
200 Self::forward_iterative_packed_radix_16(packed_input);
202 }
203
204 #[inline(always)]
205 fn forward_butterfly(x: Self, y: Self, w: Self) -> (Self, Self) {
206 let t = MP::PRIME + x.value - y.value;
207 (
208 x + y,
209 Self::new_monty(monty_reduce::<MP>(t as u64 * w.value as u64)),
210 )
211 }
212
213 #[inline]
214 fn forward_pass(input: &mut [Self], roots: &[Self]) {
215 let half_n = input.len() / 2;
216 assert_eq!(roots.len(), half_n);
217
218 let (xs, ys) = unsafe { input.split_at_mut_unchecked(half_n) };
220
221 let s = xs[0] + ys[0];
222 let t = xs[0] - ys[0];
223 xs[0] = s;
224 ys[0] = t;
225
226 izip!(&mut xs[1..], &mut ys[1..], &roots[1..]).for_each(|(x, y, &root)| {
227 (*x, *y) = Self::forward_butterfly(*x, *y, root);
228 });
229 }
230
231 #[inline(always)]
232 fn forward_2(a: &mut [Self]) {
233 assert_eq!(a.len(), 2);
234
235 let s = a[0] + a[1];
236 let t = a[0] - a[1];
237 a[0] = s;
238 a[1] = t;
239 }
240
241 #[inline(always)]
242 fn forward_4(a: &mut [Self]) {
243 assert_eq!(a.len(), 4);
244
245 let t1 = MP::PRIME + a[1].value - a[3].value;
247 let t3 = MontyField31::new_monty(monty_reduce::<MP>(
248 t1 as u64 * MP::ROOTS_8.as_ref()[2].value as u64,
249 ));
250 let t5 = a[1] + a[3];
251 let t4 = a[0] + a[2];
252 let t2 = a[0] - a[2];
253
254 a[0] = t4 + t5;
256 a[1] = t4 - t5;
257 a[2] = t2 + t3;
258 a[3] = t2 - t3;
259 }
260
261 #[inline(always)]
262 fn forward_8(a: &mut [Self]) {
263 assert_eq!(a.len(), 8);
264
265 Self::forward_pass(a, MP::ROOTS_8.as_ref());
266
267 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
269 Self::forward_4(a0);
270 Self::forward_4(a1);
271 }
272
273 #[inline(always)]
274 fn forward_16(a: &mut [Self]) {
275 assert_eq!(a.len(), 16);
276
277 Self::forward_pass(a, MP::ROOTS_16.as_ref());
278
279 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
281 Self::forward_8(a0);
282 Self::forward_8(a1);
283 }
284
285 #[inline(always)]
286 fn forward_32(a: &mut [Self], root_table: &[Vec<Self>]) {
287 assert_eq!(a.len(), 32);
288
289 Self::forward_pass(a, &root_table[0]);
290
291 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
293 Self::forward_16(a0);
294 Self::forward_16(a1);
295 }
296
297 #[inline]
299 fn forward_fft_recur(input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
300 const ITERATIVE_FFT_THRESHOLD: usize = 1024;
301
302 let n = input.len() * <Self as Field>::Packing::WIDTH;
303 if n <= ITERATIVE_FFT_THRESHOLD {
304 Self::forward_iterative(input, root_table);
305 } else {
306 assert_eq!(n, 1 << (root_table.len() + 1));
307 forward_pass_packed(input, &root_table[0]);
308
309 let (a0, a1) = unsafe { input.split_at_mut_unchecked(input.len() / 2) };
311
312 Self::forward_fft_recur(a0, &root_table[1..]);
313 Self::forward_fft_recur(a1, &root_table[1..]);
314 }
315 }
316
317 #[inline]
318 pub fn forward_fft(input: &mut [Self], root_table: &[Vec<Self>]) {
319 let n = input.len();
320 if n == 1 {
321 return;
322 }
323 assert_eq!(n, 1 << (root_table.len() + 1));
324 match n {
325 32 => Self::forward_32(input, root_table),
326 16 => Self::forward_16(input),
327 8 => Self::forward_8(input),
328 4 => Self::forward_4(input),
329 2 => Self::forward_2(input),
330 _ => {
331 let packed_input = <Self as Field>::Packing::pack_slice_mut(input);
332 Self::forward_fft_recur(packed_input, root_table)
333 }
334 }
335 }
336}