p3_matrix/
sparse.rs
1use alloc::vec;
2use alloc::vec::Vec;
3use core::iter;
4use core::ops::Range;
5
6use rand::distributions::{Distribution, Standard};
7use rand::Rng;
8
9use crate::Matrix;
10
11#[derive(Debug)]
13pub struct CsrMatrix<T> {
14 width: usize,
15
16 nonzero_values: Vec<(usize, T)>,
18
19 row_indices: Vec<usize>,
22}
23
24impl<T: Clone + Default + Send + Sync> CsrMatrix<T> {
25 fn row_index_range(&self, r: usize) -> Range<usize> {
26 debug_assert!(r < self.height());
27 self.row_indices[r]..self.row_indices[r + 1]
28 }
29
30 #[must_use]
31 pub fn sparse_row(&self, r: usize) -> &[(usize, T)] {
32 &self.nonzero_values[self.row_index_range(r)]
33 }
34
35 pub fn sparse_row_mut(&mut self, r: usize) -> &mut [(usize, T)] {
36 let range = self.row_index_range(r);
37 &mut self.nonzero_values[range]
38 }
39
40 pub fn rand_fixed_row_weight<R: Rng>(
41 rng: &mut R,
42 rows: usize,
43 cols: usize,
44 row_weight: usize,
45 ) -> Self
46 where
47 T: Default,
48 Standard: Distribution<T>,
49 {
50 let nonzero_values = iter::repeat_with(|| (rng.gen_range(0..cols), rng.gen()))
51 .take(rows * row_weight)
52 .collect();
53 let row_indices = (0..=rows).map(|r| r * row_weight).collect();
54 Self {
55 width: cols,
56 nonzero_values,
57 row_indices,
58 }
59 }
60}
61
62impl<T: Clone + Default + Send + Sync> Matrix<T> for CsrMatrix<T> {
63 fn width(&self) -> usize {
64 self.width
65 }
66
67 fn height(&self) -> usize {
68 self.row_indices.len() - 1
69 }
70
71 fn get(&self, r: usize, c: usize) -> T {
72 self.sparse_row(r)
73 .iter()
74 .find(|(col, _)| *col == c)
75 .map(|(_, val)| val.clone())
76 .unwrap_or_default()
77 }
78
79 type Row<'a>
80 = <Vec<T> as IntoIterator>::IntoIter
81 where
82 Self: 'a;
83
84 fn row(&self, r: usize) -> Self::Row<'_> {
85 let mut row = vec![T::default(); self.width()];
86 for (c, v) in self.sparse_row(r) {
87 row[*c] = v.clone();
88 }
89 row.into_iter()
90 }
91}