1extern crate alloc;
8use alloc::vec::Vec;
9
10use itertools::izip;
11use p3_field::{Field, FieldAlgebra, PackedFieldPow2, PackedValue};
12use p3_util::log2_strict_usize;
13
14use crate::utils::monty_reduce;
15use crate::{FieldParameters, MontyField31, TwoAdicData};
16
17#[inline(always)]
18fn backward_butterfly<T: FieldAlgebra + Copy>(x: T, y: T, roots: T) -> (T, T) {
19 let t = y * roots;
20 (x + t, x - t)
21}
22
23#[inline(always)]
24fn backward_butterfly_interleaved<const HALF_RADIX: usize, T: PackedFieldPow2>(
25 x: T,
26 y: T,
27 roots: T,
28) -> (T, T) {
29 let (x, y) = x.interleave(y, HALF_RADIX);
30 let (x, y) = backward_butterfly(x, y, roots);
31 x.interleave(y, HALF_RADIX)
32}
33
34#[inline]
35fn backward_pass_packed<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
36 let packed_roots = T::pack_slice(roots);
37 let n = input.len();
38 let (xs, ys) = unsafe { input.split_at_mut_unchecked(n / 2) };
39
40 izip!(xs, ys, packed_roots)
41 .for_each(|(x, y, &roots)| (*x, *y) = backward_butterfly(*x, *y, roots));
42}
43
44#[inline]
45fn backward_iterative_layer_1<T: PackedFieldPow2>(input: &mut [T], roots: &[T::Scalar]) {
46 let packed_roots = T::pack_slice(roots);
47 let n = input.len();
48 let (top_half, bottom_half) = unsafe { input.split_at_mut_unchecked(n / 2) };
49 let (xs, ys) = unsafe { top_half.split_at_mut_unchecked(n / 4) };
50 let (zs, ws) = unsafe { bottom_half.split_at_mut_unchecked(n / 4) };
51
52 izip!(xs, ys, zs, ws, packed_roots).for_each(|(x, y, z, w, &root)| {
53 (*x, *y) = backward_butterfly(*x, *y, root);
54 (*z, *w) = backward_butterfly(*z, *w, root);
55 });
56}
57
58#[inline]
59fn backward_iterative_packed<const HALF_RADIX: usize, T: PackedFieldPow2>(
60 input: &mut [T],
61 roots: &[T::Scalar],
62) {
63 let roots = T::from_fn(|i| roots[i % HALF_RADIX]);
66
67 input.chunks_exact_mut(2).for_each(|pair| {
68 let (x, y) = backward_butterfly_interleaved::<HALF_RADIX, _>(pair[0], pair[1], roots);
69 pair[0] = x;
70 pair[1] = y;
71 });
72}
73
74#[inline]
75fn backward_iterative_packed_radix_2<T: PackedFieldPow2>(input: &mut [T]) {
76 input.chunks_exact_mut(2).for_each(|pair| {
77 let x = pair[0];
78 let y = pair[1];
79 let (mut x, y) = x.interleave(y, 1);
80 let t = x - y; x += y;
82 let (x, y) = x.interleave(t, 1);
83 pair[0] = x;
84 pair[1] = y;
85 });
86}
87
88impl<MP: FieldParameters + TwoAdicData> MontyField31<MP> {
89 #[inline]
91 fn backward_iterative_layer(
92 packed_input: &mut [<Self as Field>::Packing],
93 roots: &[Self],
94 m: usize,
95 ) {
96 debug_assert_eq!(roots.len(), m);
97 let packed_roots = <Self as Field>::Packing::pack_slice(roots);
98
99 let packed_m = m / <Self as Field>::Packing::WIDTH;
101 packed_input
102 .chunks_exact_mut(2 * packed_m)
103 .for_each(|layer_chunk| {
104 let (xs, ys) = unsafe { layer_chunk.split_at_mut_unchecked(packed_m) };
105
106 izip!(xs, ys, packed_roots)
107 .for_each(|(x, y, &root)| (*x, *y) = backward_butterfly(*x, *y, root));
108 });
109 }
110
111 #[inline]
112 fn backward_iterative_packed_radix_16(input: &mut [<Self as Field>::Packing]) {
113 backward_iterative_packed_radix_2(input);
120
121 let roots4 = [MP::INV_ROOTS_8.as_ref()[0], MP::INV_ROOTS_8.as_ref()[2]];
123 if <Self as Field>::Packing::WIDTH >= 4 {
124 backward_iterative_packed::<2, _>(input, &roots4);
125 } else {
126 Self::backward_iterative_layer(input, &roots4, 2);
127 }
128
129 if <Self as Field>::Packing::WIDTH >= 8 {
131 backward_iterative_packed::<4, _>(input, MP::INV_ROOTS_8.as_ref());
132 } else {
133 Self::backward_iterative_layer(input, MP::INV_ROOTS_8.as_ref(), 4);
134 }
135
136 if <Self as Field>::Packing::WIDTH >= 16 {
138 backward_iterative_packed::<8, _>(input, MP::INV_ROOTS_16.as_ref());
139 } else {
140 Self::backward_iterative_layer(input, MP::INV_ROOTS_16.as_ref(), 8);
141 }
142 }
143
144 fn backward_iterative(packed_input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
145 assert!(packed_input.len() >= 2);
146 let packing_width = <Self as Field>::Packing::WIDTH;
147 let n = packed_input.len() * packing_width;
148 let lg_n = log2_strict_usize(n);
149
150 const FIRST_LOOP_LAYER: usize = 4;
154
155 const NUM_SPECIALISATIONS: usize = 2;
157
158 assert!(lg_n >= FIRST_LOOP_LAYER + NUM_SPECIALISATIONS);
161
162 Self::backward_iterative_packed_radix_16(packed_input);
163
164 for lg_m in FIRST_LOOP_LAYER..(lg_n - NUM_SPECIALISATIONS) {
165 let s = lg_n - lg_m - 1;
166 let m = 1 << lg_m;
167
168 let roots = &root_table[s];
169 debug_assert_eq!(roots.len(), m);
170
171 Self::backward_iterative_layer(packed_input, roots, m);
172 }
173 backward_iterative_layer_1(packed_input, &root_table[1]); backward_pass_packed(packed_input, &root_table[0]); }
177
178 #[inline]
179 fn backward_pass(input: &mut [Self], roots: &[Self]) {
180 let half_n = input.len() / 2;
181 assert_eq!(roots.len(), half_n);
182
183 let (xs, ys) = unsafe { input.split_at_mut_unchecked(half_n) };
185
186 let s = xs[0] + ys[0];
187 let t = xs[0] - ys[0];
188 xs[0] = s;
189 ys[0] = t;
190
191 izip!(&mut xs[1..], &mut ys[1..], &roots[1..]).for_each(|(x, y, &root)| {
192 (*x, *y) = backward_butterfly(*x, *y, root);
193 });
194 }
195
196 #[inline(always)]
197 fn backward_2(a: &mut [Self]) {
198 assert_eq!(a.len(), 2);
199
200 let s = a[0] + a[1];
201 let t = a[0] - a[1];
202 a[0] = s;
203 a[1] = t;
204 }
205
206 #[inline(always)]
207 fn backward_4(a: &mut [Self]) {
208 assert_eq!(a.len(), 4);
209
210 let a0 = a[0];
212 let a2 = a[1];
213 let a1 = a[2];
214 let a3 = a[3];
215
216 let t1 = MP::PRIME + a1.value - a3.value;
218 let t3 = MontyField31::new_monty(monty_reduce::<MP>(
219 t1 as u64 * MP::INV_ROOTS_8.as_ref()[2].value as u64,
220 ));
221 let t5 = a1 + a3;
222 let t4 = a0 + a2;
223 let t2 = a0 - a2;
224
225 a[0] = t4 + t5;
226 a[1] = t2 + t3;
227 a[2] = t4 - t5;
228 a[3] = t2 - t3;
229 }
230
231 #[inline(always)]
232 fn backward_8(a: &mut [Self]) {
233 assert_eq!(a.len(), 8);
234
235 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
237 Self::backward_4(a0);
238 Self::backward_4(a1);
239
240 Self::backward_pass(a, MP::INV_ROOTS_8.as_ref());
241 }
242
243 #[inline(always)]
244 fn backward_16(a: &mut [Self]) {
245 assert_eq!(a.len(), 16);
246
247 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
249 Self::backward_8(a0);
250 Self::backward_8(a1);
251
252 Self::backward_pass(a, MP::INV_ROOTS_16.as_ref());
253 }
254
255 #[inline(always)]
256 fn backward_32(a: &mut [Self], root_table: &[Vec<Self>]) {
257 assert_eq!(a.len(), 32);
258
259 let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) };
261 Self::backward_16(a0);
262 Self::backward_16(a1);
263
264 Self::backward_pass(a, &root_table[0]);
265 }
266
267 #[inline]
270 fn backward_fft_recur(input: &mut [<Self as Field>::Packing], root_table: &[Vec<Self>]) {
271 const ITERATIVE_FFT_THRESHOLD: usize = 1024;
272
273 let n = input.len() * <Self as Field>::Packing::WIDTH;
274 if n <= ITERATIVE_FFT_THRESHOLD {
275 Self::backward_iterative(input, root_table);
276 } else {
277 assert_eq!(n, 1 << (root_table.len() + 1));
278
279 let (a0, a1) = unsafe { input.split_at_mut_unchecked(input.len() / 2) };
281 Self::backward_fft_recur(a0, &root_table[1..]);
282 Self::backward_fft_recur(a1, &root_table[1..]);
283
284 backward_pass_packed(input, &root_table[0]);
285 }
286 }
287
288 #[inline]
289 pub fn backward_fft(input: &mut [Self], root_table: &[Vec<Self>]) {
290 let n = input.len();
291 if n == 1 {
292 return;
293 }
294
295 assert_eq!(n, 1 << (root_table.len() + 1));
296 match n {
297 32 => Self::backward_32(input, root_table),
298 16 => Self::backward_16(input),
299 8 => Self::backward_8(input),
300 4 => Self::backward_4(input),
301 2 => Self::backward_2(input),
302 _ => {
303 let packed_input = <Self as Field>::Packing::pack_slice_mut(input);
304 Self::backward_fft_recur(packed_input, root_table)
305 }
306 }
307 }
308}