1use alloc::borrow::Cow;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::borrow::{Borrow, BorrowMut};
5use core::marker::PhantomData;
6use core::ops::Deref;
7use core::{iter, slice};
8
9use p3_field::{scale_slice_in_place, ExtensionField, Field, PackedValue};
10use p3_maybe_rayon::prelude::*;
11use rand::distributions::{Distribution, Standard};
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14use tracing::instrument;
15
16use crate::Matrix;
17
18#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
20pub struct DenseMatrix<T, V = Vec<T>> {
21 pub values: V,
22 pub width: usize,
23 _phantom: PhantomData<T>,
24}
25
26pub type RowMajorMatrix<T> = DenseMatrix<T, Vec<T>>;
27pub type RowMajorMatrixView<'a, T> = DenseMatrix<T, &'a [T]>;
28pub type RowMajorMatrixViewMut<'a, T> = DenseMatrix<T, &'a mut [T]>;
29pub type RowMajorMatrixCow<'a, T> = DenseMatrix<T, Cow<'a, [T]>>;
30
31pub trait DenseStorage<T>: Borrow<[T]> + Send + Sync {
32 fn to_vec(self) -> Vec<T>;
33}
34impl<T: Clone + Send + Sync> DenseStorage<T> for Vec<T> {
36 fn to_vec(self) -> Vec<T> {
37 self
38 }
39}
40impl<T: Clone + Send + Sync> DenseStorage<T> for &[T] {
41 fn to_vec(self) -> Vec<T> {
42 <[T]>::to_vec(self)
43 }
44}
45impl<T: Clone + Send + Sync> DenseStorage<T> for &mut [T] {
46 fn to_vec(self) -> Vec<T> {
47 <[T]>::to_vec(self)
48 }
49}
50impl<T: Clone + Send + Sync> DenseStorage<T> for Cow<'_, [T]> {
51 fn to_vec(self) -> Vec<T> {
52 self.into_owned()
53 }
54}
55
56impl<T: Clone + Send + Sync + Default> DenseMatrix<T> {
57 #[must_use]
60 pub fn default(width: usize, height: usize) -> Self {
61 Self::new(vec![T::default(); width * height], width)
62 }
63}
64
65impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
66 #[must_use]
67 pub fn new(values: S, width: usize) -> Self {
68 debug_assert!(width == 0 || values.borrow().len() % width == 0);
69 Self {
70 values,
71 width,
72 _phantom: PhantomData,
73 }
74 }
75
76 #[must_use]
77 pub fn new_row(values: S) -> Self {
78 let width = values.borrow().len();
79 Self::new(values, width)
80 }
81
82 #[must_use]
83 pub fn new_col(values: S) -> Self {
84 Self::new(values, 1)
85 }
86
87 pub fn as_view(&self) -> RowMajorMatrixView<'_, T> {
88 RowMajorMatrixView::new(self.values.borrow(), self.width)
89 }
90
91 pub fn as_view_mut(&mut self) -> RowMajorMatrixViewMut<'_, T>
92 where
93 S: BorrowMut<[T]>,
94 {
95 RowMajorMatrixViewMut::new(self.values.borrow_mut(), self.width)
96 }
97
98 pub fn copy_from<S2>(&mut self, source: &DenseMatrix<T, S2>)
99 where
100 T: Copy,
101 S: BorrowMut<[T]>,
102 S2: DenseStorage<T>,
103 {
104 assert_eq!(self.dimensions(), source.dimensions());
105 self.par_rows_mut()
108 .zip(source.par_row_slices())
109 .for_each(|(dst, src)| {
110 dst.copy_from_slice(src);
111 });
112 }
113
114 pub fn flatten_to_base<F: Field>(&self) -> RowMajorMatrix<F>
115 where
116 T: ExtensionField<F>,
117 {
118 let width = self.width * T::D;
119 let values = self
120 .values
121 .borrow()
122 .iter()
123 .flat_map(|x| x.as_base_slice().iter().copied())
124 .collect();
125 RowMajorMatrix::new(values, width)
126 }
127
128 pub fn row_slices(&self) -> impl Iterator<Item = &[T]> {
129 self.values.borrow().chunks_exact(self.width)
130 }
131
132 pub fn par_row_slices(&self) -> impl IndexedParallelIterator<Item = &[T]>
133 where
134 T: Sync,
135 {
136 self.values.borrow().par_chunks_exact(self.width)
137 }
138
139 pub fn row_mut(&mut self, r: usize) -> &mut [T]
140 where
141 S: BorrowMut<[T]>,
142 {
143 &mut self.values.borrow_mut()[r * self.width..(r + 1) * self.width]
144 }
145
146 pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [T]>
147 where
148 S: BorrowMut<[T]>,
149 {
150 self.values.borrow_mut().chunks_exact_mut(self.width)
151 }
152
153 pub fn par_rows_mut<'a>(&'a mut self) -> impl IndexedParallelIterator<Item = &'a mut [T]>
154 where
155 T: 'a + Send,
156 S: BorrowMut<[T]>,
157 {
158 self.values.borrow_mut().par_chunks_exact_mut(self.width)
159 }
160
161 pub fn horizontally_packed_row_mut<P>(&mut self, r: usize) -> (&mut [P], &mut [T])
162 where
163 P: PackedValue<Value = T>,
164 S: BorrowMut<[T]>,
165 {
166 P::pack_slice_with_suffix_mut(self.row_mut(r))
167 }
168
169 pub fn scale_row(&mut self, r: usize, scale: T)
170 where
171 T: Field,
172 S: BorrowMut<[T]>,
173 {
174 scale_slice_in_place(scale, self.row_mut(r));
175 }
176
177 pub fn scale(&mut self, scale: T)
178 where
179 T: Field,
180 S: BorrowMut<[T]>,
181 {
182 scale_slice_in_place(scale, self.values.borrow_mut());
183 }
184
185 pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView<T>, RowMajorMatrixView<T>) {
186 let (lo, hi) = self.values.borrow().split_at(r * self.width);
187 (
188 DenseMatrix::new(lo, self.width),
189 DenseMatrix::new(hi, self.width),
190 )
191 }
192
193 pub fn split_rows_mut(
194 &mut self,
195 r: usize,
196 ) -> (RowMajorMatrixViewMut<T>, RowMajorMatrixViewMut<T>)
197 where
198 S: BorrowMut<[T]>,
199 {
200 let (lo, hi) = self.values.borrow_mut().split_at_mut(r * self.width);
201 (
202 DenseMatrix::new(lo, self.width),
203 DenseMatrix::new(hi, self.width),
204 )
205 }
206
207 pub fn par_row_chunks(
208 &self,
209 chunk_rows: usize,
210 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<T>>
211 where
212 T: Send,
213 {
214 self.values
215 .borrow()
216 .par_chunks(self.width * chunk_rows)
217 .map(|slice| RowMajorMatrixView::new(slice, self.width))
218 }
219
220 pub fn par_row_chunks_exact(
221 &self,
222 chunk_rows: usize,
223 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<T>>
224 where
225 T: Send,
226 {
227 self.values
228 .borrow()
229 .par_chunks_exact(self.width * chunk_rows)
230 .map(|slice| RowMajorMatrixView::new(slice, self.width))
231 }
232
233 pub fn par_row_chunks_mut(
234 &mut self,
235 chunk_rows: usize,
236 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
237 where
238 T: Send,
239 S: BorrowMut<[T]>,
240 {
241 self.values
242 .borrow_mut()
243 .par_chunks_mut(self.width * chunk_rows)
244 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
245 }
246
247 pub fn row_chunks_exact_mut(
248 &mut self,
249 chunk_rows: usize,
250 ) -> impl Iterator<Item = RowMajorMatrixViewMut<T>>
251 where
252 T: Send,
253 S: BorrowMut<[T]>,
254 {
255 self.values
256 .borrow_mut()
257 .chunks_exact_mut(self.width * chunk_rows)
258 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
259 }
260
261 pub fn par_row_chunks_exact_mut(
262 &mut self,
263 chunk_rows: usize,
264 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
265 where
266 T: Send,
267 S: BorrowMut<[T]>,
268 {
269 self.values
270 .borrow_mut()
271 .par_chunks_exact_mut(self.width * chunk_rows)
272 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
273 }
274
275 pub fn row_pair_mut(&mut self, row_1: usize, row_2: usize) -> (&mut [T], &mut [T])
276 where
277 S: BorrowMut<[T]>,
278 {
279 debug_assert_ne!(row_1, row_2);
280 let start_1 = row_1 * self.width;
281 let start_2 = row_2 * self.width;
282 let (lo, hi) = self.values.borrow_mut().split_at_mut(start_2);
283 (&mut lo[start_1..][..self.width], &mut hi[..self.width])
284 }
285
286 #[allow(clippy::type_complexity)]
287 pub fn packed_row_pair_mut<P>(
288 &mut self,
289 row_1: usize,
290 row_2: usize,
291 ) -> ((&mut [P], &mut [T]), (&mut [P], &mut [T]))
292 where
293 S: BorrowMut<[T]>,
294 P: PackedValue<Value = T>,
295 {
296 let (slice_1, slice_2) = self.row_pair_mut(row_1, row_2);
297 (
298 P::pack_slice_with_suffix_mut(slice_1),
299 P::pack_slice_with_suffix_mut(slice_2),
300 )
301 }
302
303 #[instrument(level = "debug", skip_all)]
306 pub fn bit_reversed_zero_pad(self, added_bits: usize) -> RowMajorMatrix<T>
307 where
308 T: Field,
309 {
310 if added_bits == 0 {
311 return self.to_row_major_matrix();
312 }
313
314 let w = self.width;
324 let mut padded =
325 RowMajorMatrix::new(T::zero_vec(self.values.borrow().len() << added_bits), w);
326 padded
327 .par_row_chunks_exact_mut(1 << added_bits)
328 .zip(self.par_row_slices())
329 .for_each(|(mut ch, r)| ch.row_mut(0).copy_from_slice(r));
330
331 padded
332 }
333}
334
335impl<T: Clone + Send + Sync, S: DenseStorage<T>> Matrix<T> for DenseMatrix<T, S> {
336 #[inline]
337 fn width(&self) -> usize {
338 self.width
339 }
340
341 #[inline]
342 fn height(&self) -> usize {
343 if self.width == 0 {
344 0
345 } else {
346 self.values.borrow().len() / self.width
347 }
348 }
349
350 #[inline]
351 fn get(&self, r: usize, c: usize) -> T {
352 self.values.borrow()[r * self.width + c].clone()
353 }
354
355 type Row<'a>
356 = iter::Cloned<slice::Iter<'a, T>>
357 where
358 Self: 'a;
359
360 #[inline]
361 fn row(&self, r: usize) -> Self::Row<'_> {
362 self.values.borrow()[r * self.width..(r + 1) * self.width]
363 .iter()
364 .cloned()
365 }
366
367 #[inline]
368 fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
369 &self.values.borrow()[r * self.width..(r + 1) * self.width]
370 }
371
372 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
373 where
374 Self: Sized,
375 T: Clone,
376 {
377 RowMajorMatrix::new(self.values.to_vec(), self.width)
378 }
379
380 #[inline]
381 fn horizontally_packed_row<'a, P>(
382 &'a self,
383 r: usize,
384 ) -> (
385 impl Iterator<Item = P> + Send + Sync,
386 impl Iterator<Item = T> + Send + Sync,
387 )
388 where
389 P: PackedValue<Value = T>,
390 T: Clone + 'a,
391 {
392 let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
393 let (packed, sfx) = P::pack_slice_with_suffix(buf);
394 (packed.iter().cloned(), sfx.iter().cloned())
395 }
396
397 #[inline]
398 fn padded_horizontally_packed_row<'a, P>(
399 &'a self,
400 r: usize,
401 ) -> impl Iterator<Item = P> + Send + Sync
402 where
403 P: PackedValue<Value = T>,
404 T: Clone + Default + 'a,
405 {
406 let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
407 let (packed, sfx) = P::pack_slice_with_suffix(buf);
408 packed.iter().cloned().chain(iter::once(P::from_fn(|i| {
409 sfx.get(i).cloned().unwrap_or_default()
410 })))
411 }
412}
413
414impl<T: Clone + Default + Send + Sync> DenseMatrix<T, Vec<T>> {
415 pub fn as_cow<'a>(self) -> RowMajorMatrixCow<'a, T> {
416 RowMajorMatrixCow::new(Cow::Owned(self.values), self.width)
417 }
418
419 pub fn rand<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
420 where
421 Standard: Distribution<T>,
422 {
423 let values = rng.sample_iter(Standard).take(rows * cols).collect();
424 Self::new(values, cols)
425 }
426
427 pub fn rand_nonzero<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
428 where
429 T: Field,
430 Standard: Distribution<T>,
431 {
432 let values = rng
433 .sample_iter(Standard)
434 .filter(|x| !x.is_zero())
435 .take(rows * cols)
436 .collect();
437 Self::new(values, cols)
438 }
439
440 pub fn pad_to_height(&mut self, new_height: usize, fill: T) {
441 assert!(new_height >= self.height());
442 self.values.resize(self.width * new_height, fill);
443 }
444}
445
446impl<T: Copy + Default + Send + Sync> DenseMatrix<T, Vec<T>> {
447 pub fn transpose(&self) -> Self {
448 let nelts = self.height() * self.width();
449 let mut values = vec![T::default(); nelts];
450 transpose::transpose(&self.values, &mut values, self.width(), self.height());
451 Self::new(values, self.height())
452 }
453
454 pub fn transpose_into(&self, other: &mut Self) {
455 assert_eq!(self.height(), other.width());
456 assert_eq!(other.height(), self.width());
457 transpose::transpose(&self.values, &mut other.values, self.width(), self.height());
458 }
459}
460
461impl<'a, T: Clone + Default + Send + Sync> DenseMatrix<T, &'a [T]> {
462 pub fn as_cow(self) -> RowMajorMatrixCow<'a, T> {
463 RowMajorMatrixCow::new(Cow::Borrowed(self.values), self.width)
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470
471 #[test]
472 fn test_transpose_square_matrix() {
473 const START_INDEX: usize = 1;
474 const VALUE_LEN: usize = 9;
475 const WIDTH: usize = 3;
476 const HEIGHT: usize = 3;
477
478 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
479 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
480 let transposed = matrix.transpose();
481 let should_be_transposed_values = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
482 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
483 assert_eq!(transposed, should_be_transposed);
484 }
485
486 #[test]
487 fn test_transpose_row_matrix() {
488 const START_INDEX: usize = 1;
489 const VALUE_LEN: usize = 30;
490 const WIDTH: usize = 1;
491 const HEIGHT: usize = 30;
492
493 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
494 let matrix = RowMajorMatrix::new(matrix_values.clone(), WIDTH);
495 let transposed = matrix.transpose();
496 let should_be_transposed = RowMajorMatrix::new(matrix_values, HEIGHT);
497 assert_eq!(transposed, should_be_transposed);
498 }
499
500 #[test]
501 fn test_transpose_rectangular_matrix() {
502 const START_INDEX: usize = 1;
503 const VALUE_LEN: usize = 30;
504 const WIDTH: usize = 5;
505 const HEIGHT: usize = 6;
506
507 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
508 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
509 let transposed = matrix.transpose();
510 let should_be_transposed_values = vec![
511 1, 6, 11, 16, 21, 26, 2, 7, 12, 17, 22, 27, 3, 8, 13, 18, 23, 28, 4, 9, 14, 19, 24, 29,
512 5, 10, 15, 20, 25, 30,
513 ];
514 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
515 assert_eq!(transposed, should_be_transposed);
516 }
517
518 #[test]
519 fn test_transpose_larger_rectangular_matrix() {
520 const START_INDEX: usize = 1;
521 const VALUE_LEN: usize = 131072; const WIDTH: usize = 256;
523 const HEIGHT: usize = 512;
524
525 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
526 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
527 let transposed = matrix.clone().transpose();
528
529 assert_eq!(transposed.width(), HEIGHT);
530 assert_eq!(transposed.height(), WIDTH);
531
532 for col_index in 0..WIDTH {
533 for row_index in 0..HEIGHT {
534 assert_eq!(
535 matrix.values[row_index * WIDTH + col_index],
536 transposed.values[col_index * HEIGHT + row_index]
537 );
538 }
539 }
540 }
541
542 #[test]
543 fn test_transpose_very_large_rectangular_matrix() {
544 const START_INDEX: usize = 1;
545 const VALUE_LEN: usize = 1048576; const WIDTH: usize = 1024;
547 const HEIGHT: usize = 1024;
548
549 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
550 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
551 let transposed = matrix.clone().transpose();
552
553 assert_eq!(transposed.width(), HEIGHT);
554 assert_eq!(transposed.height(), WIDTH);
555
556 for col_index in 0..WIDTH {
557 for row_index in 0..HEIGHT {
558 assert_eq!(
559 matrix.values[row_index * WIDTH + col_index],
560 transposed.values[col_index * HEIGHT + row_index]
561 );
562 }
563 }
564 }
565}