1#![no_std]
4
5extern crate alloc;
6
7use alloc::vec::Vec;
8use core::fmt::{Debug, Display, Formatter};
9use core::ops::Deref;
10
11use itertools::{izip, Itertools};
12use p3_field::{
13 dot_product, ExtensionField, Field, FieldAlgebra, FieldExtensionAlgebra, PackedValue,
14};
15use p3_maybe_rayon::prelude::*;
16use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
17use tracing::instrument;
18
19use crate::dense::RowMajorMatrix;
20
21pub mod bitrev;
22pub mod dense;
23pub mod extension;
24pub mod horizontally_truncated;
25pub mod mul;
26pub mod row_index_mapped;
27pub mod sparse;
28pub mod stack;
29pub mod strided;
30pub mod util;
31
32#[derive(Copy, Clone, PartialEq, Eq)]
33pub struct Dimensions {
34 pub width: usize,
35 pub height: usize,
36}
37
38impl Debug for Dimensions {
39 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
40 write!(f, "{}x{}", self.width, self.height)
41 }
42}
43
44impl Display for Dimensions {
45 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
46 write!(f, "{}x{}", self.width, self.height)
47 }
48}
49
50pub trait Matrix<T: Send + Sync>: Send + Sync {
51 fn width(&self) -> usize;
52 fn height(&self) -> usize;
53
54 fn dimensions(&self) -> Dimensions {
55 Dimensions {
56 width: self.width(),
57 height: self.height(),
58 }
59 }
60
61 fn get(&self, r: usize, c: usize) -> T {
62 self.row(r).nth(c).unwrap()
63 }
64
65 type Row<'a>: Iterator<Item = T> + Send + Sync
66 where
67 Self: 'a;
68
69 fn row(&self, r: usize) -> Self::Row<'_>;
70
71 fn rows(&self) -> impl Iterator<Item = Self::Row<'_>> {
72 (0..self.height()).map(move |r| self.row(r))
73 }
74
75 fn par_rows(&self) -> impl IndexedParallelIterator<Item = Self::Row<'_>> {
76 (0..self.height()).into_par_iter().map(move |r| self.row(r))
77 }
78
79 fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
81 self.row(r).collect_vec()
82 }
83
84 fn first_row(&self) -> Self::Row<'_> {
85 self.row(0)
86 }
87
88 fn last_row(&self) -> Self::Row<'_> {
89 self.row(self.height() - 1)
90 }
91
92 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
93 where
94 Self: Sized,
95 T: Clone,
96 {
97 RowMajorMatrix::new(
98 (0..self.height()).flat_map(|r| self.row(r)).collect(),
99 self.width(),
100 )
101 }
102
103 fn horizontally_packed_row<'a, P>(
104 &'a self,
105 r: usize,
106 ) -> (
107 impl Iterator<Item = P> + Send + Sync,
108 impl Iterator<Item = T> + Send + Sync,
109 )
110 where
111 P: PackedValue<Value = T>,
112 T: Clone + 'a,
113 {
114 let num_packed = self.width() / P::WIDTH;
115 let packed = (0..num_packed).map(move |c| P::from_fn(|i| self.get(r, P::WIDTH * c + i)));
116 let sfx = (num_packed * P::WIDTH..self.width()).map(move |c| self.get(r, c));
117 (packed, sfx)
118 }
119
120 fn padded_horizontally_packed_row<'a, P>(
122 &'a self,
123 r: usize,
124 ) -> impl Iterator<Item = P> + Send + Sync
125 where
126 P: PackedValue<Value = T>,
127 T: Clone + Default + 'a,
128 {
129 let mut row_iter = self.row(r);
130 let num_elems = self.width().div_ceil(P::WIDTH);
131 (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
133 }
134
135 fn par_horizontally_packed_rows<'a, P>(
136 &'a self,
137 ) -> impl IndexedParallelIterator<
138 Item = (
139 impl Iterator<Item = P> + Send + Sync,
140 impl Iterator<Item = T> + Send + Sync,
141 ),
142 >
143 where
144 P: PackedValue<Value = T>,
145 T: Clone + 'a,
146 {
147 (0..self.height())
148 .into_par_iter()
149 .map(|r| self.horizontally_packed_row(r))
150 }
151
152 fn par_padded_horizontally_packed_rows<'a, P>(
153 &'a self,
154 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
155 where
156 P: PackedValue<Value = T>,
157 T: Clone + Default + 'a,
158 {
159 (0..self.height())
160 .into_par_iter()
161 .map(|r| self.padded_horizontally_packed_row(r))
162 }
163
164 #[inline]
170 fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
171 where
172 T: Copy,
173 P: PackedValue<Value = T>,
174 {
175 let rows = (0..(P::WIDTH))
176 .map(|c| self.row_slice((r + c) % self.height()))
177 .collect_vec();
178 (0..self.width()).map(move |c| P::from_fn(|i| rows[i][c]))
179 }
180
181 #[inline]
189 fn vertically_packed_row_pair<P>(&self, r: usize, step: usize) -> Vec<P>
190 where
191 T: Copy,
192 P: PackedValue<Value = T>,
193 {
194 let rows = (0..P::WIDTH)
199 .map(|c| self.row_slice((r + c) % self.height()))
200 .collect_vec();
201
202 let next_rows = (0..P::WIDTH)
203 .map(|c| self.row_slice((r + c + step) % self.height()))
204 .collect_vec();
205
206 (0..self.width())
207 .map(|c| P::from_fn(|i| rows[i][c]))
208 .chain((0..self.width()).map(|c| P::from_fn(|i| next_rows[i][c])))
209 .collect_vec()
210 }
211
212 fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
213 where
214 Self: Sized,
215 {
216 VerticallyStridedRowIndexMap::new_view(self, stride, offset)
217 }
218
219 #[instrument(level = "debug", skip_all, fields(dims = %self.dimensions()))]
223 fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
224 where
225 T: Field,
226 EF: ExtensionField<T>,
227 {
228 let packed_width = self.width().div_ceil(T::Packing::WIDTH);
229
230 let packed_result = self
231 .par_padded_horizontally_packed_rows::<T::Packing>()
232 .zip(v)
233 .par_fold_reduce(
234 || EF::ExtensionPacking::zero_vec(packed_width),
235 |mut acc, (row, &scale)| {
236 let scale = EF::ExtensionPacking::from_base_fn(|i| {
237 T::Packing::from(scale.as_base_slice()[i])
238 });
239 izip!(&mut acc, row).for_each(|(l, r)| *l += scale * r);
240 acc
241 },
242 |mut acc_l, acc_r| {
243 izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
244 acc_l
245 },
246 );
247
248 packed_result
249 .into_iter()
250 .flat_map(|p| {
251 (0..T::Packing::WIDTH)
252 .map(move |i| EF::from_base_fn(|j| p.as_base_slice()[j].as_slice()[i]))
253 })
254 .take(self.width())
255 .collect()
256 }
257
258 fn dot_ext_powers<EF>(&self, base: EF) -> impl IndexedParallelIterator<Item = EF>
260 where
261 T: Field,
262 EF: ExtensionField<T>,
263 {
264 let powers_packed = base
265 .ext_powers_packed()
266 .take(self.width().next_multiple_of(T::Packing::WIDTH))
267 .collect_vec();
268 self.par_padded_horizontally_packed_rows::<T::Packing>()
269 .map(move |row_packed| {
270 let packed_sum_of_packed: EF::ExtensionPacking =
271 dot_product(powers_packed.iter().copied(), row_packed);
272 let sum_of_packed: EF = EF::from_base_fn(|i| {
273 packed_sum_of_packed.as_base_slice()[i]
274 .as_slice()
275 .iter()
276 .copied()
277 .sum()
278 });
279 sum_of_packed
280 })
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use alloc::vec::Vec;
287 use alloc::{format, vec};
288
289 use itertools::izip;
290 use p3_baby_bear::BabyBear;
291 use p3_field::extension::BinomialExtensionField;
292 use p3_field::FieldAlgebra;
293 use rand::thread_rng;
294
295 use super::*;
296
297 #[test]
298 fn test_columnwise_dot_product() {
299 type F = BabyBear;
300 type EF = BinomialExtensionField<BabyBear, 4>;
301
302 let m = RowMajorMatrix::<F>::rand(&mut thread_rng(), 1 << 8, 1 << 4);
303 let v = RowMajorMatrix::<EF>::rand(&mut thread_rng(), 1 << 8, 1).values;
304
305 let mut expected = vec![EF::ZERO; m.width()];
306 for (row, &scale) in izip!(m.rows(), &v) {
307 for (l, r) in izip!(&mut expected, row) {
308 *l += scale * r;
309 }
310 }
311
312 assert_eq!(m.columnwise_dot_product(&v), expected);
313 }
314
315 struct MockMatrix {
317 data: Vec<Vec<u32>>,
318 width: usize,
319 height: usize,
320 }
321
322 impl Matrix<u32> for MockMatrix {
323 type Row<'a> = alloc::vec::IntoIter<u32>;
324
325 fn width(&self) -> usize {
326 self.width
327 }
328
329 fn height(&self) -> usize {
330 self.height
331 }
332
333 fn row(&self, r: usize) -> Self::Row<'_> {
334 self.data[r].clone().into_iter()
335 }
336 }
337
338 #[test]
339 fn test_dimensions() {
340 let dims = Dimensions {
341 width: 3,
342 height: 5,
343 };
344 assert_eq!(dims.width, 3);
345 assert_eq!(dims.height, 5);
346 assert_eq!(format!("{:?}", dims), "3x5");
347 assert_eq!(format!("{}", dims), "3x5");
348 }
349
350 #[test]
351 fn test_mock_matrix_dimensions() {
352 let matrix = MockMatrix {
353 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
354 width: 3,
355 height: 3,
356 };
357 assert_eq!(matrix.width(), 3);
358 assert_eq!(matrix.height(), 3);
359 assert_eq!(
360 matrix.dimensions(),
361 Dimensions {
362 width: 3,
363 height: 3
364 }
365 );
366 }
367
368 #[test]
369 fn test_first_row() {
370 let matrix = MockMatrix {
371 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
372 width: 3,
373 height: 3,
374 };
375 let mut first_row = matrix.first_row();
376 assert_eq!(first_row.next(), Some(1));
377 assert_eq!(first_row.next(), Some(2));
378 assert_eq!(first_row.next(), Some(3));
379 }
380
381 #[test]
382 fn test_last_row() {
383 let matrix = MockMatrix {
384 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
385 width: 3,
386 height: 3,
387 };
388 let mut last_row = matrix.last_row();
389 assert_eq!(last_row.next(), Some(7));
390 assert_eq!(last_row.next(), Some(8));
391 assert_eq!(last_row.next(), Some(9));
392 }
393
394 #[test]
395 fn test_row_slice() {
396 let matrix = MockMatrix {
397 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
398 width: 3,
399 height: 3,
400 };
401 let row_slice = matrix.row_slice(1);
402 assert_eq!(row_slice.deref(), &[4, 5, 6]);
403 }
404
405 #[test]
406 fn test_to_row_major_matrix() {
407 let matrix = MockMatrix {
408 data: vec![vec![1, 2], vec![3, 4]],
409 width: 2,
410 height: 2,
411 };
412 let row_major = matrix.to_row_major_matrix();
413 assert_eq!(row_major.values, vec![1, 2, 3, 4]);
414 assert_eq!(row_major.width, 2);
415 }
416
417 #[test]
418 fn test_matrix_get() {
419 let matrix = MockMatrix {
420 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
421 width: 3,
422 height: 3,
423 };
424 assert_eq!(matrix.get(0, 0), 1);
425 assert_eq!(matrix.get(1, 2), 6);
426 assert_eq!(matrix.get(2, 1), 8);
427 }
428
429 #[test]
430 fn test_matrix_row_iteration() {
431 let matrix = MockMatrix {
432 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
433 width: 3,
434 height: 3,
435 };
436
437 let mut row_iter = matrix.row(1);
438 assert_eq!(row_iter.next(), Some(4));
439 assert_eq!(row_iter.next(), Some(5));
440 assert_eq!(row_iter.next(), Some(6));
441 assert_eq!(row_iter.next(), None);
442 }
443
444 #[test]
445 fn test_matrix_rows() {
446 let matrix = MockMatrix {
447 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
448 width: 3,
449 height: 3,
450 };
451
452 let all_rows: Vec<Vec<u32>> = matrix.rows().map(|row| row.collect()).collect();
453 assert_eq!(all_rows, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
454 }
455}