ruint/
pow.rs

1use crate::Uint;
2
3impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
4    /// Raises self to the power of `exp`.
5    ///
6    /// Returns None if the result would overflow.
7    #[inline]
8    #[must_use]
9    pub fn checked_pow(self, exp: Self) -> Option<Self> {
10        match self.overflowing_pow(exp) {
11            (x, false) => Some(x),
12            (_, true) => None,
13        }
14    }
15
16    /// Raises self to the power of `exp` and if the result would overflow.
17    ///
18    /// # Examples
19    ///
20    /// ```
21    /// # use ruint::{Uint, uint};
22    /// # uint!{
23    /// assert_eq!(
24    ///     36_U64.overflowing_pow(12_U64),
25    ///     (0x41c21cb8e1000000_U64, false)
26    /// );
27    /// assert_eq!(
28    ///     36_U64.overflowing_pow(13_U64),
29    ///     (0x3f4c09ffa4000000_U64, true)
30    /// );
31    /// assert_eq!(
32    ///     36_U68.overflowing_pow(13_U68),
33    ///     (0x093f4c09ffa4000000_U68, false)
34    /// );
35    /// assert_eq!(16_U65.overflowing_pow(32_U65), (0_U65, true));
36    /// # }
37    /// ```
38    /// Small cases:
39    /// ```
40    /// # use ruint::{Uint, uint};
41    /// # uint!{
42    /// assert_eq!(0_U0.overflowing_pow(0_U0), (0_U0, false));
43    /// assert_eq!(0_U1.overflowing_pow(0_U1), (1_U1, false));
44    /// assert_eq!(0_U1.overflowing_pow(1_U1), (0_U1, false));
45    /// assert_eq!(1_U1.overflowing_pow(0_U1), (1_U1, false));
46    /// assert_eq!(1_U1.overflowing_pow(1_U1), (1_U1, false));
47    /// # }
48    /// ```
49    #[inline]
50    #[must_use]
51    pub fn overflowing_pow(mut self, mut exp: Self) -> (Self, bool) {
52        if BITS == 0 {
53            return (self, false);
54        }
55
56        // Exponentiation by squaring
57        let mut overflow = false;
58        let mut base_overflow = false;
59        let mut result = Self::from(1);
60        while exp != Self::ZERO {
61            // Multiply by base
62            if exp.bit(0) {
63                let (r, o) = result.overflowing_mul(self);
64                result = r;
65                overflow |= o | base_overflow;
66            }
67
68            // Square base
69            let (s, o) = self.overflowing_mul(self);
70            self = s;
71            base_overflow |= o;
72            exp >>= 1;
73        }
74        (result, overflow)
75    }
76
77    /// Raises self to the power of `exp`, wrapping around on overflow.
78    #[inline]
79    #[must_use]
80    pub fn pow(self, exp: Self) -> Self {
81        self.wrapping_pow(exp)
82    }
83
84    /// Raises self to the power of `exp`, saturating on overflow.
85    #[inline]
86    #[must_use]
87    pub fn saturating_pow(self, exp: Self) -> Self {
88        match self.overflowing_pow(exp) {
89            (x, false) => x,
90            (_, true) => Self::MAX,
91        }
92    }
93
94    /// Raises self to the power of `exp`, wrapping around on overflow.
95    #[inline]
96    #[must_use]
97    pub fn wrapping_pow(mut self, mut exp: Self) -> Self {
98        if BITS == 0 {
99            return self;
100        }
101
102        // Exponentiation by squaring
103        let mut result = Self::from(1);
104        while exp != Self::ZERO {
105            // Multiply by base
106            if exp.bit(0) {
107                result = result.wrapping_mul(self);
108            }
109
110            // Square base
111            self = self.wrapping_mul(self);
112            exp >>= 1;
113        }
114        result
115    }
116
117    /// Construct from double precision binary logarithm.
118    ///
119    /// # Examples
120    ///
121    /// ```
122    /// # use ruint::{Uint, uint, aliases::*};
123    /// # uint!{
124    /// assert_eq!(U64::approx_pow2(-2.0), Some(0_U64));
125    /// assert_eq!(U64::approx_pow2(-1.0), Some(1_U64));
126    /// assert_eq!(U64::approx_pow2(0.0), Some(1_U64));
127    /// assert_eq!(U64::approx_pow2(1.0), Some(2_U64));
128    /// assert_eq!(U64::approx_pow2(1.6), Some(3_U64));
129    /// assert_eq!(U64::approx_pow2(2.0), Some(4_U64));
130    /// assert_eq!(U64::approx_pow2(64.0), None);
131    /// assert_eq!(U64::approx_pow2(10.385), Some(1337_U64));
132    /// # }
133    /// ```
134    #[cfg(feature = "std")]
135    #[must_use]
136    #[allow(clippy::missing_inline_in_public_items)]
137    pub fn approx_pow2(exp: f64) -> Option<Self> {
138        const LN2_1P5: f64 = 0.584_962_500_721_156_2_f64;
139        const EXP2_63: f64 = 9_223_372_036_854_775_808_f64;
140
141        // FEATURE: Round negative to zero.
142        #[allow(clippy::cast_precision_loss)] // Self::BITS ~< 2^52 and so fits f64.
143        if exp < LN2_1P5 {
144            if exp < -1.0 {
145                return Some(Self::ZERO);
146            }
147            return Self::try_from(1).ok();
148        }
149        #[allow(clippy::cast_precision_loss)]
150        if exp > Self::BITS as f64 {
151            return None;
152        }
153
154        // Since exp < BITS, it has an integer and fractional part.
155        #[allow(clippy::cast_possible_truncation)] // exp <= BITS <= usize::MAX.
156        #[allow(clippy::cast_sign_loss)] // exp >= 0.
157        let shift = exp.trunc() as usize;
158        let fract = exp.fract();
159
160        // Compute the leading 64 bits
161        // Since `fract < 1.0` we have `fract.exp2() < 2`, so we can rescale by
162        // 2^63 and cast to u64.
163        #[allow(clippy::cast_possible_truncation)] // fract < 1.0
164        #[allow(clippy::cast_sign_loss)] // fract >= 0.
165        let bits = (fract.exp2() * EXP2_63) as u64;
166        // Note: If `fract` is zero this will result in `u64::MAX`.
167
168        if shift >= 63 {
169            // OPT: A dedicated function avoiding full-sized shift.
170            Some(Self::try_from(bits).ok()?.checked_shl(shift - 63)?)
171        } else {
172            let shift = 63 - shift;
173            // Divide `bits` by `2^shift`, rounding to nearest.
174            let bits = (bits >> shift) + ((bits >> (shift - 1)) & 1);
175            Self::try_from(bits).ok()
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::{const_for, nlimbs};
184    use core::iter::repeat;
185    use proptest::proptest;
186
187    #[test]
188    fn test_pow2_shl() {
189        const_for!(BITS in NON_ZERO if (BITS >= 2) {
190            const LIMBS: usize = nlimbs(BITS);
191            type U = Uint<BITS, LIMBS>;
192            proptest!(|(e in 0..=BITS+1)| {
193                assert_eq!(U::from(2).pow(U::from(e)), U::from(1) << e);
194            });
195        });
196    }
197
198    #[test]
199    fn test_pow_product() {
200        const_for!(BITS in NON_ZERO if (BITS >= 64) {
201            const LIMBS: usize = nlimbs(BITS);
202            type U = Uint<BITS, LIMBS>;
203            proptest!(|(b in 2_u64..100, e in 0_usize..100)| {
204                let b = U::from(b);
205                let prod = repeat(b).take(e).product();
206                assert_eq!(b.pow(U::from(e)), prod);
207            });
208        });
209    }
210}