p3_mds/util.rs
1use alloc::vec::Vec;
2use core::array;
3use core::ops::{AddAssign, Mul};
4
5use p3_dft::TwoAdicSubgroupDft;
6use p3_field::{FieldAlgebra, TwoAdicField};
7
8// NB: These are all MDS for M31, BabyBear and Goldilocks
9// const MATRIX_CIRC_MDS_8_2EXP: [u64; 8] = [1, 1, 2, 1, 8, 32, 4, 256];
10// const MATRIX_CIRC_MDS_8_SML: [u64; 8] = [4, 1, 2, 9, 10, 5, 1, 1];
11// Much smaller: [1, 1, -1, 2, 3, 8, 2, -3] but need to deal with the -ve's
12
13// const MATRIX_CIRC_MDS_12_2EXP: [u64; 12] = [1, 1, 2, 1, 8, 32, 2, 256, 4096, 8, 65536, 1024];
14// const MATRIX_CIRC_MDS_12_SML: [u64; 12] = [9, 7, 4, 1, 16, 2, 256, 128, 3, 32, 1, 1];
15// const MATRIX_CIRC_MDS_12_SML: [u64; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10];
16
17// Trying to maximise the # of 1's in the vector.
18// Not clear exactly what we should be optimising here but that seems reasonable.
19// const MATRIX_CIRC_MDS_16_SML: [u64; 16] =
20// [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3];
21// 1, 1, 51, 52, 11, 63, 1, 2, 1, 2, 15, 67, 2, 22, 13, 3
22// [1, 1, 2, 1, 8, 32, 2, 65, 77, 8, 91, 31, 3, 65, 32, 7];
23
24/// This will throw an error if N = 0 but it's hard to imagine this case coming up.
25#[inline(always)]
26pub fn dot_product<T, const N: usize>(u: [T; N], v: [T; N]) -> T
27where
28 T: Copy + AddAssign + Mul<Output = T>,
29{
30 debug_assert_ne!(N, 0);
31 let mut dp = u[0] * v[0];
32 for i in 1..N {
33 dp += u[i] * v[i];
34 }
35 dp
36}
37
38/// Given the first row `circ_matrix` of an NxN circulant matrix, say
39/// C, return the product `C*input`.
40///
41/// NB: This function is a naive implementation of the n²
42/// evaluation. It is a placeholder until we have FFT implementations
43/// for all combinations of field and size.
44pub fn apply_circulant<FA: FieldAlgebra, const N: usize>(
45 circ_matrix: &[u64; N],
46 input: [FA; N],
47) -> [FA; N] {
48 let mut matrix: [FA; N] = circ_matrix.map(FA::from_canonical_u64);
49
50 let mut output = array::from_fn(|_| FA::ZERO);
51 for out_i in output.iter_mut().take(N - 1) {
52 *out_i = FA::dot_product(&matrix, &input);
53 matrix.rotate_right(1);
54 }
55 output[N - 1] = FA::dot_product(&matrix, &input);
56 output
57}
58
59/// Given the first row of a circulant matrix, return the first column.
60///
61/// For example if, `v = [0, 1, 2, 3, 4, 5]` then `output = [0, 5, 4, 3, 2, 1]`,
62/// i.e. the first element is the same and the other elements are reversed.
63///
64/// This is useful to prepare a circulant matrix for input to an FFT
65/// algorithm, which expects the first column of the matrix rather
66/// than the first row (as we normally store them).
67///
68/// NB: The algorithm is inefficient but simple enough that this
69/// function can be declared `const`, and that is the intended context
70/// for use.
71pub const fn first_row_to_first_col<const N: usize, T: Copy>(v: &[T; N]) -> [T; N] {
72 // Can do this to get a simple Default value. Might be better ways?
73 let mut output = [v[0]; N];
74 let mut i = 1;
75 loop {
76 if i >= N {
77 break;
78 }
79 output[i] = v[N - i];
80 i += 1;
81 }
82 output
83}
84
85/// Use the convolution theorem to calculate the product of the given
86/// circulant matrix and the given vector.
87///
88/// The circulant matrix must be specified by its first *column*, not its first row. If you have
89/// the row as an array, you can obtain the column with `first_row_to_first_col()`.
90#[inline]
91pub fn apply_circulant_fft<F: TwoAdicField, const N: usize, FFT: TwoAdicSubgroupDft<F>>(
92 fft: FFT,
93 column: [u64; N],
94 input: &[F; N],
95) -> [F; N] {
96 let column = column.map(F::from_canonical_u64).to_vec();
97 let matrix = fft.dft(column);
98 let input = fft.dft(input.to_vec());
99
100 // point-wise product
101 let product = matrix
102 .iter()
103 .zip(input)
104 .map(|(&x, y)| x * y)
105 .collect::<Vec<_>>();
106
107 let output = fft.idft(product);
108 output.try_into().unwrap()
109}
110
111#[cfg(test)]
112mod tests {
113 use super::first_row_to_first_col;
114
115 #[test]
116 fn rotation() {
117 let input = [0, 1, 2, 3, 4, 5];
118 let output = [0, 5, 4, 3, 2, 1];
119
120 assert_eq!(first_row_to_first_col(&input), output);
121 }
122}