ruint/algorithms/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//! ⚠️ Collection of bignum algorithms.
//!
//! **Warning.** Most functions in this module are currently not considered part
//! of the stable API and may be changed or removed in future minor releases.

#![allow(missing_docs)] // TODO: document algorithms

use core::cmp::Ordering;

mod add;
pub mod div;
mod gcd;
mod mul;
mod mul_redc;
mod ops;
mod shift;

pub use self::{
    add::{adc_n, sbb_n},
    div::div,
    gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix},
    mul::{add_nx1, addmul, addmul_n, addmul_nx1, mul_nx1, submul_nx1},
    mul_redc::{mul_redc, square_redc},
    ops::{adc, sbb},
    shift::{shift_left_small, shift_right_small},
};

trait DoubleWord<T>: Sized + Copy {
    fn join(high: T, low: T) -> Self;
    fn add(a: T, b: T) -> Self;
    fn mul(a: T, b: T) -> Self;
    fn muladd(a: T, b: T, c: T) -> Self;
    fn muladd2(a: T, b: T, c: T, d: T) -> Self;

    fn high(self) -> T;
    fn low(self) -> T;
    fn split(self) -> (T, T);
}

impl DoubleWord<u64> for u128 {
    #[inline(always)]
    fn join(high: u64, low: u64) -> Self {
        (Self::from(high) << 64) | Self::from(low)
    }

    /// Computes `a + b` as a 128-bit value.
    #[inline(always)]
    fn add(a: u64, b: u64) -> Self {
        Self::from(a) + Self::from(b)
    }

    /// Computes `a * b` as a 128-bit value.
    #[inline(always)]
    fn mul(a: u64, b: u64) -> Self {
        Self::from(a) * Self::from(b)
    }

    /// Computes `a * b + c` as a 128-bit value. Note that this can not
    /// overflow.
    #[inline(always)]
    fn muladd(a: u64, b: u64, c: u64) -> Self {
        Self::from(a) * Self::from(b) + Self::from(c)
    }

    /// Computes `a * b + c + d` as a 128-bit value. Note that this can not
    /// overflow.
    #[inline(always)]
    fn muladd2(a: u64, b: u64, c: u64, d: u64) -> Self {
        Self::from(a) * Self::from(b) + Self::from(c) + Self::from(d)
    }

    #[inline(always)]
    fn high(self) -> u64 {
        (self >> 64) as u64
    }

    #[inline(always)]
    #[allow(clippy::cast_possible_truncation)]
    fn low(self) -> u64 {
        self as u64
    }

    #[inline(always)]
    fn split(self) -> (u64, u64) {
        (self.low(), self.high())
    }
}

/// Compare two `u64` slices in reverse order.
#[inline(always)]
#[must_use]
pub fn cmp(left: &[u64], right: &[u64]) -> Ordering {
    let l = core::cmp::min(left.len(), right.len());

    // Slice to the loop iteration range to enable bound check
    // elimination in the compiler
    let lhs = &left[..l];
    let rhs = &right[..l];

    for i in (0..l).rev() {
        match i8::from(lhs[i] > rhs[i]) - i8::from(lhs[i] < rhs[i]) {
            -1 => return Ordering::Less,
            0 => {}
            1 => return Ordering::Greater,
            _ => unsafe { core::hint::unreachable_unchecked() },
        }

        // Equivalent to:
        // match lhs[i].cmp(&rhs[i]) {
        //     Ordering::Equal => {}
        //     non_eq => return non_eq,
        // }
    }

    left.len().cmp(&right.len())
}

// Helper while [Rust#85532](https://github.com/rust-lang/rust/issues/85532) stabilizes.
#[inline]
#[must_use]
pub const fn carrying_add(lhs: u64, rhs: u64, carry: bool) -> (u64, bool) {
    let (result, carry_1) = lhs.overflowing_add(rhs);
    let (result, carry_2) = result.overflowing_add(carry as u64);
    (result, carry_1 | carry_2)
}

// Helper while [Rust#85532](https://github.com/rust-lang/rust/issues/85532) stabilizes.
#[inline]
#[must_use]
pub const fn borrowing_sub(lhs: u64, rhs: u64, borrow: bool) -> (u64, bool) {
    let (result, borrow_1) = lhs.overflowing_sub(rhs);
    let (result, borrow_2) = result.overflowing_sub(borrow as u64);
    (result, borrow_1 | borrow_2)
}