p3_mds/
integrated_coset_mds.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use alloc::vec::Vec;

use p3_field::{AbstractField, Powers, TwoAdicField};
use p3_symmetric::Permutation;
use p3_util::{log2_strict_usize, reverse_slice_index_bits};

use crate::butterflies::{dif_butterfly, dit_butterfly, twiddle_free_butterfly};
use crate::MdsPermutation;

/// Like `CosetMds`, with a few differences:
/// - (Bit reversed, a la Bowers) DIF + DIT rather than DIT + DIF
/// - We skip bit reversals of the inputs and outputs
/// - We don't weight by `1/N`, since this doesn't affect the MDS property
/// - We integrate the coset shifts into the DIF's twiddle factors
#[derive(Clone, Debug)]
pub struct IntegratedCosetMds<F, const N: usize> {
    ifft_twiddles: Vec<F>,
    fft_twiddles: Vec<Vec<F>>,
}

impl<F: TwoAdicField, const N: usize> Default for IntegratedCosetMds<F, N> {
    fn default() -> Self {
        let log_n = log2_strict_usize(N);
        let root = F::two_adic_generator(log_n);
        let root_inv = root.inverse();
        let coset_shift = F::GENERATOR;

        let mut ifft_twiddles: Vec<F> = root_inv.powers().take(N / 2).collect();
        reverse_slice_index_bits(&mut ifft_twiddles);

        let fft_twiddles: Vec<Vec<F>> = (0..log_n)
            .map(|layer| {
                let shift_power = coset_shift.exp_power_of_2(layer);
                let powers = Powers {
                    base: root.exp_power_of_2(layer),
                    current: shift_power,
                };
                let mut twiddles: Vec<_> = powers.take(N >> (layer + 1)).collect();
                reverse_slice_index_bits(&mut twiddles);
                twiddles
            })
            .collect();

        Self {
            ifft_twiddles,
            fft_twiddles,
        }
    }
}

impl<AF: AbstractField, const N: usize> Permutation<[AF; N]> for IntegratedCosetMds<AF::F, N> {
    fn permute(&self, mut input: [AF; N]) -> [AF; N] {
        self.permute_mut(&mut input);
        input
    }

    fn permute_mut(&self, values: &mut [AF; N]) {
        let log_n = log2_strict_usize(N);

        // Bit-reversed DIF, aka Bowers G
        for layer in 0..log_n {
            bowers_g_layer(values, layer, &self.ifft_twiddles);
        }

        // Bit-reversed DIT, aka Bowers G^T
        for layer in (0..log_n).rev() {
            bowers_g_t_layer(values, layer, &self.fft_twiddles[layer]);
        }
    }
}

impl<AF: AbstractField, const N: usize> MdsPermutation<AF, N> for IntegratedCosetMds<AF::F, N> {}

#[inline]
fn bowers_g_layer<AF: AbstractField, const N: usize>(
    values: &mut [AF; N],
    log_half_block_size: usize,
    twiddles: &[AF::F],
) {
    let log_block_size = log_half_block_size + 1;
    let half_block_size = 1 << log_half_block_size;
    let num_blocks = N >> log_block_size;

    // Unroll first iteration with a twiddle factor of 1.
    for hi in 0..half_block_size {
        let lo = hi + half_block_size;
        twiddle_free_butterfly(values, hi, lo);
    }

    for (block, &twiddle) in (1..num_blocks).zip(&twiddles[1..]) {
        let block_start = block << log_block_size;
        for hi in block_start..block_start + half_block_size {
            let lo = hi + half_block_size;
            dif_butterfly(values, hi, lo, twiddle);
        }
    }
}

#[inline]
fn bowers_g_t_layer<AF: AbstractField, const N: usize>(
    values: &mut [AF; N],
    log_half_block_size: usize,
    twiddles: &[AF::F],
) {
    let log_block_size = log_half_block_size + 1;
    let half_block_size = 1 << log_half_block_size;
    let num_blocks = N >> log_block_size;

    for (block, &twiddle) in (0..num_blocks).zip(twiddles) {
        let block_start = block << log_block_size;
        for hi in block_start..block_start + half_block_size {
            let lo = hi + half_block_size;
            dit_butterfly(values, hi, lo, twiddle);
        }
    }
}

#[cfg(test)]
mod tests {
    use p3_baby_bear::BabyBear;
    use p3_dft::{NaiveDft, TwoAdicSubgroupDft};
    use p3_field::{AbstractField, Field};
    use p3_symmetric::Permutation;
    use p3_util::reverse_slice_index_bits;
    use rand::{thread_rng, Rng};

    use crate::integrated_coset_mds::IntegratedCosetMds;

    type F = BabyBear;
    const N: usize = 16;

    #[test]
    fn matches_naive() {
        let mut rng = thread_rng();
        let mut arr: [F; N] = rng.gen();

        let mut arr_rev = arr.to_vec();
        reverse_slice_index_bits(&mut arr_rev);

        let shift = F::GENERATOR;
        let mut coset_lde_naive = NaiveDft.coset_lde(arr_rev, 0, shift);
        reverse_slice_index_bits(&mut coset_lde_naive);
        coset_lde_naive
            .iter_mut()
            .for_each(|x| *x *= F::from_canonical_usize(N));
        IntegratedCosetMds::default().permute_mut(&mut arr);
        assert_eq!(coset_lde_naive, arr);
    }
}