ruint/algorithms/gcd/
matrix.rs

1#![allow(clippy::use_self)]
2
3use crate::Uint;
4
5/// ⚠️ Lehmer update matrix
6///
7/// **Warning.** This struct is not part of the stable API.
8///
9/// Signs are implicit, the boolean `.4` encodes which of two sign
10/// patterns applies. The signs and layout of the matrix are:
11///
12/// ```text
13///     true          false
14///  [ .0  -.1]    [-.0   .1]
15///  [-.2   .3]    [ .2  -.3]
16/// ```
17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
18pub struct Matrix(pub u64, pub u64, pub u64, pub u64, pub bool);
19
20impl Matrix {
21    pub const IDENTITY: Self = Self(1, 0, 0, 1, true);
22
23    /// Returns the matrix product `self * other`.
24    #[inline]
25    #[allow(clippy::suspicious_operation_groupings)]
26    #[must_use]
27    pub const fn compose(self, other: Self) -> Self {
28        Self(
29            self.0 * other.0 + self.1 * other.2,
30            self.0 * other.1 + self.1 * other.3,
31            self.2 * other.0 + self.3 * other.2,
32            self.2 * other.1 + self.3 * other.3,
33            self.4 ^ !other.4,
34        )
35    }
36
37    /// Applies the matrix to a `Uint`.
38    #[inline]
39    pub fn apply<const BITS: usize, const LIMBS: usize>(
40        &self,
41        a: &mut Uint<BITS, LIMBS>,
42        b: &mut Uint<BITS, LIMBS>,
43    ) {
44        if BITS == 0 {
45            return;
46        }
47        // OPT: We can avoid the temporary if we implement a dedicated matrix
48        // multiplication.
49        let (c, d) = if self.4 {
50            (
51                Uint::from(self.0) * *a - Uint::from(self.1) * *b,
52                Uint::from(self.3) * *b - Uint::from(self.2) * *a,
53            )
54        } else {
55            (
56                Uint::from(self.1) * *b - Uint::from(self.0) * *a,
57                Uint::from(self.2) * *a - Uint::from(self.3) * *b,
58            )
59        };
60        *a = c;
61        *b = d;
62    }
63
64    /// Applies the matrix to a `u128`.
65    #[inline]
66    #[must_use]
67    pub const fn apply_u128(&self, a: u128, b: u128) -> (u128, u128) {
68        // Intermediate values can overflow but the final result will fit, so we
69        // compute mod 2^128.
70        if self.4 {
71            (
72                (self.0 as u128)
73                    .wrapping_mul(a)
74                    .wrapping_sub((self.1 as u128).wrapping_mul(b)),
75                (self.3 as u128)
76                    .wrapping_mul(b)
77                    .wrapping_sub((self.2 as u128).wrapping_mul(a)),
78            )
79        } else {
80            (
81                (self.1 as u128)
82                    .wrapping_mul(b)
83                    .wrapping_sub((self.0 as u128).wrapping_mul(a)),
84                (self.2 as u128)
85                    .wrapping_mul(a)
86                    .wrapping_sub((self.3 as u128).wrapping_mul(b)),
87            )
88        }
89    }
90
91    /// Compute a Lehmer update matrix from two `Uint`s.
92    ///
93    /// # Panics
94    ///
95    /// Panics if `b > a`.
96    #[inline]
97    #[must_use]
98    pub fn from<const BITS: usize, const LIMBS: usize>(
99        a: Uint<BITS, LIMBS>,
100        b: Uint<BITS, LIMBS>,
101    ) -> Self {
102        assert!(a >= b);
103
104        // Grab the first 128 bits.
105        let s = a.bit_len();
106        if s <= 64 {
107            Self::from_u64(a.try_into().unwrap(), b.try_into().unwrap())
108        } else if s <= 128 {
109            Self::from_u128_prefix(a.try_into().unwrap(), b.try_into().unwrap())
110        } else {
111            let a = a >> (s - 128);
112            let b = b >> (s - 128);
113            Self::from_u128_prefix(a.try_into().unwrap(), b.try_into().unwrap())
114        }
115    }
116
117    /// Compute the Lehmer update matrix for small values.
118    ///
119    /// This is essentially Euclids extended GCD algorithm for 64 bits.
120    ///
121    /// # Panics
122    ///
123    /// Panics if `r0 < r1`.
124    // OPT: Would this be faster using extended binary gcd?
125    // See <https://en.algorithmica.org/hpc/algorithms/gcd>
126    #[inline]
127    #[must_use]
128    pub fn from_u64(mut r0: u64, mut r1: u64) -> Self {
129        debug_assert!(r0 >= r1);
130        if r1 == 0_u64 {
131            return Matrix::IDENTITY;
132        }
133        let mut q00 = 1_u64;
134        let mut q01 = 0_u64;
135        let mut q10 = 0_u64;
136        let mut q11 = 1_u64;
137        loop {
138            // Loop is unrolled once to avoid swapping variables and tracking parity.
139            let q = r0 / r1;
140            r0 -= q * r1;
141            q00 += q * q10;
142            q01 += q * q11;
143            if r0 == 0_u64 {
144                return Matrix(q10, q11, q00, q01, false);
145            }
146            let q = r1 / r0;
147            r1 -= q * r0;
148            q10 += q * q00;
149            q11 += q * q01;
150            if r1 == 0_u64 {
151                return Matrix(q00, q01, q10, q11, true);
152            }
153        }
154    }
155
156    /// Compute the largest valid Lehmer update matrix for a prefix.
157    ///
158    /// Compute the Lehmer update matrix for a0 and a1 such that the matrix is
159    /// valid for any two large integers starting with the bits of a0 and
160    /// a1.
161    ///
162    /// See also `mpn_hgcd2` in GMP, but ours handles the double precision bit
163    /// separately in `lehmer_double`.
164    /// <https://gmplib.org/repo/gmp-6.1/file/tip/mpn/generic/hgcd2.c#l226>
165    ///
166    /// # Panics
167    ///
168    /// Panics if `a0` does not have the highest bit set.
169    /// Panics if `a0 < a1`.
170    #[inline]
171    #[must_use]
172    #[allow(clippy::redundant_else)]
173    #[allow(clippy::cognitive_complexity)] // REFACTOR: Improve
174    pub fn from_u64_prefix(a0: u64, mut a1: u64) -> Self {
175        const LIMIT: u64 = 1_u64 << 32;
176        debug_assert!(a0 >= 1_u64 << 63);
177        debug_assert!(a0 >= a1);
178
179        // Here we do something original: The cofactors undergo identical
180        // operations which makes them a candidate for SIMD instructions.
181        // They also never exceed 32 bit, so we can SWAR them in a single u64.
182        let mut k0 = 1_u64 << 32; // u0 = 1, v0 = 0
183        let mut k1 = 1_u64; // u1 = 0, v1 = 1
184        let mut even = true;
185        if a1 < LIMIT {
186            return Matrix::IDENTITY;
187        }
188
189        // Compute a2
190        let q = a0 / a1;
191        // dbg!(q);
192        let mut a2 = a0 - q * a1;
193        let mut k2 = k0 + q * k1;
194        if a2 < LIMIT {
195            let u2 = k2 >> 32;
196            let v2 = k2 % LIMIT;
197
198            // Test i + 1 (odd)
199            if a2 >= v2 && a1 - a2 >= u2 {
200                return Matrix(0, 1, u2, v2, false);
201            } else {
202                return Matrix::IDENTITY;
203            }
204        }
205
206        // Compute a3
207        let q = a1 / a2;
208        // dbg!(q);
209        let mut a3 = a1 - q * a2;
210        let mut k3 = k1 + q * k2;
211
212        // Loop until a3 < LIMIT, maintaining the last three values
213        // of a and the last four values of k.
214        while a3 >= LIMIT {
215            a1 = a2;
216            a2 = a3;
217            a3 = a1;
218            k0 = k1;
219            k1 = k2;
220            k2 = k3;
221            k3 = k1;
222            debug_assert!(a2 < a3);
223            debug_assert!(a2 > 0);
224            let q = a3 / a2;
225            // dbg!(q);
226            a3 -= q * a2;
227            k3 += q * k2;
228            if a3 < LIMIT {
229                even = false;
230                break;
231            }
232            a1 = a2;
233            a2 = a3;
234            a3 = a1;
235            k0 = k1;
236            k1 = k2;
237            k2 = k3;
238            k3 = k1;
239            debug_assert!(a2 < a3);
240            debug_assert!(a2 > 0);
241            let q = a3 / a2;
242            // dbg!(q);
243            a3 -= q * a2;
244            k3 += q * k2;
245        }
246        // Unpack k into cofactors u and v
247        let u0 = k0 >> 32;
248        let u1 = k1 >> 32;
249        let u2 = k2 >> 32;
250        let u3 = k3 >> 32;
251        let v0 = k0 % LIMIT;
252        let v1 = k1 % LIMIT;
253        let v2 = k2 % LIMIT;
254        let v3 = k3 % LIMIT;
255        debug_assert!(a2 >= LIMIT);
256        debug_assert!(a3 < LIMIT);
257
258        // Use Jebelean's exact condition to determine which outputs are correct.
259        // Statistically, i + 2 should be correct about two-thirds of the time.
260        if even {
261            // Test i + 1 (odd)
262            debug_assert!(a2 >= v2);
263            if a1 - a2 >= u2 + u1 {
264                // Test i + 2 (even)
265                if a3 >= u3 && a2 - a3 >= v3 + v2 {
266                    // Correct value is i + 2
267                    Matrix(u2, v2, u3, v3, true)
268                } else {
269                    // Correct value is i + 1
270                    Matrix(u1, v1, u2, v2, false)
271                }
272            } else {
273                // Correct value is i
274                Matrix(u0, v0, u1, v1, true)
275            }
276        } else {
277            // Test i + 1 (even)
278            debug_assert!(a2 >= u2);
279            if a1 - a2 >= v2 + v1 {
280                // Test i + 2 (odd)
281                if a3 >= v3 && a2 - a3 >= u3 + u2 {
282                    // Correct value is i + 2
283                    Matrix(u2, v2, u3, v3, false)
284                } else {
285                    // Correct value is i + 1
286                    Matrix(u1, v1, u2, v2, true)
287                }
288            } else {
289                // Correct value is i
290                Matrix(u0, v0, u1, v1, false)
291            }
292        }
293    }
294
295    /// Compute the Lehmer update matrix in full 64 bit precision.
296    ///
297    /// Jebelean solves this by starting in double-precission followed
298    /// by single precision once values are small enough.
299    /// Cohen instead runs a single precision round, refreshes the r0 and r1
300    /// values and continues with another single precision round on top.
301    /// Our approach is similar to Cohen, but instead doing the second round
302    /// on the same matrix, we start we a fresh matrix and multiply both in the
303    /// end. This requires 8 additional multiplications, but allows us to use
304    /// the tighter stopping conditions from Jebelean. It also seems the
305    /// simplest out of these solutions.
306    // OPT: We can update r0 and r1 in place. This won't remove the partially
307    // redundant call to lehmer_update, but it reduces memory usage.
308    #[inline]
309    #[must_use]
310    pub fn from_u128_prefix(r0: u128, r1: u128) -> Self {
311        debug_assert!(r0 >= r1);
312        let s = r0.leading_zeros();
313        let r0s = r0 << s;
314        let r1s = r1 << s;
315        let q = Self::from_u64_prefix((r0s >> 64) as u64, (r1s >> 64) as u64);
316        if q == Matrix::IDENTITY {
317            return q;
318        }
319        // We can return q here and have a perfectly valid single-word Lehmer GCD.
320        q
321        // OPT: Fix the below method to get double-word Lehmer GCD.
322
323        // Recompute r0 and r1 and take the high bits.
324        // TODO: Is it safe to do this based on just the u128 prefix?
325        // let (r0, r1) = q.apply_u128(r0, r1);
326        // let s = r0.leading_zeros();
327        // let r0s = r0 << s;
328        // let r1s = r1 << s;
329        // let qn = Self::from_u64_prefix((r0s >> 64) as u64, (r1s >> 64) as
330        // u64);
331
332        // // Multiply matrices qn * q
333        // qn.compose(q)
334    }
335}
336
337#[cfg(test)]
338#[allow(clippy::cast_lossless)]
339#[allow(clippy::many_single_char_names)]
340mod tests {
341    use super::*;
342    use crate::{const_for, nlimbs};
343    use core::{
344        cmp::{max, min},
345        mem::swap,
346        str::FromStr,
347    };
348    use proptest::{proptest, test_runner::Config};
349
350    fn gcd(mut a: u128, mut b: u128) -> u128 {
351        while b != 0 {
352            a %= b;
353            swap(&mut a, &mut b);
354        }
355        a
356    }
357
358    fn gcd_uint<const BITS: usize, const LIMBS: usize>(
359        mut a: Uint<BITS, LIMBS>,
360        mut b: Uint<BITS, LIMBS>,
361    ) -> Uint<BITS, LIMBS> {
362        while b != Uint::ZERO {
363            a %= b;
364            swap(&mut a, &mut b);
365        }
366        a
367    }
368
369    #[test]
370    fn test_from_u64_example() {
371        let (a, b) = (252, 105);
372        let m = Matrix::from_u64(a, b);
373        assert_eq!(m, Matrix(2, 5, 5, 12, false));
374        let (a, b) = m.apply_u128(a as u128, b as u128);
375        assert_eq!(a, 21);
376        assert_eq!(b, 0);
377    }
378
379    #[test]
380    fn test_from_u64() {
381        proptest!(|(a: u64, b: u64)| {
382            let (a, b) = (max(a,b), min(a,b));
383            let m = Matrix::from_u64(a, b);
384            let (c, d) = m.apply_u128(a as u128, b as u128);
385            assert!(c >= d);
386            assert_eq!(c, gcd(a as u128, b as u128));
387            assert_eq!(d, 0);
388        });
389    }
390
391    #[test]
392    fn test_from_u64_prefix() {
393        proptest!(|(a: u128, b: u128)| {
394            // Prepare input
395            let (a, b) = (max(a,b), min(a,b));
396            let s = a.leading_zeros();
397            let (sa, sb) = (a << s, b << s);
398
399            let m = Matrix::from_u64_prefix((sa >> 64) as u64, (sb >> 64) as u64);
400            let (c, d) = m.apply_u128(a, b);
401            assert!(c >= d);
402            if m == Matrix::IDENTITY {
403                assert_eq!(c, a);
404                assert_eq!(d, b);
405            } else {
406                assert!(c <= a);
407                assert!(d < b);
408                assert_eq!(gcd(a, b), gcd(c, d));
409            }
410        });
411    }
412
413    fn test_form_uint_one<const BITS: usize, const LIMBS: usize>(
414        a: Uint<BITS, LIMBS>,
415        b: Uint<BITS, LIMBS>,
416    ) {
417        let (a, b) = (max(a, b), min(a, b));
418        let m = Matrix::from(a, b);
419        let (mut c, mut d) = (a, b);
420        m.apply(&mut c, &mut d);
421        assert!(c >= d);
422        if m == Matrix::IDENTITY {
423            assert_eq!(c, a);
424            assert_eq!(d, b);
425        } else {
426            assert!(c <= a);
427            assert!(d < b);
428            assert_eq!(gcd_uint(a, b), gcd_uint(c, d));
429        }
430    }
431
432    #[test]
433    fn test_from_uint_cases() {
434        // This case fails with the double-word version above.
435        type U129 = Uint<129, 3>;
436        test_form_uint_one(
437            U129::from_str("0x01de6ef6f3caa963a548d7a411b05b9988").unwrap(),
438            U129::from_str("0x006d7c4641f88b729a97889164dd8d07db").unwrap(),
439        );
440    }
441
442    #[test]
443    #[allow(clippy::absurd_extreme_comparisons)] // Generated code
444    fn test_from_uint_proptest() {
445        const_for!(BITS in SIZES {
446            const LIMBS: usize = nlimbs(BITS);
447            type U = Uint<BITS, LIMBS>;
448            let config = Config { cases: 10, ..Default::default() };
449            proptest!(config, |(a: U, b: U)| {
450                test_form_uint_one(a, b);
451            });
452        });
453    }
454}