ruint/algorithms/gcd/
mod.rs

1#![allow(clippy::module_name_repetitions)]
2
3// TODO: https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md
4
5// TODO: Make these algorithms work on limb slices.
6mod matrix;
7
8pub use self::matrix::Matrix as LehmerMatrix;
9use crate::Uint;
10use core::mem::swap;
11
12/// ⚠️ Lehmer's GCD algorithms.
13///
14/// **Warning.** This struct is not part of the stable API.
15///
16/// See [`gcd_extended`] for documentation.
17#[inline]
18#[must_use]
19pub fn gcd<const BITS: usize, const LIMBS: usize>(
20    mut a: Uint<BITS, LIMBS>,
21    mut b: Uint<BITS, LIMBS>,
22) -> Uint<BITS, LIMBS> {
23    if b > a {
24        swap(&mut a, &mut b);
25    }
26    while b != Uint::ZERO {
27        debug_assert!(a >= b);
28        let m = LehmerMatrix::from(a, b);
29        if m == LehmerMatrix::IDENTITY {
30            // Lehmer step failed to find a factor, which happens when
31            // the factor is very large. We do a regular Euclidean step, which
32            // will make a lot of progress since `q` will be large.
33            a %= b;
34            swap(&mut a, &mut b);
35        } else {
36            m.apply(&mut a, &mut b);
37        }
38    }
39    a
40}
41
42/// ⚠️ Lehmer's extended GCD.
43///
44/// **Warning.** This struct is not part of the stable API.
45///
46/// Returns `(gcd, x, y, sign)` such that `gcd = a * x + b * y`.
47///
48/// # Algorithm
49///
50/// A variation of Euclids algorithm where repeated 64-bit approximations are
51/// used to make rapid progress on.
52///
53/// See Jebelean (1994) "A Double-Digit Lehmer-Euclid Algorithm for Finding the
54/// GCD of Long Integers".
55///
56/// The function `lehmer_double` takes two `U256`'s and returns a 64-bit matrix.
57///
58/// The function `lehmer_update` updates state variables using this matrix. If
59/// the matrix makes no progress (because 64 bit precision is not enough) a full
60/// precision Euclid step is done, but this happens rarely.
61///
62/// See also `mpn_gcdext_lehmer_n` in GMP.
63/// <https://gmplib.org/repo/gmp-6.1/file/tip/mpn/generic/gcdext_lehmer.c#l146>
64#[inline]
65#[must_use]
66pub fn gcd_extended<const BITS: usize, const LIMBS: usize>(
67    mut a: Uint<BITS, LIMBS>,
68    mut b: Uint<BITS, LIMBS>,
69) -> (
70    Uint<BITS, LIMBS>,
71    Uint<BITS, LIMBS>,
72    Uint<BITS, LIMBS>,
73    bool,
74) {
75    if BITS == 0 {
76        return (Uint::ZERO, Uint::ZERO, Uint::ZERO, false);
77    }
78    let swapped = a < b;
79    if swapped {
80        swap(&mut a, &mut b);
81    }
82
83    // Initialize state matrix to identity.
84    let mut s0 = Uint::from(1);
85    let mut s1 = Uint::ZERO;
86    let mut t0 = Uint::ZERO;
87    let mut t1 = Uint::from(1);
88    let mut even = true;
89    while b != Uint::ZERO {
90        debug_assert!(a >= b);
91        let m = LehmerMatrix::from(a, b);
92        if m == LehmerMatrix::IDENTITY {
93            // Lehmer step failed to find a factor, which happens when
94            // the factor is very large. We do a regular Euclidean step, which
95            // will make a lot of progress since `q` will be large.
96            let q = a / b;
97            a -= q * b;
98            swap(&mut a, &mut b);
99            s0 -= q * s1;
100            swap(&mut s0, &mut s1);
101            t0 -= q * t1;
102            swap(&mut t0, &mut t1);
103            even = !even;
104        } else {
105            m.apply(&mut a, &mut b);
106            m.apply(&mut s0, &mut s1);
107            m.apply(&mut t0, &mut t1);
108            even ^= !m.4;
109        }
110    }
111    // TODO: Compute using absolute value instead of patching sign.
112    if even {
113        // t negative
114        t0 = Uint::ZERO - t0;
115    } else {
116        // s negative
117        s0 = Uint::ZERO - s0;
118    }
119    if swapped {
120        swap(&mut s0, &mut t0);
121        even = !even;
122    }
123    (a, s0, t0, even)
124}
125
126/// ⚠️ Modular inversion using extended GCD.
127///
128/// It uses the Bezout identity
129///
130/// ```text
131///    a * modulus + b * num = gcd(modulus, num)
132/// ````
133///
134/// where `a` and `b` are the cofactors from the extended Euclidean algorithm.
135/// A modular inverse only exists if `modulus` and `num` are coprime, in which
136/// case `gcd(modulus, num)` is one. Reducing both sides by the modulus then
137/// results in the equation `b * num = 1 (mod modulus)`. In other words, the
138/// cofactor `b` is the modular inverse of `num`.
139///
140/// It differs from `gcd_extended` in that it only computes the required
141/// cofactor, and returns `None` if the GCD is not one (i.e. when `num` does
142/// not have an inverse).
143#[inline]
144#[must_use]
145pub fn inv_mod<const BITS: usize, const LIMBS: usize>(
146    num: Uint<BITS, LIMBS>,
147    modulus: Uint<BITS, LIMBS>,
148) -> Option<Uint<BITS, LIMBS>> {
149    if BITS == 0 || modulus == Uint::ZERO {
150        return None;
151    }
152    let mut a = modulus;
153    let mut b = num;
154    if b >= a {
155        b %= a;
156    }
157    if b == Uint::ZERO {
158        return None;
159    }
160
161    let mut t0 = Uint::ZERO;
162    let mut t1 = Uint::from(1);
163    let mut even = true;
164    while b != Uint::ZERO {
165        debug_assert!(a >= b);
166        let m = LehmerMatrix::from(a, b);
167        if m == LehmerMatrix::IDENTITY {
168            // Lehmer step failed to find a factor, which happens when
169            // the factor is very large. We do a regular Euclidean step, which
170            // will make a lot of progress since `q` will be large.
171            let q = a / b;
172            a -= q * b;
173            swap(&mut a, &mut b);
174            t0 -= q * t1;
175            swap(&mut t0, &mut t1);
176            even = !even;
177        } else {
178            m.apply(&mut a, &mut b);
179            m.apply(&mut t0, &mut t1);
180            even ^= !m.4;
181        }
182    }
183    if a == Uint::from(1) {
184        // When `even` t0 is negative and in twos-complement form
185        Some(if even { modulus + t0 } else { t0 })
186    } else {
187        None
188    }
189}
190
191#[cfg(test)]
192#[allow(clippy::cast_lossless)]
193mod tests {
194    use super::*;
195    use crate::{const_for, nlimbs};
196    use proptest::{proptest, test_runner::Config};
197
198    #[test]
199    fn test_gcd_one() {
200        use core::str::FromStr;
201        const BITS: usize = 129;
202        const LIMBS: usize = nlimbs(BITS);
203        type U = Uint<BITS, LIMBS>;
204        let a = U::from_str("0x006d7c4641f88b729a97889164dd8d07db").unwrap();
205        let b = U::from_str("0x01de6ef6f3caa963a548d7a411b05b9988").unwrap();
206        assert_eq!(gcd(a, b), gcd_ref(a, b));
207    }
208
209    // Reference implementation
210    fn gcd_ref<const BITS: usize, const LIMBS: usize>(
211        mut a: Uint<BITS, LIMBS>,
212        mut b: Uint<BITS, LIMBS>,
213    ) -> Uint<BITS, LIMBS> {
214        while b != Uint::ZERO {
215            a %= b;
216            swap(&mut a, &mut b);
217        }
218        a
219    }
220
221    #[test]
222    #[allow(clippy::absurd_extreme_comparisons)] // Generated code
223    fn test_gcd() {
224        const_for!(BITS in SIZES {
225            const LIMBS: usize = nlimbs(BITS);
226            type U = Uint<BITS, LIMBS>;
227            let config = Config { cases: 10, ..Default::default()};
228            proptest!(config, |(a: U, b: U)| {
229                assert_eq!(gcd(a, b), gcd_ref(a, b));
230            });
231        });
232    }
233
234    #[test]
235    fn test_gcd_extended() {
236        const_for!(BITS in SIZES {
237            const LIMBS: usize = nlimbs(BITS);
238            type U = Uint<BITS, LIMBS>;
239            let config = Config { cases: 5, ..Default::default() };
240            proptest!(config, |(a: U, b: U)| {
241                let (g, x, y, sign) = gcd_extended(a, b);
242                assert_eq!(g, gcd_ref(a, b));
243                if sign {
244                    assert_eq!(a * x - b * y, g);
245                } else {
246                    assert_eq!(b * y - a * x, g);
247                }
248            });
249        });
250    }
251}