p3_matrix/
stack.rs

1use core::iter::Chain;
2use core::ops::Deref;
3
4use crate::Matrix;
5
6/// A combination of two matrices, stacked together vertically.
7#[derive(Copy, Clone, Debug)]
8pub struct VerticalPair<First, Second> {
9    pub first: First,
10    pub second: Second,
11}
12
13/// A combination of two matrices, stacked together horizontally.
14#[derive(Copy, Clone, Debug)]
15pub struct HorizontalPair<First, Second> {
16    pub first: First,
17    pub second: Second,
18}
19
20impl<First, Second> VerticalPair<First, Second> {
21    pub fn new<T>(first: First, second: Second) -> Self
22    where
23        T: Send + Sync,
24        First: Matrix<T>,
25        Second: Matrix<T>,
26    {
27        assert_eq!(first.width(), second.width());
28        Self { first, second }
29    }
30}
31
32impl<First, Second> HorizontalPair<First, Second> {
33    pub fn new<T>(first: First, second: Second) -> Self
34    where
35        T: Send + Sync,
36        First: Matrix<T>,
37        Second: Matrix<T>,
38    {
39        assert_eq!(first.height(), second.height());
40        Self { first, second }
41    }
42}
43
44impl<T: Send + Sync, First: Matrix<T>, Second: Matrix<T>> Matrix<T>
45    for VerticalPair<First, Second>
46{
47    fn width(&self) -> usize {
48        self.first.width()
49    }
50
51    fn height(&self) -> usize {
52        self.first.height() + self.second.height()
53    }
54
55    fn get(&self, r: usize, c: usize) -> T {
56        if r < self.first.height() {
57            self.first.get(r, c)
58        } else {
59            self.second.get(r - self.first.height(), c)
60        }
61    }
62
63    type Row<'a>
64        = EitherRow<First::Row<'a>, Second::Row<'a>>
65    where
66        Self: 'a;
67
68    fn row(&self, r: usize) -> Self::Row<'_> {
69        if r < self.first.height() {
70            EitherRow::Left(self.first.row(r))
71        } else {
72            EitherRow::Right(self.second.row(r - self.first.height()))
73        }
74    }
75
76    fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
77        if r < self.first.height() {
78            EitherRow::Left(self.first.row_slice(r))
79        } else {
80            EitherRow::Right(self.second.row_slice(r - self.first.height()))
81        }
82    }
83}
84
85impl<T: Send + Sync, First: Matrix<T>, Second: Matrix<T>> Matrix<T>
86    for HorizontalPair<First, Second>
87{
88    fn width(&self) -> usize {
89        self.first.width() + self.second.width()
90    }
91
92    fn height(&self) -> usize {
93        self.first.height()
94    }
95
96    fn get(&self, r: usize, c: usize) -> T {
97        if c < self.first.width() {
98            self.first.get(r, c)
99        } else {
100            self.second.get(r, c - self.first.width())
101        }
102    }
103
104    type Row<'a>
105        = Chain<First::Row<'a>, Second::Row<'a>>
106    where
107        Self: 'a;
108
109    fn row(&self, r: usize) -> Self::Row<'_> {
110        self.first.row(r).chain(self.second.row(r))
111    }
112}
113
114/// We use this to wrap both the row iterator and the row slice.
115#[derive(Debug)]
116pub enum EitherRow<L, R> {
117    Left(L),
118    Right(R),
119}
120
121impl<T, L, R> Iterator for EitherRow<L, R>
122where
123    L: Iterator<Item = T>,
124    R: Iterator<Item = T>,
125{
126    type Item = T;
127
128    fn next(&mut self) -> Option<Self::Item> {
129        match self {
130            EitherRow::Left(l) => l.next(),
131            EitherRow::Right(r) => r.next(),
132        }
133    }
134}
135
136impl<T, L, R> Deref for EitherRow<L, R>
137where
138    L: Deref<Target = [T]>,
139    R: Deref<Target = [T]>,
140{
141    type Target = [T];
142    fn deref(&self) -> &Self::Target {
143        match self {
144            EitherRow::Left(l) => l,
145            EitherRow::Right(r) => r,
146        }
147    }
148}