p3_matrix/
lib.rs

1//! Matrix library.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::vec::Vec;
8use core::fmt::{Debug, Display, Formatter};
9use core::ops::Deref;
10
11use itertools::{izip, Itertools};
12use p3_field::{
13    dot_product, ExtensionField, Field, FieldAlgebra, FieldExtensionAlgebra, PackedValue,
14};
15use p3_maybe_rayon::prelude::*;
16use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
17use tracing::instrument;
18
19use crate::dense::RowMajorMatrix;
20
21pub mod bitrev;
22pub mod dense;
23pub mod extension;
24pub mod horizontally_truncated;
25pub mod mul;
26pub mod row_index_mapped;
27pub mod sparse;
28pub mod stack;
29pub mod strided;
30pub mod util;
31
32#[derive(Copy, Clone, PartialEq, Eq)]
33pub struct Dimensions {
34    pub width: usize,
35    pub height: usize,
36}
37
38impl Debug for Dimensions {
39    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
40        write!(f, "{}x{}", self.width, self.height)
41    }
42}
43
44impl Display for Dimensions {
45    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
46        write!(f, "{}x{}", self.width, self.height)
47    }
48}
49
50pub trait Matrix<T: Send + Sync>: Send + Sync {
51    fn width(&self) -> usize;
52    fn height(&self) -> usize;
53
54    fn dimensions(&self) -> Dimensions {
55        Dimensions {
56            width: self.width(),
57            height: self.height(),
58        }
59    }
60
61    fn get(&self, r: usize, c: usize) -> T {
62        self.row(r).nth(c).unwrap()
63    }
64
65    type Row<'a>: Iterator<Item = T> + Send + Sync
66    where
67        Self: 'a;
68
69    fn row(&self, r: usize) -> Self::Row<'_>;
70
71    fn rows(&self) -> impl Iterator<Item = Self::Row<'_>> {
72        (0..self.height()).map(move |r| self.row(r))
73    }
74
75    fn par_rows(&self) -> impl IndexedParallelIterator<Item = Self::Row<'_>> {
76        (0..self.height()).into_par_iter().map(move |r| self.row(r))
77    }
78
79    // Opaque return type implicitly captures &'_ self
80    fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
81        self.row(r).collect_vec()
82    }
83
84    fn first_row(&self) -> Self::Row<'_> {
85        self.row(0)
86    }
87
88    fn last_row(&self) -> Self::Row<'_> {
89        self.row(self.height() - 1)
90    }
91
92    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
93    where
94        Self: Sized,
95        T: Clone,
96    {
97        RowMajorMatrix::new(
98            (0..self.height()).flat_map(|r| self.row(r)).collect(),
99            self.width(),
100        )
101    }
102
103    fn horizontally_packed_row<'a, P>(
104        &'a self,
105        r: usize,
106    ) -> (
107        impl Iterator<Item = P> + Send + Sync,
108        impl Iterator<Item = T> + Send + Sync,
109    )
110    where
111        P: PackedValue<Value = T>,
112        T: Clone + 'a,
113    {
114        let num_packed = self.width() / P::WIDTH;
115        let packed = (0..num_packed).map(move |c| P::from_fn(|i| self.get(r, P::WIDTH * c + i)));
116        let sfx = (num_packed * P::WIDTH..self.width()).map(move |c| self.get(r, c));
117        (packed, sfx)
118    }
119
120    /// Zero padded.
121    fn padded_horizontally_packed_row<'a, P>(
122        &'a self,
123        r: usize,
124    ) -> impl Iterator<Item = P> + Send + Sync
125    where
126        P: PackedValue<Value = T>,
127        T: Clone + Default + 'a,
128    {
129        let mut row_iter = self.row(r);
130        let num_elems = self.width().div_ceil(P::WIDTH);
131        // array::from_fn currently always calls in order, but it's not clear whether that's guaranteed.
132        (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
133    }
134
135    fn par_horizontally_packed_rows<'a, P>(
136        &'a self,
137    ) -> impl IndexedParallelIterator<
138        Item = (
139            impl Iterator<Item = P> + Send + Sync,
140            impl Iterator<Item = T> + Send + Sync,
141        ),
142    >
143    where
144        P: PackedValue<Value = T>,
145        T: Clone + 'a,
146    {
147        (0..self.height())
148            .into_par_iter()
149            .map(|r| self.horizontally_packed_row(r))
150    }
151
152    fn par_padded_horizontally_packed_rows<'a, P>(
153        &'a self,
154    ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
155    where
156        P: PackedValue<Value = T>,
157        T: Clone + Default + 'a,
158    {
159        (0..self.height())
160            .into_par_iter()
161            .map(|r| self.padded_horizontally_packed_row(r))
162    }
163
164    /// Pack together a collection of adjacent rows from the matrix.
165    ///
166    /// Returns an iterator whose i'th element is packing of the i'th element of the
167    /// rows r through r + P::WIDTH - 1. If we exceed the height of the matrix,
168    /// wrap around and include initial rows.
169    #[inline]
170    fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
171    where
172        T: Copy,
173        P: PackedValue<Value = T>,
174    {
175        let rows = (0..(P::WIDTH))
176            .map(|c| self.row_slice((r + c) % self.height()))
177            .collect_vec();
178        (0..self.width()).map(move |c| P::from_fn(|i| rows[i][c]))
179    }
180
181    /// Pack together a collection of rows and "next" rows from the matrix.
182    ///
183    /// Returns a vector corresponding to 2 packed rows. The i'th element of the first
184    /// row contains the packing of the i'th element of the rows r through r + P::WIDTH - 1.
185    /// The i'th element of the second row contains the packing of the i'th element of the
186    /// rows r + step through r + step + P::WIDTH - 1. If at some point we exceed the
187    /// height of the matrix, wrap around and include initial rows.
188    #[inline]
189    fn vertically_packed_row_pair<P>(&self, r: usize, step: usize) -> Vec<P>
190    where
191        T: Copy,
192        P: PackedValue<Value = T>,
193    {
194        // Whilst it would appear that this can be replaced by two calls to vertically_packed_row
195        // tests seem to indicate that combining them in the same function is slightly faster.
196        // It's probably allowing the compiler to make some optimizations on the fly.
197
198        let rows = (0..P::WIDTH)
199            .map(|c| self.row_slice((r + c) % self.height()))
200            .collect_vec();
201
202        let next_rows = (0..P::WIDTH)
203            .map(|c| self.row_slice((r + c + step) % self.height()))
204            .collect_vec();
205
206        (0..self.width())
207            .map(|c| P::from_fn(|i| rows[i][c]))
208            .chain((0..self.width()).map(|c| P::from_fn(|i| next_rows[i][c])))
209            .collect_vec()
210    }
211
212    fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
213    where
214        Self: Sized,
215    {
216        VerticallyStridedRowIndexMap::new_view(self, stride, offset)
217    }
218
219    /// Compute Mᵀv, aka premultiply this matrix by the given vector,
220    /// aka scale each row by the corresponding entry in `v` and take the sum across rows.
221    /// `v` can be a vector of extension elements.
222    #[instrument(level = "debug", skip_all, fields(dims = %self.dimensions()))]
223    fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
224    where
225        T: Field,
226        EF: ExtensionField<T>,
227    {
228        let packed_width = self.width().div_ceil(T::Packing::WIDTH);
229
230        let packed_result = self
231            .par_padded_horizontally_packed_rows::<T::Packing>()
232            .zip(v)
233            .par_fold_reduce(
234                || EF::ExtensionPacking::zero_vec(packed_width),
235                |mut acc, (row, &scale)| {
236                    let scale = EF::ExtensionPacking::from_base_fn(|i| {
237                        T::Packing::from(scale.as_base_slice()[i])
238                    });
239                    izip!(&mut acc, row).for_each(|(l, r)| *l += scale * r);
240                    acc
241                },
242                |mut acc_l, acc_r| {
243                    izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
244                    acc_l
245                },
246            );
247
248        packed_result
249            .into_iter()
250            .flat_map(|p| {
251                (0..T::Packing::WIDTH)
252                    .map(move |i| EF::from_base_fn(|j| p.as_base_slice()[j].as_slice()[i]))
253            })
254            .take(self.width())
255            .collect()
256    }
257
258    /// Multiply this matrix by the vector of powers of `base`, which is an extension element.
259    fn dot_ext_powers<EF>(&self, base: EF) -> impl IndexedParallelIterator<Item = EF>
260    where
261        T: Field,
262        EF: ExtensionField<T>,
263    {
264        let powers_packed = base
265            .ext_powers_packed()
266            .take(self.width().next_multiple_of(T::Packing::WIDTH))
267            .collect_vec();
268        self.par_padded_horizontally_packed_rows::<T::Packing>()
269            .map(move |row_packed| {
270                let packed_sum_of_packed: EF::ExtensionPacking =
271                    dot_product(powers_packed.iter().copied(), row_packed);
272                let sum_of_packed: EF = EF::from_base_fn(|i| {
273                    packed_sum_of_packed.as_base_slice()[i]
274                        .as_slice()
275                        .iter()
276                        .copied()
277                        .sum()
278                });
279                sum_of_packed
280            })
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use alloc::vec::Vec;
287    use alloc::{format, vec};
288
289    use itertools::izip;
290    use p3_baby_bear::BabyBear;
291    use p3_field::extension::BinomialExtensionField;
292    use p3_field::FieldAlgebra;
293    use rand::thread_rng;
294
295    use super::*;
296
297    #[test]
298    fn test_columnwise_dot_product() {
299        type F = BabyBear;
300        type EF = BinomialExtensionField<BabyBear, 4>;
301
302        let m = RowMajorMatrix::<F>::rand(&mut thread_rng(), 1 << 8, 1 << 4);
303        let v = RowMajorMatrix::<EF>::rand(&mut thread_rng(), 1 << 8, 1).values;
304
305        let mut expected = vec![EF::ZERO; m.width()];
306        for (row, &scale) in izip!(m.rows(), &v) {
307            for (l, r) in izip!(&mut expected, row) {
308                *l += scale * r;
309            }
310        }
311
312        assert_eq!(m.columnwise_dot_product(&v), expected);
313    }
314
315    // Mock implementation for testing purposes
316    struct MockMatrix {
317        data: Vec<Vec<u32>>,
318        width: usize,
319        height: usize,
320    }
321
322    impl Matrix<u32> for MockMatrix {
323        type Row<'a> = alloc::vec::IntoIter<u32>;
324
325        fn width(&self) -> usize {
326            self.width
327        }
328
329        fn height(&self) -> usize {
330            self.height
331        }
332
333        fn row(&self, r: usize) -> Self::Row<'_> {
334            self.data[r].clone().into_iter()
335        }
336    }
337
338    #[test]
339    fn test_dimensions() {
340        let dims = Dimensions {
341            width: 3,
342            height: 5,
343        };
344        assert_eq!(dims.width, 3);
345        assert_eq!(dims.height, 5);
346        assert_eq!(format!("{:?}", dims), "3x5");
347        assert_eq!(format!("{}", dims), "3x5");
348    }
349
350    #[test]
351    fn test_mock_matrix_dimensions() {
352        let matrix = MockMatrix {
353            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
354            width: 3,
355            height: 3,
356        };
357        assert_eq!(matrix.width(), 3);
358        assert_eq!(matrix.height(), 3);
359        assert_eq!(
360            matrix.dimensions(),
361            Dimensions {
362                width: 3,
363                height: 3
364            }
365        );
366    }
367
368    #[test]
369    fn test_first_row() {
370        let matrix = MockMatrix {
371            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
372            width: 3,
373            height: 3,
374        };
375        let mut first_row = matrix.first_row();
376        assert_eq!(first_row.next(), Some(1));
377        assert_eq!(first_row.next(), Some(2));
378        assert_eq!(first_row.next(), Some(3));
379    }
380
381    #[test]
382    fn test_last_row() {
383        let matrix = MockMatrix {
384            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
385            width: 3,
386            height: 3,
387        };
388        let mut last_row = matrix.last_row();
389        assert_eq!(last_row.next(), Some(7));
390        assert_eq!(last_row.next(), Some(8));
391        assert_eq!(last_row.next(), Some(9));
392    }
393
394    #[test]
395    fn test_row_slice() {
396        let matrix = MockMatrix {
397            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
398            width: 3,
399            height: 3,
400        };
401        let row_slice = matrix.row_slice(1);
402        assert_eq!(row_slice.deref(), &[4, 5, 6]);
403    }
404
405    #[test]
406    fn test_to_row_major_matrix() {
407        let matrix = MockMatrix {
408            data: vec![vec![1, 2], vec![3, 4]],
409            width: 2,
410            height: 2,
411        };
412        let row_major = matrix.to_row_major_matrix();
413        assert_eq!(row_major.values, vec![1, 2, 3, 4]);
414        assert_eq!(row_major.width, 2);
415    }
416
417    #[test]
418    fn test_matrix_get() {
419        let matrix = MockMatrix {
420            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
421            width: 3,
422            height: 3,
423        };
424        assert_eq!(matrix.get(0, 0), 1);
425        assert_eq!(matrix.get(1, 2), 6);
426        assert_eq!(matrix.get(2, 1), 8);
427    }
428
429    #[test]
430    fn test_matrix_row_iteration() {
431        let matrix = MockMatrix {
432            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
433            width: 3,
434            height: 3,
435        };
436
437        let mut row_iter = matrix.row(1);
438        assert_eq!(row_iter.next(), Some(4));
439        assert_eq!(row_iter.next(), Some(5));
440        assert_eq!(row_iter.next(), Some(6));
441        assert_eq!(row_iter.next(), None);
442    }
443
444    #[test]
445    fn test_matrix_rows() {
446        let matrix = MockMatrix {
447            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
448            width: 3,
449            height: 3,
450        };
451
452        let all_rows: Vec<Vec<u32>> = matrix.rows().map(|row| row.collect()).collect();
453        assert_eq!(all_rows, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
454    }
455}