12use strength_reduce::StrengthReducedUsize;
3use num_integer;
45fn multiplicative_inverse(a: usize, n: usize) -> usize {
6// we're going to use a modified version extended euclidean algorithm
7 // we only need half the output
89let mut t = 0;
10let mut t_new = 1;
1112let mut r = n;
13let mut r_new = a;
1415while r_new > 0 {
16let quotient = r / r_new;
1718 r = r - quotient * r_new;
19 core::mem::swap(&mut r, &mut r_new);
2021// t might go negative here, so we have to do a checked subtract
22 // if it underflows, wrap it around to the other end of the modulo
23 // IE, 3 - 4 mod 5 = -1 mod 5 = 4
24let t_subtract = quotient * t_new;
25 t = if t_subtract < t {
26 t - t_subtract
27 } else {
28 n - (t_subtract - t) % n
29 };
30 core::mem::swap(&mut t, &mut t_new);
31 }
3233 t
34}
3536/// Transpose the input array in-place.
37///
38/// Given an input array of size input_width * input_height, representing flattened 2D data stored in row-major order,
39/// transpose the rows and columns of that input array, in-place.
40///
41/// Despite being in-place, this algorithm requires max(width, height) in scratch space.
42///
43/// ```
44/// // row-major order: the rows of our 2D array are contiguous,
45/// // and the columns are strided
46/// let original_array = vec![ 1, 2, 3,
47/// 4, 5, 6];
48/// let mut input_array = original_array.clone();
49///
50/// // Treat our 6-element array as a 2D 3x2 array, and transpose it to a 2x3 array
51/// // transpose_inplace requires max(width, height) scratch space, which is in this case 3
52/// let mut scratch = vec![0; 3];
53/// transpose::transpose_inplace(&mut input_array, &mut scratch, 3, 2);
54///
55/// // The rows have become the columns, and the columns have become the rows
56/// let expected_array = vec![ 1, 4,
57/// 2, 5,
58/// 3, 6];
59/// assert_eq!(input_array, expected_array);
60///
61/// // If we transpose it again, we should get our original data back.
62/// transpose::transpose_inplace(&mut input_array, &mut scratch, 2, 3);
63/// assert_eq!(original_array, input_array);
64/// ```
65///
66/// # Panics
67///
68/// Panics if `input.len() != input_width * input_height` or if `scratch.len() != max(width, height)`
69pub fn transpose_inplace<T: Copy>(buffer: &mut [T], scratch: &mut [T], width: usize, height: usize) {
70assert_eq!(width.checked_mul(height), Some(buffer.len()));
71assert_eq!(core::cmp::max(width, height), scratch.len());
7273let gcd = StrengthReducedUsize::new(num_integer::gcd(width, height));
74let a = StrengthReducedUsize::new(height / gcd);
75let b = StrengthReducedUsize::new(width / gcd);
76let a_inverse = multiplicative_inverse(a.get(), b.get());
77let strength_reduced_height = StrengthReducedUsize::new(height);
7879let index_fn = |x, y| x + y * width;
8081if gcd.get() > 1 {
82for x in 0..width {
83let column_offset = (x / b) % strength_reduced_height;
84let wrapping_point = height - column_offset;
8586// wrapped rotation -- do the "right half" of the array, then the "left half"
87for y in 0..wrapping_point {
88 scratch[y] = buffer[index_fn(x, y + column_offset)];
89 }
90for y in wrapping_point..height {
91 scratch[y] = buffer[index_fn(x, y + column_offset - height)];
92 }
9394// copy the data back into the column
95for y in 0..height {
96 buffer[index_fn(x, y)] = scratch[y];
97 }
98 }
99 }
100101// Permute the rows
102{
103let row_scratch = &mut scratch[0..width];
104105for (y, row) in buffer.chunks_exact_mut(width).enumerate() {
106for x in 0..width {
107let helper_val = if y <= height + x%gcd - gcd.get() { x + y*(width-1) } else { x + y*(width-1) + height };
108let (helper_div, helper_mod) = StrengthReducedUsize::div_rem(helper_val, gcd);
109110let gather_x = (a_inverse * helper_div)%b + b.get()*helper_mod;
111 row_scratch[x] = row[gather_x];
112 }
113114 row.copy_from_slice(row_scratch);
115 }
116 }
117118// Permute the columns
119for x in 0..width {
120let column_offset = x % strength_reduced_height;
121let wrapping_point = height - column_offset;
122123// wrapped rotation -- do the "right half" of the array, then the "left half"
124for y in 0..wrapping_point {
125 scratch[y] = buffer[index_fn(x, y + column_offset)];
126 }
127for y in wrapping_point..height {
128 scratch[y] = buffer[index_fn(x, y + column_offset - height)];
129 }
130131// Copy the data back to the buffer, but shuffle it as we do so
132for y in 0..height {
133let shuffled_y = (y * width - (y / a)) % strength_reduced_height;
134 buffer[index_fn(x, y)] = scratch[shuffled_y];
135 }
136 }
137}