p3_util/
transpose.rs

1use core::ptr::swap;
2
3const LB_BLOCK_SIZE: usize = 3;
4
5/// Transpose square matrix in-place
6/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies
7/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition
8/// swaps `M[i, j]` and `M[j, i]`.
9///
10/// SAFETY:
11/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all
12/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap.
13unsafe fn transpose_in_place_square_small<T>(
14    arr: &mut [T],
15    lb_stride: usize,
16    lb_size: usize,
17    x: usize,
18) {
19    for i in x + 1..x + (1 << lb_size) {
20        for j in x..i {
21            swap(
22                arr.get_unchecked_mut(i + (j << lb_stride)),
23                arr.get_unchecked_mut((i << lb_stride) + j),
24            );
25        }
26    }
27}
28
29/// Transpose square matrices and swap
30/// The matrices are of size `1 << lb_size` by `1 << lb_size`. They occupy
31/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]`
32/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`.
33///
34/// SAFETY:
35/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid
36/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to
37/// prevent overlap.
38unsafe fn transpose_swap_square_small<T>(
39    arr: &mut [T],
40    lb_stride: usize,
41    lb_size: usize,
42    x: usize,
43    y: usize,
44) {
45    for i in x..x + (1 << lb_size) {
46        for j in y..y + (1 << lb_size) {
47            swap(
48                arr.get_unchecked_mut(i + (j << lb_stride)),
49                arr.get_unchecked_mut((i << lb_stride) + j),
50            );
51        }
52    }
53}
54
55/// Transpose square matrices and swap
56/// The matrices are of size `1 << lb_size` by `1 << lb_size`. They occupy
57/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]`
58/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`.
59///
60/// SAFETY:
61/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid
62/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to
63/// prevent overlap.
64unsafe fn transpose_swap_square<T>(
65    arr: &mut [T],
66    lb_stride: usize,
67    lb_size: usize,
68    x: usize,
69    y: usize,
70) {
71    if lb_size <= LB_BLOCK_SIZE {
72        transpose_swap_square_small(arr, lb_stride, lb_size, x, y);
73    } else {
74        let lb_block_size = lb_size - 1;
75        let block_size = 1 << lb_block_size;
76        transpose_swap_square(arr, lb_stride, lb_block_size, x, y);
77        transpose_swap_square(arr, lb_stride, lb_block_size, x + block_size, y);
78        transpose_swap_square(arr, lb_stride, lb_block_size, x, y + block_size);
79        transpose_swap_square(
80            arr,
81            lb_stride,
82            lb_block_size,
83            x + block_size,
84            y + block_size,
85        );
86    }
87}
88
89/// Transpose square matrix in-place
90/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies
91/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition
92/// swaps `M[i, j]` and `M[j, i]`.
93///
94/// SAFETY:
95/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all
96/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap.
97pub(crate) unsafe fn transpose_in_place_square<T>(
98    arr: &mut [T],
99    lb_stride: usize,
100    lb_size: usize,
101    x: usize,
102) {
103    if lb_size <= LB_BLOCK_SIZE {
104        transpose_in_place_square_small(arr, lb_stride, lb_size, x);
105    } else {
106        let lb_block_size = lb_size - 1;
107        let block_size = 1 << lb_block_size;
108        transpose_in_place_square(arr, lb_stride, lb_block_size, x);
109        transpose_swap_square(arr, lb_stride, lb_block_size, x, x + block_size);
110        transpose_in_place_square(arr, lb_stride, lb_block_size, x + block_size);
111    }
112}