p3_matrix/
dense.rs

1use alloc::borrow::Cow;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::borrow::{Borrow, BorrowMut};
5use core::marker::PhantomData;
6use core::ops::Deref;
7use core::{iter, slice};
8
9use p3_field::{scale_slice_in_place, ExtensionField, Field, PackedValue};
10use p3_maybe_rayon::prelude::*;
11use rand::distributions::{Distribution, Standard};
12use rand::Rng;
13use serde::{Deserialize, Serialize};
14use tracing::instrument;
15
16use crate::Matrix;
17
18/// A dense matrix stored in row-major form.
19#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
20pub struct DenseMatrix<T, V = Vec<T>> {
21    pub values: V,
22    pub width: usize,
23    _phantom: PhantomData<T>,
24}
25
26pub type RowMajorMatrix<T> = DenseMatrix<T, Vec<T>>;
27pub type RowMajorMatrixView<'a, T> = DenseMatrix<T, &'a [T]>;
28pub type RowMajorMatrixViewMut<'a, T> = DenseMatrix<T, &'a mut [T]>;
29pub type RowMajorMatrixCow<'a, T> = DenseMatrix<T, Cow<'a, [T]>>;
30
31pub trait DenseStorage<T>: Borrow<[T]> + Send + Sync {
32    fn to_vec(self) -> Vec<T>;
33}
34// Cow doesn't impl IntoOwned so we can't blanket it
35impl<T: Clone + Send + Sync> DenseStorage<T> for Vec<T> {
36    fn to_vec(self) -> Vec<T> {
37        self
38    }
39}
40impl<T: Clone + Send + Sync> DenseStorage<T> for &[T] {
41    fn to_vec(self) -> Vec<T> {
42        <[T]>::to_vec(self)
43    }
44}
45impl<T: Clone + Send + Sync> DenseStorage<T> for &mut [T] {
46    fn to_vec(self) -> Vec<T> {
47        <[T]>::to_vec(self)
48    }
49}
50impl<T: Clone + Send + Sync> DenseStorage<T> for Cow<'_, [T]> {
51    fn to_vec(self) -> Vec<T> {
52        self.into_owned()
53    }
54}
55
56impl<T: Clone + Send + Sync + Default> DenseMatrix<T> {
57    /// Create a new dense matrix of the given dimensions, backed by a `Vec`, and filled with
58    /// default values.
59    #[must_use]
60    pub fn default(width: usize, height: usize) -> Self {
61        Self::new(vec![T::default(); width * height], width)
62    }
63}
64
65impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
66    #[must_use]
67    pub fn new(values: S, width: usize) -> Self {
68        debug_assert!(width == 0 || values.borrow().len() % width == 0);
69        Self {
70            values,
71            width,
72            _phantom: PhantomData,
73        }
74    }
75
76    #[must_use]
77    pub fn new_row(values: S) -> Self {
78        let width = values.borrow().len();
79        Self::new(values, width)
80    }
81
82    #[must_use]
83    pub fn new_col(values: S) -> Self {
84        Self::new(values, 1)
85    }
86
87    pub fn as_view(&self) -> RowMajorMatrixView<'_, T> {
88        RowMajorMatrixView::new(self.values.borrow(), self.width)
89    }
90
91    pub fn as_view_mut(&mut self) -> RowMajorMatrixViewMut<'_, T>
92    where
93        S: BorrowMut<[T]>,
94    {
95        RowMajorMatrixViewMut::new(self.values.borrow_mut(), self.width)
96    }
97
98    pub fn copy_from<S2>(&mut self, source: &DenseMatrix<T, S2>)
99    where
100        T: Copy,
101        S: BorrowMut<[T]>,
102        S2: DenseStorage<T>,
103    {
104        assert_eq!(self.dimensions(), source.dimensions());
105        // Equivalent to:
106        // self.values.borrow_mut().copy_from_slice(source.values.borrow());
107        self.par_rows_mut()
108            .zip(source.par_row_slices())
109            .for_each(|(dst, src)| {
110                dst.copy_from_slice(src);
111            });
112    }
113
114    pub fn flatten_to_base<F: Field>(&self) -> RowMajorMatrix<F>
115    where
116        T: ExtensionField<F>,
117    {
118        let width = self.width * T::D;
119        let values = self
120            .values
121            .borrow()
122            .iter()
123            .flat_map(|x| x.as_base_slice().iter().copied())
124            .collect();
125        RowMajorMatrix::new(values, width)
126    }
127
128    pub fn row_slices(&self) -> impl Iterator<Item = &[T]> {
129        self.values.borrow().chunks_exact(self.width)
130    }
131
132    pub fn par_row_slices(&self) -> impl IndexedParallelIterator<Item = &[T]>
133    where
134        T: Sync,
135    {
136        self.values.borrow().par_chunks_exact(self.width)
137    }
138
139    pub fn row_mut(&mut self, r: usize) -> &mut [T]
140    where
141        S: BorrowMut<[T]>,
142    {
143        &mut self.values.borrow_mut()[r * self.width..(r + 1) * self.width]
144    }
145
146    pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [T]>
147    where
148        S: BorrowMut<[T]>,
149    {
150        self.values.borrow_mut().chunks_exact_mut(self.width)
151    }
152
153    pub fn par_rows_mut<'a>(&'a mut self) -> impl IndexedParallelIterator<Item = &'a mut [T]>
154    where
155        T: 'a + Send,
156        S: BorrowMut<[T]>,
157    {
158        self.values.borrow_mut().par_chunks_exact_mut(self.width)
159    }
160
161    pub fn horizontally_packed_row_mut<P>(&mut self, r: usize) -> (&mut [P], &mut [T])
162    where
163        P: PackedValue<Value = T>,
164        S: BorrowMut<[T]>,
165    {
166        P::pack_slice_with_suffix_mut(self.row_mut(r))
167    }
168
169    pub fn scale_row(&mut self, r: usize, scale: T)
170    where
171        T: Field,
172        S: BorrowMut<[T]>,
173    {
174        scale_slice_in_place(scale, self.row_mut(r));
175    }
176
177    pub fn scale(&mut self, scale: T)
178    where
179        T: Field,
180        S: BorrowMut<[T]>,
181    {
182        scale_slice_in_place(scale, self.values.borrow_mut());
183    }
184
185    pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView<T>, RowMajorMatrixView<T>) {
186        let (lo, hi) = self.values.borrow().split_at(r * self.width);
187        (
188            DenseMatrix::new(lo, self.width),
189            DenseMatrix::new(hi, self.width),
190        )
191    }
192
193    pub fn split_rows_mut(
194        &mut self,
195        r: usize,
196    ) -> (RowMajorMatrixViewMut<T>, RowMajorMatrixViewMut<T>)
197    where
198        S: BorrowMut<[T]>,
199    {
200        let (lo, hi) = self.values.borrow_mut().split_at_mut(r * self.width);
201        (
202            DenseMatrix::new(lo, self.width),
203            DenseMatrix::new(hi, self.width),
204        )
205    }
206
207    pub fn par_row_chunks(
208        &self,
209        chunk_rows: usize,
210    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<T>>
211    where
212        T: Send,
213    {
214        self.values
215            .borrow()
216            .par_chunks(self.width * chunk_rows)
217            .map(|slice| RowMajorMatrixView::new(slice, self.width))
218    }
219
220    pub fn par_row_chunks_exact(
221        &self,
222        chunk_rows: usize,
223    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<T>>
224    where
225        T: Send,
226    {
227        self.values
228            .borrow()
229            .par_chunks_exact(self.width * chunk_rows)
230            .map(|slice| RowMajorMatrixView::new(slice, self.width))
231    }
232
233    pub fn par_row_chunks_mut(
234        &mut self,
235        chunk_rows: usize,
236    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
237    where
238        T: Send,
239        S: BorrowMut<[T]>,
240    {
241        self.values
242            .borrow_mut()
243            .par_chunks_mut(self.width * chunk_rows)
244            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
245    }
246
247    pub fn row_chunks_exact_mut(
248        &mut self,
249        chunk_rows: usize,
250    ) -> impl Iterator<Item = RowMajorMatrixViewMut<T>>
251    where
252        T: Send,
253        S: BorrowMut<[T]>,
254    {
255        self.values
256            .borrow_mut()
257            .chunks_exact_mut(self.width * chunk_rows)
258            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
259    }
260
261    pub fn par_row_chunks_exact_mut(
262        &mut self,
263        chunk_rows: usize,
264    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
265    where
266        T: Send,
267        S: BorrowMut<[T]>,
268    {
269        self.values
270            .borrow_mut()
271            .par_chunks_exact_mut(self.width * chunk_rows)
272            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
273    }
274
275    pub fn row_pair_mut(&mut self, row_1: usize, row_2: usize) -> (&mut [T], &mut [T])
276    where
277        S: BorrowMut<[T]>,
278    {
279        debug_assert_ne!(row_1, row_2);
280        let start_1 = row_1 * self.width;
281        let start_2 = row_2 * self.width;
282        let (lo, hi) = self.values.borrow_mut().split_at_mut(start_2);
283        (&mut lo[start_1..][..self.width], &mut hi[..self.width])
284    }
285
286    #[allow(clippy::type_complexity)]
287    pub fn packed_row_pair_mut<P>(
288        &mut self,
289        row_1: usize,
290        row_2: usize,
291    ) -> ((&mut [P], &mut [T]), (&mut [P], &mut [T]))
292    where
293        S: BorrowMut<[T]>,
294        P: PackedValue<Value = T>,
295    {
296        let (slice_1, slice_2) = self.row_pair_mut(row_1, row_2);
297        (
298            P::pack_slice_with_suffix_mut(slice_1),
299            P::pack_slice_with_suffix_mut(slice_2),
300        )
301    }
302
303    /// Append zeros to the "end" of the given matrix, except that the matrix is in bit-reversed order,
304    /// so in actuality we're interleaving zero rows.
305    #[instrument(level = "debug", skip_all)]
306    pub fn bit_reversed_zero_pad(self, added_bits: usize) -> RowMajorMatrix<T>
307    where
308        T: Field,
309    {
310        if added_bits == 0 {
311            return self.to_row_major_matrix();
312        }
313
314        // This is equivalent to:
315        //     reverse_matrix_index_bits(mat);
316        //     mat
317        //         .values
318        //         .resize(mat.values.len() << added_bits, F::ZERO);
319        //     reverse_matrix_index_bits(mat);
320        // But rather than implement it with bit reversals, we directly construct the resulting matrix,
321        // whose rows are zero except for rows whose low `added_bits` bits are zero.
322
323        let w = self.width;
324        let mut padded =
325            RowMajorMatrix::new(T::zero_vec(self.values.borrow().len() << added_bits), w);
326        padded
327            .par_row_chunks_exact_mut(1 << added_bits)
328            .zip(self.par_row_slices())
329            .for_each(|(mut ch, r)| ch.row_mut(0).copy_from_slice(r));
330
331        padded
332    }
333}
334
335impl<T: Clone + Send + Sync, S: DenseStorage<T>> Matrix<T> for DenseMatrix<T, S> {
336    #[inline]
337    fn width(&self) -> usize {
338        self.width
339    }
340
341    #[inline]
342    fn height(&self) -> usize {
343        if self.width == 0 {
344            0
345        } else {
346            self.values.borrow().len() / self.width
347        }
348    }
349
350    #[inline]
351    fn get(&self, r: usize, c: usize) -> T {
352        self.values.borrow()[r * self.width + c].clone()
353    }
354
355    type Row<'a>
356        = iter::Cloned<slice::Iter<'a, T>>
357    where
358        Self: 'a;
359
360    #[inline]
361    fn row(&self, r: usize) -> Self::Row<'_> {
362        self.values.borrow()[r * self.width..(r + 1) * self.width]
363            .iter()
364            .cloned()
365    }
366
367    #[inline]
368    fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
369        &self.values.borrow()[r * self.width..(r + 1) * self.width]
370    }
371
372    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
373    where
374        Self: Sized,
375        T: Clone,
376    {
377        RowMajorMatrix::new(self.values.to_vec(), self.width)
378    }
379
380    #[inline]
381    fn horizontally_packed_row<'a, P>(
382        &'a self,
383        r: usize,
384    ) -> (
385        impl Iterator<Item = P> + Send + Sync,
386        impl Iterator<Item = T> + Send + Sync,
387    )
388    where
389        P: PackedValue<Value = T>,
390        T: Clone + 'a,
391    {
392        let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
393        let (packed, sfx) = P::pack_slice_with_suffix(buf);
394        (packed.iter().cloned(), sfx.iter().cloned())
395    }
396
397    #[inline]
398    fn padded_horizontally_packed_row<'a, P>(
399        &'a self,
400        r: usize,
401    ) -> impl Iterator<Item = P> + Send + Sync
402    where
403        P: PackedValue<Value = T>,
404        T: Clone + Default + 'a,
405    {
406        let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
407        let (packed, sfx) = P::pack_slice_with_suffix(buf);
408        packed.iter().cloned().chain(iter::once(P::from_fn(|i| {
409            sfx.get(i).cloned().unwrap_or_default()
410        })))
411    }
412}
413
414impl<T: Clone + Default + Send + Sync> DenseMatrix<T, Vec<T>> {
415    pub fn as_cow<'a>(self) -> RowMajorMatrixCow<'a, T> {
416        RowMajorMatrixCow::new(Cow::Owned(self.values), self.width)
417    }
418
419    pub fn rand<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
420    where
421        Standard: Distribution<T>,
422    {
423        let values = rng.sample_iter(Standard).take(rows * cols).collect();
424        Self::new(values, cols)
425    }
426
427    pub fn rand_nonzero<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
428    where
429        T: Field,
430        Standard: Distribution<T>,
431    {
432        let values = rng
433            .sample_iter(Standard)
434            .filter(|x| !x.is_zero())
435            .take(rows * cols)
436            .collect();
437        Self::new(values, cols)
438    }
439
440    pub fn pad_to_height(&mut self, new_height: usize, fill: T) {
441        assert!(new_height >= self.height());
442        self.values.resize(self.width * new_height, fill);
443    }
444}
445
446impl<T: Copy + Default + Send + Sync> DenseMatrix<T, Vec<T>> {
447    pub fn transpose(&self) -> Self {
448        let nelts = self.height() * self.width();
449        let mut values = vec![T::default(); nelts];
450        transpose::transpose(&self.values, &mut values, self.width(), self.height());
451        Self::new(values, self.height())
452    }
453
454    pub fn transpose_into(&self, other: &mut Self) {
455        assert_eq!(self.height(), other.width());
456        assert_eq!(other.height(), self.width());
457        transpose::transpose(&self.values, &mut other.values, self.width(), self.height());
458    }
459}
460
461impl<'a, T: Clone + Default + Send + Sync> DenseMatrix<T, &'a [T]> {
462    pub fn as_cow(self) -> RowMajorMatrixCow<'a, T> {
463        RowMajorMatrixCow::new(Cow::Borrowed(self.values), self.width)
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    #[test]
472    fn test_transpose_square_matrix() {
473        const START_INDEX: usize = 1;
474        const VALUE_LEN: usize = 9;
475        const WIDTH: usize = 3;
476        const HEIGHT: usize = 3;
477
478        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
479        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
480        let transposed = matrix.transpose();
481        let should_be_transposed_values = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
482        let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
483        assert_eq!(transposed, should_be_transposed);
484    }
485
486    #[test]
487    fn test_transpose_row_matrix() {
488        const START_INDEX: usize = 1;
489        const VALUE_LEN: usize = 30;
490        const WIDTH: usize = 1;
491        const HEIGHT: usize = 30;
492
493        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
494        let matrix = RowMajorMatrix::new(matrix_values.clone(), WIDTH);
495        let transposed = matrix.transpose();
496        let should_be_transposed = RowMajorMatrix::new(matrix_values, HEIGHT);
497        assert_eq!(transposed, should_be_transposed);
498    }
499
500    #[test]
501    fn test_transpose_rectangular_matrix() {
502        const START_INDEX: usize = 1;
503        const VALUE_LEN: usize = 30;
504        const WIDTH: usize = 5;
505        const HEIGHT: usize = 6;
506
507        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
508        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
509        let transposed = matrix.transpose();
510        let should_be_transposed_values = vec![
511            1, 6, 11, 16, 21, 26, 2, 7, 12, 17, 22, 27, 3, 8, 13, 18, 23, 28, 4, 9, 14, 19, 24, 29,
512            5, 10, 15, 20, 25, 30,
513        ];
514        let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
515        assert_eq!(transposed, should_be_transposed);
516    }
517
518    #[test]
519    fn test_transpose_larger_rectangular_matrix() {
520        const START_INDEX: usize = 1;
521        const VALUE_LEN: usize = 131072; // 512 * 256
522        const WIDTH: usize = 256;
523        const HEIGHT: usize = 512;
524
525        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
526        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
527        let transposed = matrix.clone().transpose();
528
529        assert_eq!(transposed.width(), HEIGHT);
530        assert_eq!(transposed.height(), WIDTH);
531
532        for col_index in 0..WIDTH {
533            for row_index in 0..HEIGHT {
534                assert_eq!(
535                    matrix.values[row_index * WIDTH + col_index],
536                    transposed.values[col_index * HEIGHT + row_index]
537                );
538            }
539        }
540    }
541
542    #[test]
543    fn test_transpose_very_large_rectangular_matrix() {
544        const START_INDEX: usize = 1;
545        const VALUE_LEN: usize = 1048576; // 512 * 256
546        const WIDTH: usize = 1024;
547        const HEIGHT: usize = 1024;
548
549        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
550        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
551        let transposed = matrix.clone().transpose();
552
553        assert_eq!(transposed.width(), HEIGHT);
554        assert_eq!(transposed.height(), WIDTH);
555
556        for col_index in 0..WIDTH {
557            for row_index in 0..HEIGHT {
558                assert_eq!(
559                    matrix.values[row_index * WIDTH + col_index],
560                    transposed.values[col_index * HEIGHT + row_index]
561                );
562            }
563        }
564    }
565}