p3_fri/
fold_even_odd.rs

1use alloc::vec::Vec;
2
3use itertools::Itertools;
4use p3_field::TwoAdicField;
5use p3_matrix::dense::RowMajorMatrix;
6use p3_matrix::Matrix;
7use p3_maybe_rayon::prelude::*;
8use p3_util::{log2_strict_usize, reverse_slice_index_bits};
9use tracing::instrument;
10
11/// Fold a polynomial
12/// ```ignore
13/// p(x) = p_even(x^2) + x p_odd(x^2)
14/// ```
15/// into
16/// ```ignore
17/// p_even(x) + beta p_odd(x)
18/// ```
19/// Expects input to be bit-reversed evaluations.
20#[instrument(skip_all, level = "debug")]
21pub fn fold_even_odd<F: TwoAdicField>(poly: Vec<F>, beta: F) -> Vec<F> {
22    // We use the fact that
23    //     p_e(x^2) = (p(x) + p(-x)) / 2
24    //     p_o(x^2) = (p(x) - p(-x)) / (2 x)
25    // that is,
26    //     p_e(g^(2i)) = (p(g^i) + p(g^(n/2 + i))) / 2
27    //     p_o(g^(2i)) = (p(g^i) - p(g^(n/2 + i))) / (2 g^i)
28    // so
29    //     result(g^(2i)) = p_e(g^(2i)) + beta p_o(g^(2i))
30    //                    = (1/2 + beta/2 g_inv^i) p(g^i)
31    //                    + (1/2 - beta/2 g_inv^i) p(g^(n/2 + i))
32    let m = RowMajorMatrix::new(poly, 2);
33    let g_inv = F::two_adic_generator(log2_strict_usize(m.height()) + 1).inverse();
34    let one_half = F::TWO.inverse();
35    let half_beta = beta * one_half;
36
37    // TODO: vectorize this (after we have packed extension fields)
38
39    // beta/2 times successive powers of g_inv
40    let mut powers = g_inv
41        .shifted_powers(half_beta)
42        .take(m.height())
43        .collect_vec();
44    reverse_slice_index_bits(&mut powers);
45
46    m.par_rows()
47        .zip(powers)
48        .map(|(mut row, power)| {
49            let (r0, r1) = row.next_tuple().unwrap();
50            (one_half + power) * r0 + (one_half - power) * r1
51        })
52        .collect()
53}
54
55#[cfg(test)]
56mod tests {
57    use itertools::izip;
58    use p3_baby_bear::BabyBear;
59    use p3_dft::{Radix2Dit, TwoAdicSubgroupDft};
60    use rand::{thread_rng, Rng};
61
62    use super::*;
63
64    #[test]
65    fn test_fold_even_odd() {
66        type F = BabyBear;
67
68        let mut rng = thread_rng();
69
70        let log_n = 10;
71        let n = 1 << log_n;
72        let coeffs = (0..n).map(|_| rng.gen::<F>()).collect::<Vec<_>>();
73
74        let dft = Radix2Dit::default();
75        let evals = dft.dft(coeffs.clone());
76
77        let even_coeffs = coeffs.iter().cloned().step_by(2).collect_vec();
78        let even_evals = dft.dft(even_coeffs);
79
80        let odd_coeffs = coeffs.iter().cloned().skip(1).step_by(2).collect_vec();
81        let odd_evals = dft.dft(odd_coeffs);
82
83        let beta = rng.gen::<F>();
84        let expected = izip!(even_evals, odd_evals)
85            .map(|(even, odd)| even + beta * odd)
86            .collect::<Vec<_>>();
87
88        // fold_even_odd takes and returns in bitrev order.
89        let mut folded = evals;
90        reverse_slice_index_bits(&mut folded);
91        folded = fold_even_odd(folded, beta);
92        reverse_slice_index_bits(&mut folded);
93
94        assert_eq!(expected, folded);
95    }
96}