p3_matrix/
extension.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use alloc::vec::Vec;
use core::iter;
use core::marker::PhantomData;
use core::ops::Deref;

use p3_field::{ExtensionField, Field};

use crate::Matrix;

/// Flattens a matrix of extension field elements to one of base field elements. The flattening is
/// done horizontally, resulting in a wider matrix.
#[derive(Debug)]
pub struct FlatMatrixView<F, EF, Inner>(Inner, PhantomData<(F, EF)>);

impl<F, EF, Inner> FlatMatrixView<F, EF, Inner> {
    pub fn new(inner: Inner) -> Self {
        Self(inner, PhantomData)
    }
    pub fn inner_ref(&self) -> &Inner {
        &self.0
    }
}

impl<F, EF, Inner> Matrix<F> for FlatMatrixView<F, EF, Inner>
where
    F: Field,
    EF: ExtensionField<F>,
    Inner: Matrix<EF>,
{
    fn width(&self) -> usize {
        self.0.width() * EF::D
    }

    fn height(&self) -> usize {
        self.0.height()
    }

    type Row<'a>
        = FlatIter<F, Inner::Row<'a>>
    where
        Self: 'a;

    fn row(&self, r: usize) -> Self::Row<'_> {
        FlatIter {
            inner: self.0.row(r).peekable(),
            idx: 0,
            _phantom: PhantomData,
        }
    }

    fn row_slice(&self, r: usize) -> impl Deref<Target = [F]> {
        let ef_row: Vec<F> = self
            .0
            .row_slice(r)
            .iter()
            .flat_map(|val| val.as_base_slice())
            .copied()
            .collect();
        ef_row
    }
}

pub struct FlatIter<F, I: Iterator> {
    inner: iter::Peekable<I>,
    idx: usize,
    _phantom: PhantomData<F>,
}

impl<F, EF, I> Iterator for FlatIter<F, I>
where
    F: Field,
    EF: ExtensionField<F>,
    I: Iterator<Item = EF>,
{
    type Item = F;
    fn next(&mut self) -> Option<Self::Item> {
        if self.idx == EF::D {
            self.idx = 0;
            self.inner.next();
        }
        let value = self.inner.peek()?.as_base_slice()[self.idx];
        self.idx += 1;
        Some(value)
    }
}

#[cfg(test)]
mod tests {
    use alloc::vec;

    use p3_field::extension::Complex;
    use p3_field::{AbstractExtensionField, AbstractField};
    use p3_mersenne_31::Mersenne31;

    use super::*;
    use crate::dense::RowMajorMatrix;
    type F = Mersenne31;
    type EF = Complex<Mersenne31>;

    #[test]
    fn flat_matrix() {
        let values = vec![
            EF::from_base_fn(|i| F::from_canonical_usize(i + 10)),
            EF::from_base_fn(|i| F::from_canonical_usize(i + 20)),
            EF::from_base_fn(|i| F::from_canonical_usize(i + 30)),
            EF::from_base_fn(|i| F::from_canonical_usize(i + 40)),
        ];
        let ext = RowMajorMatrix::<EF>::new(values, 2);
        let flat = FlatMatrixView::<F, EF, _>::new(ext);
        assert_eq!(
            &*flat.row_slice(0),
            &[10, 11, 20, 21].map(F::from_canonical_usize)
        );
        assert_eq!(
            &*flat.row_slice(1),
            &[30, 31, 40, 41].map(F::from_canonical_usize)
        );
    }
}