p3_matrix/
sparse.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::iter;
4use core::ops::Range;
5
6use rand::distributions::{Distribution, Standard};
7use rand::Rng;
8
9use crate::Matrix;
10
11/// A sparse matrix stored in the compressed sparse row format.
12#[derive(Debug)]
13pub struct CsrMatrix<T> {
14    width: usize,
15
16    /// A list of `(col, coefficient)` pairs.
17    nonzero_values: Vec<(usize, T)>,
18
19    /// Indices of `nonzero_values`. The `i`th index here indicates the first index belonging to the
20    /// `i`th row.
21    row_indices: Vec<usize>,
22}
23
24impl<T: Clone + Default + Send + Sync> CsrMatrix<T> {
25    fn row_index_range(&self, r: usize) -> Range<usize> {
26        debug_assert!(r < self.height());
27        self.row_indices[r]..self.row_indices[r + 1]
28    }
29
30    #[must_use]
31    pub fn sparse_row(&self, r: usize) -> &[(usize, T)] {
32        &self.nonzero_values[self.row_index_range(r)]
33    }
34
35    pub fn sparse_row_mut(&mut self, r: usize) -> &mut [(usize, T)] {
36        let range = self.row_index_range(r);
37        &mut self.nonzero_values[range]
38    }
39
40    pub fn rand_fixed_row_weight<R: Rng>(
41        rng: &mut R,
42        rows: usize,
43        cols: usize,
44        row_weight: usize,
45    ) -> Self
46    where
47        T: Default,
48        Standard: Distribution<T>,
49    {
50        let nonzero_values = iter::repeat_with(|| (rng.gen_range(0..cols), rng.gen()))
51            .take(rows * row_weight)
52            .collect();
53        let row_indices = (0..=rows).map(|r| r * row_weight).collect();
54        Self {
55            width: cols,
56            nonzero_values,
57            row_indices,
58        }
59    }
60}
61
62impl<T: Clone + Default + Send + Sync> Matrix<T> for CsrMatrix<T> {
63    fn width(&self) -> usize {
64        self.width
65    }
66
67    fn height(&self) -> usize {
68        self.row_indices.len() - 1
69    }
70
71    fn get(&self, r: usize, c: usize) -> T {
72        self.sparse_row(r)
73            .iter()
74            .find(|(col, _)| *col == c)
75            .map(|(_, val)| val.clone())
76            .unwrap_or_default()
77    }
78
79    type Row<'a>
80        = <Vec<T> as IntoIterator>::IntoIter
81    where
82        Self: 'a;
83
84    fn row(&self, r: usize) -> Self::Row<'_> {
85        let mut row = vec![T::default(); self.width()];
86        for (c, v) in self.sparse_row(r) {
87            row[*c] = v.clone();
88        }
89        row.into_iter()
90    }
91}