p3_matrix/
extension.rs

1use alloc::vec::Vec;
2use core::iter;
3use core::marker::PhantomData;
4use core::ops::Deref;
5
6use p3_field::{ExtensionField, Field};
7
8use crate::Matrix;
9
10/// Flattens a matrix of extension field elements to one of base field elements. The flattening is
11/// done horizontally, resulting in a wider matrix.
12#[derive(Debug)]
13pub struct FlatMatrixView<F, EF, Inner>(Inner, PhantomData<(F, EF)>);
14
15impl<F, EF, Inner> FlatMatrixView<F, EF, Inner> {
16    pub fn new(inner: Inner) -> Self {
17        Self(inner, PhantomData)
18    }
19}
20
21impl<F, EF, Inner> Deref for FlatMatrixView<F, EF, Inner> {
22    type Target = Inner;
23
24    fn deref(&self) -> &Self::Target {
25        &self.0
26    }
27}
28
29impl<F, EF, Inner> Matrix<F> for FlatMatrixView<F, EF, Inner>
30where
31    F: Field,
32    EF: ExtensionField<F>,
33    Inner: Matrix<EF>,
34{
35    fn width(&self) -> usize {
36        self.0.width() * EF::D
37    }
38
39    fn height(&self) -> usize {
40        self.0.height()
41    }
42
43    type Row<'a>
44        = FlatIter<F, Inner::Row<'a>>
45    where
46        Self: 'a;
47
48    fn row(&self, r: usize) -> Self::Row<'_> {
49        FlatIter {
50            inner: self.0.row(r).peekable(),
51            idx: 0,
52            _phantom: PhantomData,
53        }
54    }
55
56    fn row_slice(&self, r: usize) -> impl Deref<Target = [F]> {
57        self.0
58            .row_slice(r)
59            .iter()
60            .flat_map(|val| val.as_base_slice())
61            .copied()
62            .collect::<Vec<_>>()
63    }
64}
65
66pub struct FlatIter<F, I: Iterator> {
67    inner: iter::Peekable<I>,
68    idx: usize,
69    _phantom: PhantomData<F>,
70}
71
72impl<F, EF, I> Iterator for FlatIter<F, I>
73where
74    F: Field,
75    EF: ExtensionField<F>,
76    I: Iterator<Item = EF>,
77{
78    type Item = F;
79    fn next(&mut self) -> Option<Self::Item> {
80        if self.idx == EF::D {
81            self.idx = 0;
82            self.inner.next();
83        }
84        let value = self.inner.peek()?.as_base_slice()[self.idx];
85        self.idx += 1;
86        Some(value)
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use alloc::vec;
93
94    use p3_field::extension::Complex;
95    use p3_field::{FieldAlgebra, FieldExtensionAlgebra};
96    use p3_mersenne_31::Mersenne31;
97
98    use super::*;
99    use crate::dense::RowMajorMatrix;
100    type F = Mersenne31;
101    type EF = Complex<Mersenne31>;
102
103    #[test]
104    fn flat_matrix() {
105        let values = vec![
106            EF::from_base_fn(|i| F::from_canonical_usize(i + 10)),
107            EF::from_base_fn(|i| F::from_canonical_usize(i + 20)),
108            EF::from_base_fn(|i| F::from_canonical_usize(i + 30)),
109            EF::from_base_fn(|i| F::from_canonical_usize(i + 40)),
110        ];
111        let ext = RowMajorMatrix::<EF>::new(values, 2);
112        let flat = FlatMatrixView::<F, EF, _>::new(ext);
113        assert_eq!(
114            &*flat.row_slice(0),
115            &[10, 11, 20, 21].map(F::from_canonical_usize)
116        );
117        assert_eq!(
118            &*flat.row_slice(1),
119            &[30, 31, 40, 41].map(F::from_canonical_usize)
120        );
121    }
122}