p3_matrix/
row_index_mapped.rs
1use core::ops::Deref;
2
3use p3_field::PackedValue;
4
5use crate::dense::RowMajorMatrix;
6use crate::Matrix;
7
8pub trait RowIndexMap: Send + Sync {
10 fn height(&self) -> usize;
11 fn map_row_index(&self, r: usize) -> usize;
12
13 fn to_row_major_matrix<T: Clone + Send + Sync, Inner: Matrix<T>>(
16 &self,
17 inner: Inner,
18 ) -> RowMajorMatrix<T> {
19 RowMajorMatrix::new(
20 (0..self.height())
21 .flat_map(|r| inner.row(self.map_row_index(r)))
22 .collect(),
23 inner.width(),
24 )
25 }
26}
27
28#[derive(Copy, Clone, Debug)]
29pub struct RowIndexMappedView<IndexMap, Inner> {
30 pub index_map: IndexMap,
31 pub inner: Inner,
32}
33
34impl<T: Send + Sync, IndexMap: RowIndexMap, Inner: Matrix<T>> Matrix<T>
35 for RowIndexMappedView<IndexMap, Inner>
36{
37 fn width(&self) -> usize {
38 self.inner.width()
39 }
40 fn height(&self) -> usize {
41 self.index_map.height()
42 }
43
44 fn get(&self, r: usize, c: usize) -> T {
45 self.inner.get(self.index_map.map_row_index(r), c)
46 }
47
48 type Row<'a>
49 = Inner::Row<'a>
50 where
51 Self: 'a;
52
53 fn row(&self, r: usize) -> Self::Row<'_> {
56 self.inner.row(self.index_map.map_row_index(r))
57 }
58
59 fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
60 self.inner.row_slice(self.index_map.map_row_index(r))
61 }
62
63 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
64 where
65 Self: Sized,
66 T: Clone,
67 {
68 self.index_map.to_row_major_matrix(self.inner)
70 }
71
72 fn horizontally_packed_row<'a, P>(
73 &'a self,
74 r: usize,
75 ) -> (
76 impl Iterator<Item = P> + Send + Sync,
77 impl Iterator<Item = T> + Send + Sync,
78 )
79 where
80 P: PackedValue<Value = T>,
81 T: Clone + 'a,
82 {
83 self.inner
84 .horizontally_packed_row(self.index_map.map_row_index(r))
85 }
86
87 fn padded_horizontally_packed_row<'a, P>(
88 &'a self,
89 r: usize,
90 ) -> impl Iterator<Item = P> + Send + Sync
91 where
92 P: PackedValue<Value = T>,
93 T: Clone + Default + 'a,
94 {
95 self.inner
96 .padded_horizontally_packed_row(self.index_map.map_row_index(r))
97 }
98}