p3_util/transpose.rs
1use core::ptr::swap;
2
3const LB_BLOCK_SIZE: usize = 3;
4
5/// Transpose square matrix in-place
6/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies
7/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition
8/// swaps `M[i, j]` and `M[j, i]`.
9///
10/// SAFETY:
11/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all
12/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap.
13unsafe fn transpose_in_place_square_small<T>(
14 arr: &mut [T],
15 lb_stride: usize,
16 lb_size: usize,
17 x: usize,
18) {
19 for i in x + 1..x + (1 << lb_size) {
20 for j in x..i {
21 swap(
22 arr.get_unchecked_mut(i + (j << lb_stride)),
23 arr.get_unchecked_mut((i << lb_stride) + j),
24 );
25 }
26 }
27}
28
29/// Transpose square matrices and swap
30/// The matrices are of size `1 << lb_size` by `1 << lb_size`. They occupy
31/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]`
32/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`.
33///
34/// SAFETY:
35/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid
36/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to
37/// prevent overlap.
38unsafe fn transpose_swap_square_small<T>(
39 arr: &mut [T],
40 lb_stride: usize,
41 lb_size: usize,
42 x: usize,
43 y: usize,
44) {
45 for i in x..x + (1 << lb_size) {
46 for j in y..y + (1 << lb_size) {
47 swap(
48 arr.get_unchecked_mut(i + (j << lb_stride)),
49 arr.get_unchecked_mut((i << lb_stride) + j),
50 );
51 }
52 }
53}
54
55/// Transpose square matrices and swap
56/// The matrices are of size `1 << lb_size` by `1 << lb_size`. They occupy
57/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]`
58/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`.
59///
60/// SAFETY:
61/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid
62/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to
63/// prevent overlap.
64unsafe fn transpose_swap_square<T>(
65 arr: &mut [T],
66 lb_stride: usize,
67 lb_size: usize,
68 x: usize,
69 y: usize,
70) {
71 if lb_size <= LB_BLOCK_SIZE {
72 transpose_swap_square_small(arr, lb_stride, lb_size, x, y);
73 } else {
74 let lb_block_size = lb_size - 1;
75 let block_size = 1 << lb_block_size;
76 transpose_swap_square(arr, lb_stride, lb_block_size, x, y);
77 transpose_swap_square(arr, lb_stride, lb_block_size, x + block_size, y);
78 transpose_swap_square(arr, lb_stride, lb_block_size, x, y + block_size);
79 transpose_swap_square(
80 arr,
81 lb_stride,
82 lb_block_size,
83 x + block_size,
84 y + block_size,
85 );
86 }
87}
88
89/// Transpose square matrix in-place
90/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies
91/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition
92/// swaps `M[i, j]` and `M[j, i]`.
93///
94/// SAFETY:
95/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all
96/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap.
97pub(crate) unsafe fn transpose_in_place_square<T>(
98 arr: &mut [T],
99 lb_stride: usize,
100 lb_size: usize,
101 x: usize,
102) {
103 if lb_size <= LB_BLOCK_SIZE {
104 transpose_in_place_square_small(arr, lb_stride, lb_size, x);
105 } else {
106 let lb_block_size = lb_size - 1;
107 let block_size = 1 << lb_block_size;
108 transpose_in_place_square(arr, lb_stride, lb_block_size, x);
109 transpose_swap_square(arr, lb_stride, lb_block_size, x, x + block_size);
110 transpose_in_place_square(arr, lb_stride, lb_block_size, x + block_size);
111 }
112}