use alloc::vec::Vec;
use itertools::Itertools;
use p3_field::TwoAdicField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::*;
use p3_util::{log2_strict_usize, reverse_slice_index_bits};
use tracing::instrument;
#[instrument(skip_all, level = "debug")]
pub fn fold_even_odd<F: TwoAdicField>(poly: Vec<F>, beta: F) -> Vec<F> {
let m = RowMajorMatrix::new(poly, 2);
let g_inv = F::two_adic_generator(log2_strict_usize(m.height()) + 1).inverse();
let one_half = F::TWO.inverse();
let half_beta = beta * one_half;
let mut powers = g_inv
.shifted_powers(half_beta)
.take(m.height())
.collect_vec();
reverse_slice_index_bits(&mut powers);
m.par_rows()
.zip(powers)
.map(|(mut row, power)| {
let (r0, r1) = row.next_tuple().unwrap();
(one_half + power) * r0 + (one_half - power) * r1
})
.collect()
}
#[cfg(test)]
mod tests {
use itertools::izip;
use p3_baby_bear::BabyBear;
use p3_dft::{Radix2Dit, TwoAdicSubgroupDft};
use rand::{thread_rng, Rng};
use super::*;
#[test]
fn test_fold_even_odd() {
type F = BabyBear;
let mut rng = thread_rng();
let log_n = 10;
let n = 1 << log_n;
let coeffs = (0..n).map(|_| rng.gen::<F>()).collect::<Vec<_>>();
let dft = Radix2Dit::default();
let evals = dft.dft(coeffs.clone());
let even_coeffs = coeffs.iter().cloned().step_by(2).collect_vec();
let even_evals = dft.dft(even_coeffs);
let odd_coeffs = coeffs.iter().cloned().skip(1).step_by(2).collect_vec();
let odd_evals = dft.dft(odd_coeffs);
let beta = rng.gen::<F>();
let expected = izip!(even_evals, odd_evals)
.map(|(even, odd)| even + beta * odd)
.collect::<Vec<_>>();
let mut folded = evals;
reverse_slice_index_bits(&mut folded);
folded = fold_even_odd(folded, beta);
reverse_slice_index_bits(&mut folded);
assert_eq!(expected, folded);
}
}