ruint/
bits.rs

1use core::ops::{
2    BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, ShlAssign, Shr,
3    ShrAssign,
4};
5
6use crate::Uint;
7
8impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
9    /// Returns whether a specific bit is set.
10    ///
11    /// Returns `false` if `index` exceeds the bit width of the number.
12    #[must_use]
13    #[inline]
14    pub const fn bit(&self, index: usize) -> bool {
15        if index >= BITS {
16            return false;
17        }
18        let (limbs, bits) = (index / 64, index % 64);
19        self.limbs[limbs] & (1 << bits) != 0
20    }
21
22    /// Sets a specific bit to a value.
23    #[inline]
24    pub fn set_bit(&mut self, index: usize, value: bool) {
25        if index >= BITS {
26            return;
27        }
28        let (limbs, bits) = (index / 64, index % 64);
29        if value {
30            self.limbs[limbs] |= 1 << bits;
31        } else {
32            self.limbs[limbs] &= !(1 << bits);
33        }
34    }
35
36    /// Returns a specific byte. The byte at index `0` is the least significant
37    /// byte (little endian).
38    ///
39    /// # Panics
40    ///
41    /// Panics if `index` is greater than or equal to the byte width of the
42    /// number.
43    ///
44    /// # Examples
45    ///
46    /// ```
47    /// # use ruint::uint;
48    /// let x = uint!(0x1234567890_U64);
49    /// let bytes = [
50    ///     x.byte(0), // 0x90
51    ///     x.byte(1), // 0x78
52    ///     x.byte(2), // 0x56
53    ///     x.byte(3), // 0x34
54    ///     x.byte(4), // 0x12
55    ///     x.byte(5), // 0x00
56    ///     x.byte(6), // 0x00
57    ///     x.byte(7), // 0x00
58    /// ];
59    /// assert_eq!(bytes, x.to_le_bytes());
60    /// ```
61    ///
62    /// Panics if out of range.
63    ///
64    /// ```should_panic
65    /// # use ruint::uint;
66    /// let x = uint!(0x1234567890_U64);
67    /// let _ = x.byte(8);
68    /// ```
69    #[inline]
70    #[must_use]
71    #[track_caller]
72    pub const fn byte(&self, index: usize) -> u8 {
73        #[cfg(target_endian = "little")]
74        {
75            self.as_le_slice()[index]
76        }
77
78        #[cfg(target_endian = "big")]
79        #[allow(clippy::cast_possible_truncation)] // intentional
80        {
81            (self.limbs[index / 8] >> ((index % 8) * 8)) as u8
82        }
83    }
84
85    /// Returns a specific byte, or `None` if `index` is out of range. The byte
86    /// at index `0` is the least significant byte (little endian).
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// # use ruint::uint;
92    /// let x = uint!(0x1234567890_U64);
93    /// assert_eq!(x.checked_byte(0), Some(0x90));
94    /// assert_eq!(x.checked_byte(7), Some(0x00));
95    /// // Out of range
96    /// assert_eq!(x.checked_byte(8), None);
97    /// ```
98    #[inline]
99    #[must_use]
100    pub const fn checked_byte(&self, index: usize) -> Option<u8> {
101        if index < Self::BYTES {
102            Some(self.byte(index))
103        } else {
104            None
105        }
106    }
107
108    /// Reverses the order of bits in the integer. The least significant bit
109    /// becomes the most significant bit, second least-significant bit becomes
110    /// second most-significant bit, etc.
111    #[inline]
112    #[must_use]
113    pub fn reverse_bits(mut self) -> Self {
114        self.limbs.reverse();
115        for limb in &mut self.limbs {
116            *limb = limb.reverse_bits();
117        }
118        if BITS % 64 != 0 {
119            self >>= 64 - BITS % 64;
120        }
121        self
122    }
123
124    /// Inverts all the bits in the integer.
125    #[inline]
126    #[must_use]
127    pub const fn not(mut self) -> Self {
128        if BITS == 0 {
129            return Self::ZERO;
130        }
131
132        let mut i = 0;
133        while i < LIMBS {
134            self.limbs[i] = !self.limbs[i];
135            i += 1;
136        }
137
138        self.masked()
139    }
140
141    /// Returns the number of leading zeros in the binary representation of
142    /// `self`.
143    #[inline]
144    #[must_use]
145    pub const fn leading_zeros(&self) -> usize {
146        let mut i = LIMBS;
147        while i > 0 {
148            i -= 1;
149            if self.limbs[i] != 0 {
150                let n = LIMBS - 1 - i;
151                let skipped = n * 64;
152                let fixed = Self::MASK.leading_zeros() as usize;
153                let top = self.limbs[i].leading_zeros() as usize;
154                return skipped + top - fixed;
155            }
156        }
157
158        BITS
159    }
160
161    /// Returns the number of leading ones in the binary representation of
162    /// `self`.
163    #[inline]
164    #[must_use]
165    pub const fn leading_ones(&self) -> usize {
166        Self::not(*self).leading_zeros()
167    }
168
169    /// Returns the number of trailing zeros in the binary representation of
170    /// `self`.
171    #[inline]
172    #[must_use]
173    pub fn trailing_zeros(&self) -> usize {
174        self.as_limbs()
175            .iter()
176            .position(|&limb| limb != 0)
177            .map_or(BITS, |n| {
178                n * 64 + self.as_limbs()[n].trailing_zeros() as usize
179            })
180    }
181
182    /// Returns the number of trailing ones in the binary representation of
183    /// `self`.
184    #[inline]
185    #[must_use]
186    pub fn trailing_ones(&self) -> usize {
187        self.as_limbs()
188            .iter()
189            .position(|&limb| limb != u64::MAX)
190            .map_or(BITS, |n| {
191                n * 64 + self.as_limbs()[n].trailing_ones() as usize
192            })
193    }
194
195    /// Returns the number of ones in the binary representation of `self`.
196    #[inline]
197    #[must_use]
198    pub const fn count_ones(&self) -> usize {
199        let mut total = 0;
200
201        let mut i = 0;
202        while i < LIMBS {
203            total += self.limbs[i].count_ones() as usize;
204            i += 1;
205        }
206
207        total
208    }
209
210    /// Returns the number of zeros in the binary representation of `self`.
211    #[must_use]
212    #[inline]
213    pub const fn count_zeros(&self) -> usize {
214        BITS - self.count_ones()
215    }
216
217    /// Returns the dynamic length of this number in bits, ignoring leading
218    /// zeros.
219    ///
220    /// For the maximum length of the type, use [`Uint::BITS`](Self::BITS).
221    #[must_use]
222    #[inline]
223    pub const fn bit_len(&self) -> usize {
224        BITS - self.leading_zeros()
225    }
226
227    /// Returns the dynamic length of this number in bytes, ignoring leading
228    /// zeros.
229    ///
230    /// For the maximum length of the type, use [`Uint::BYTES`](Self::BYTES).
231    #[must_use]
232    #[inline]
233    pub const fn byte_len(&self) -> usize {
234        (self.bit_len() + 7) / 8
235    }
236
237    /// Returns the most significant 64 bits of the number and the exponent.
238    ///
239    /// Given return value $(\mathtt{bits}, \mathtt{exponent})$, the `self` can
240    /// be approximated as
241    ///
242    /// $$
243    /// \mathtt{self} ≈ \mathtt{bits} ⋅ 2^\mathtt{exponent}
244    /// $$
245    ///
246    /// If `self` is $<≥> 2^{63}$, then `exponent` will be zero and `bits` will
247    /// have leading zeros.
248    #[inline]
249    #[must_use]
250    pub fn most_significant_bits(&self) -> (u64, usize) {
251        let first_set_limb = self
252            .as_limbs()
253            .iter()
254            .rposition(|&limb| limb != 0)
255            .unwrap_or(0);
256        if first_set_limb == 0 {
257            (self.as_limbs().first().copied().unwrap_or(0), 0)
258        } else {
259            let hi = self.as_limbs()[first_set_limb];
260            let lo = self.as_limbs()[first_set_limb - 1];
261            let leading_zeros = hi.leading_zeros();
262            let bits = if leading_zeros > 0 {
263                (hi << leading_zeros) | (lo >> (64 - leading_zeros))
264            } else {
265                hi
266            };
267            let exponent = first_set_limb * 64 - leading_zeros as usize;
268            (bits, exponent)
269        }
270    }
271
272    /// Checked left shift by `rhs` bits.
273    ///
274    /// Returns $\mathtt{self} ⋅ 2^{\mathtt{rhs}}$ or [`None`] if the result
275    /// would $≥ 2^{\mathtt{BITS}}$. That is, it returns [`None`] if the bits
276    /// shifted out would be non-zero.
277    ///
278    /// Note: This differs from [`u64::checked_shl`] which returns `None` if the
279    /// shift is larger than BITS (which is IMHO not very useful).
280    #[inline(always)]
281    #[must_use]
282    pub fn checked_shl(self, rhs: usize) -> Option<Self> {
283        match self.overflowing_shl(rhs) {
284            (value, false) => Some(value),
285            _ => None,
286        }
287    }
288
289    /// Saturating left shift by `rhs` bits.
290    ///
291    /// Returns $\mathtt{self} ⋅ 2^{\mathtt{rhs}}$ or [`Uint::MAX`] if the
292    /// result would $≥ 2^{\mathtt{BITS}}$. That is, it returns
293    /// [`Uint::MAX`] if the bits shifted out would be non-zero.
294    #[inline(always)]
295    #[must_use]
296    pub fn saturating_shl(self, rhs: usize) -> Self {
297        match self.overflowing_shl(rhs) {
298            (value, false) => value,
299            _ => Self::MAX,
300        }
301    }
302
303    /// Left shift by `rhs` bits with overflow detection.
304    ///
305    /// Returns $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$.
306    /// If the product is $≥ 2^{\mathtt{BITS}}$ it returns `true`. That is, it
307    /// returns true if the bits shifted out are non-zero.
308    ///
309    /// Note: This differs from [`u64::overflowing_shl`] which returns `true` if
310    /// the shift is larger than `BITS` (which is IMHO not very useful).
311    #[inline]
312    #[must_use]
313    pub fn overflowing_shl(self, rhs: usize) -> (Self, bool) {
314        let (limbs, bits) = (rhs / 64, rhs % 64);
315        if limbs >= LIMBS {
316            return (Self::ZERO, self != Self::ZERO);
317        }
318
319        let word_bits = 64;
320        let mut r = Self::ZERO;
321        let mut carry = 0;
322        for i in 0..Self::LIMBS - limbs {
323            let x = self.limbs[i];
324            r.limbs[i + limbs] = (x << bits) | carry;
325            carry = (x >> (word_bits - bits - 1)) >> 1;
326        }
327        r.apply_mask();
328        (r, carry != 0)
329    }
330
331    /// Left shift by `rhs` bits.
332    ///
333    /// Returns $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$.
334    ///
335    /// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs`
336    /// by `BITS` (which is IMHO not very useful).
337    #[cfg(not(target_os = "zkvm"))]
338    #[inline(always)]
339    #[must_use]
340    pub fn wrapping_shl(self, rhs: usize) -> Self {
341        self.overflowing_shl(rhs).0
342    }
343
344    /// Left shift by `rhs` bits.
345    ///
346    /// Returns $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$.
347    ///
348    /// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs`
349    /// by `BITS` (which is IMHO not very useful).
350    #[cfg(target_os = "zkvm")]
351    #[inline(always)]
352    #[must_use]
353    pub fn wrapping_shl(mut self, rhs: usize) -> Self {
354        if BITS == 256 {
355            if rhs >= 256 {
356                return Self::ZERO;
357            }
358            use crate::support::zkvm::zkvm_u256_wrapping_shl_impl;
359            let rhs = rhs as u64;
360            unsafe {
361                zkvm_u256_wrapping_shl_impl(
362                    self.limbs.as_mut_ptr() as *mut u8,
363                    self.limbs.as_ptr() as *const u8,
364                    [rhs].as_ptr() as *const u8,
365                );
366            }
367            self
368        } else {
369            self.overflowing_shl(rhs).0
370        }
371    }
372
373    /// Checked right shift by `rhs` bits.
374    ///
375    /// $$
376    /// \frac{\mathtt{self}}{2^{\mathtt{rhs}}}
377    /// $$
378    ///
379    /// Returns the above or [`None`] if the division is not exact. This is the
380    /// same as
381    ///
382    /// Note: This differs from [`u64::checked_shr`] which returns `None` if the
383    /// shift is larger than BITS (which is IMHO not very useful).
384    #[inline(always)]
385    #[must_use]
386    pub fn checked_shr(self, rhs: usize) -> Option<Self> {
387        match self.overflowing_shr(rhs) {
388            (value, false) => Some(value),
389            _ => None,
390        }
391    }
392
393    /// Right shift by `rhs` bits with underflow detection.
394    ///
395    /// $$
396    /// \floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}
397    /// $$
398    ///
399    /// Returns the above and `false` if the division was exact, and `true` if
400    /// it was rounded down. This is the same as non-zero bits being shifted
401    /// out.
402    ///
403    /// Note: This differs from [`u64::overflowing_shr`] which returns `true` if
404    /// the shift is larger than `BITS` (which is IMHO not very useful).
405    #[inline]
406    #[must_use]
407    pub fn overflowing_shr(self, rhs: usize) -> (Self, bool) {
408        let (limbs, bits) = (rhs / 64, rhs % 64);
409        if limbs >= LIMBS {
410            return (Self::ZERO, self != Self::ZERO);
411        }
412
413        let word_bits = 64;
414        let mut r = Self::ZERO;
415        let mut carry = 0;
416        for i in 0..LIMBS - limbs {
417            let x = self.limbs[LIMBS - 1 - i];
418            r.limbs[LIMBS - 1 - i - limbs] = (x >> bits) | carry;
419            carry = (x << (word_bits - bits - 1)) << 1;
420        }
421        (r, carry != 0)
422    }
423
424    /// Right shift by `rhs` bits.
425    ///
426    /// $$
427    /// \mathtt{wrapping\\_shr}(\mathtt{self}, \mathtt{rhs}) =
428    /// \floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}
429    /// $$
430    ///
431    /// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs`
432    /// by `BITS` (which is IMHO not very useful).
433    #[cfg(not(target_os = "zkvm"))]
434    #[inline(always)]
435    #[must_use]
436    pub fn wrapping_shr(self, rhs: usize) -> Self {
437        self.overflowing_shr(rhs).0
438    }
439
440    /// Right shift by `rhs` bits.
441    ///
442    /// $$
443    /// \mathtt{wrapping\\_shr}(\mathtt{self}, \mathtt{rhs}) =
444    /// \floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}
445    /// $$
446    ///
447    /// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs`
448    /// by `BITS` (which is IMHO not very useful).
449    #[cfg(target_os = "zkvm")]
450    #[inline(always)]
451    #[must_use]
452    pub fn wrapping_shr(mut self, rhs: usize) -> Self {
453        if BITS == 256 {
454            if rhs >= 256 {
455                return Self::ZERO;
456            }
457            use crate::support::zkvm::zkvm_u256_wrapping_shr_impl;
458            let rhs = rhs as u64;
459            unsafe {
460                zkvm_u256_wrapping_shr_impl(
461                    self.limbs.as_mut_ptr() as *mut u8,
462                    self.limbs.as_ptr() as *const u8,
463                    [rhs].as_ptr() as *const u8,
464                );
465            }
466            self
467        } else {
468            self.overflowing_shr(rhs).0
469        }
470    }
471
472    /// Arithmetic shift right by `rhs` bits.
473    #[cfg(not(target_os = "zkvm"))]
474    #[inline]
475    #[must_use]
476    pub fn arithmetic_shr(self, rhs: usize) -> Self {
477        if BITS == 0 {
478            return Self::ZERO;
479        }
480        let sign = self.bit(BITS - 1);
481        let mut r = self >> rhs;
482        if sign {
483            r |= Self::MAX << BITS.saturating_sub(rhs);
484        }
485        r
486    }
487
488    /// Arithmetic shift right by `rhs` bits.
489    #[cfg(target_os = "zkvm")]
490    #[inline]
491    #[must_use]
492    pub fn arithmetic_shr(mut self, rhs: usize) -> Self {
493        if BITS == 256 {
494            let rhs = if rhs >= 256 { 255 } else { rhs };
495            use crate::support::zkvm::zkvm_u256_arithmetic_shr_impl;
496            let rhs = rhs as u64;
497            unsafe {
498                zkvm_u256_arithmetic_shr_impl(
499                    self.limbs.as_mut_ptr() as *mut u8,
500                    self.limbs.as_ptr() as *const u8,
501                    [rhs].as_ptr() as *const u8,
502                );
503            }
504            self
505        } else {
506            if BITS == 0 {
507                return Self::ZERO;
508            }
509            let sign = self.bit(BITS - 1);
510            let mut r = self >> rhs;
511            if sign {
512                r |= Self::MAX << BITS.saturating_sub(rhs);
513            }
514            r
515        }
516    }
517
518    /// Shifts the bits to the left by a specified amount, `rhs`, wrapping the
519    /// truncated bits to the end of the resulting integer.
520    #[inline]
521    #[must_use]
522    #[allow(clippy::missing_const_for_fn)] // False positive
523    pub fn rotate_left(self, rhs: usize) -> Self {
524        if BITS == 0 {
525            return Self::ZERO;
526        }
527        let rhs = rhs % BITS;
528        (self << rhs) | (self >> (BITS - rhs))
529    }
530
531    /// Shifts the bits to the right by a specified amount, `rhs`, wrapping the
532    /// truncated bits to the beginning of the resulting integer.
533    #[inline(always)]
534    #[must_use]
535    pub fn rotate_right(self, rhs: usize) -> Self {
536        if BITS == 0 {
537            return Self::ZERO;
538        }
539        let rhs = rhs % BITS;
540        self.rotate_left(BITS - rhs)
541    }
542}
543
544impl<const BITS: usize, const LIMBS: usize> Not for Uint<BITS, LIMBS> {
545    type Output = Self;
546
547    #[cfg(not(target_os = "zkvm"))]
548    #[inline]
549    fn not(self) -> Self::Output {
550        Self::not(self)
551    }
552
553    #[cfg(target_os = "zkvm")]
554    #[inline(always)]
555    fn not(mut self) -> Self::Output {
556        use crate::support::zkvm::zkvm_u256_wrapping_sub_impl;
557        if BITS == 256 {
558            unsafe {
559                zkvm_u256_wrapping_sub_impl(
560                    self.limbs.as_mut_ptr() as *mut u8,
561                    Self::MAX.limbs.as_ptr() as *const u8,
562                    self.limbs.as_ptr() as *const u8,
563                );
564            }
565            self
566        } else {
567            if BITS == 0 {
568                return Self::ZERO;
569            }
570            for limb in &mut self.limbs {
571                *limb = u64::not(*limb);
572            }
573            self.limbs[LIMBS - 1] &= Self::MASK;
574            self
575        }
576    }
577}
578
579impl<const BITS: usize, const LIMBS: usize> Not for &Uint<BITS, LIMBS> {
580    type Output = Uint<BITS, LIMBS>;
581
582    #[inline]
583    fn not(self) -> Self::Output {
584        (*self).not()
585    }
586}
587
588macro_rules! impl_bit_op {
589    ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident, $fn_zkvm_impl:ident) => {
590        impl<const BITS: usize, const LIMBS: usize> $trait_assign<Uint<BITS, LIMBS>>
591            for Uint<BITS, LIMBS>
592        {
593            #[inline(always)]
594            fn $fn_assign(&mut self, rhs: Uint<BITS, LIMBS>) {
595                self.$fn_assign(&rhs);
596            }
597        }
598
599        impl<const BITS: usize, const LIMBS: usize> $trait_assign<&Uint<BITS, LIMBS>>
600            for Uint<BITS, LIMBS>
601        {
602            #[cfg(not(target_os = "zkvm"))]
603            #[inline]
604            fn $fn_assign(&mut self, rhs: &Uint<BITS, LIMBS>) {
605                for i in 0..LIMBS {
606                    u64::$fn_assign(&mut self.limbs[i], rhs.limbs[i]);
607                }
608            }
609
610            #[cfg(target_os = "zkvm")]
611            #[inline(always)]
612            fn $fn_assign(&mut self, rhs: &Uint<BITS, LIMBS>) {
613                use crate::support::zkvm::$fn_zkvm_impl;
614                unsafe {
615                    $fn_zkvm_impl(
616                        self.limbs.as_mut_ptr() as *mut u8,
617                        self.limbs.as_ptr() as *const u8,
618                        rhs.limbs.as_ptr() as *const u8,
619                    );
620                }
621            }
622        }
623
624        impl<const BITS: usize, const LIMBS: usize> $trait<Uint<BITS, LIMBS>>
625            for Uint<BITS, LIMBS>
626        {
627            type Output = Uint<BITS, LIMBS>;
628
629            #[inline(always)]
630            fn $fn(mut self, rhs: Uint<BITS, LIMBS>) -> Self::Output {
631                self.$fn_assign(rhs);
632                self
633            }
634        }
635
636        impl<const BITS: usize, const LIMBS: usize> $trait<&Uint<BITS, LIMBS>>
637            for Uint<BITS, LIMBS>
638        {
639            type Output = Uint<BITS, LIMBS>;
640
641            #[inline(always)]
642            fn $fn(mut self, rhs: &Uint<BITS, LIMBS>) -> Self::Output {
643                self.$fn_assign(rhs);
644                self
645            }
646        }
647
648        impl<const BITS: usize, const LIMBS: usize> $trait<Uint<BITS, LIMBS>>
649            for &Uint<BITS, LIMBS>
650        {
651            type Output = Uint<BITS, LIMBS>;
652
653            #[inline(always)]
654            fn $fn(self, mut rhs: Uint<BITS, LIMBS>) -> Self::Output {
655                rhs.$fn_assign(self);
656                rhs
657            }
658        }
659
660        impl<const BITS: usize, const LIMBS: usize> $trait<&Uint<BITS, LIMBS>>
661            for &Uint<BITS, LIMBS>
662        {
663            type Output = Uint<BITS, LIMBS>;
664
665            #[inline(always)]
666            fn $fn(self, rhs: &Uint<BITS, LIMBS>) -> Self::Output {
667                self.clone().$fn(rhs)
668            }
669        }
670    };
671}
672
673impl_bit_op!(
674    BitOr,
675    bitor,
676    BitOrAssign,
677    bitor_assign,
678    zkvm_u256_bitor_impl
679);
680impl_bit_op!(
681    BitAnd,
682    bitand,
683    BitAndAssign,
684    bitand_assign,
685    zkvm_u256_bitand_impl
686);
687impl_bit_op!(
688    BitXor,
689    bitxor,
690    BitXorAssign,
691    bitxor_assign,
692    zkvm_u256_bitxor_impl
693);
694
695impl<const BITS: usize, const LIMBS: usize> Shl<Self> for Uint<BITS, LIMBS> {
696    type Output = Self;
697
698    #[inline(always)]
699    fn shl(self, rhs: Self) -> Self::Output {
700        // This check shortcuts, and prevents panics on the `[0]` later
701        if BITS == 0 {
702            return self;
703        }
704        // Rationale: if BITS is larger than 2**64 - 1, it means we're running
705        // on a 128-bit platform with 2.3 exabytes of memory. In this case,
706        // the code produces incorrect output.
707        #[allow(clippy::cast_possible_truncation)]
708        self.wrapping_shl(rhs.as_limbs()[0] as usize)
709    }
710}
711
712impl<const BITS: usize, const LIMBS: usize> Shl<&Self> for Uint<BITS, LIMBS> {
713    type Output = Self;
714
715    #[inline(always)]
716    fn shl(self, rhs: &Self) -> Self::Output {
717        self << *rhs
718    }
719}
720
721impl<const BITS: usize, const LIMBS: usize> Shr<Self> for Uint<BITS, LIMBS> {
722    type Output = Self;
723
724    #[inline(always)]
725    fn shr(self, rhs: Self) -> Self::Output {
726        // This check shortcuts, and prevents panics on the `[0]` later
727        if BITS == 0 {
728            return self;
729        }
730        // Rationale: if BITS is larger than 2**64 - 1, it means we're running
731        // on a 128-bit platform with 2.3 exabytes of memory. In this case,
732        // the code produces incorrect output.
733        #[allow(clippy::cast_possible_truncation)]
734        self.wrapping_shr(rhs.as_limbs()[0] as usize)
735    }
736}
737
738impl<const BITS: usize, const LIMBS: usize> Shr<&Self> for Uint<BITS, LIMBS> {
739    type Output = Self;
740
741    #[inline(always)]
742    fn shr(self, rhs: &Self) -> Self::Output {
743        self >> *rhs
744    }
745}
746
747impl<const BITS: usize, const LIMBS: usize> ShlAssign<Self> for Uint<BITS, LIMBS> {
748    #[inline(always)]
749    fn shl_assign(&mut self, rhs: Self) {
750        *self = *self << rhs;
751    }
752}
753
754impl<const BITS: usize, const LIMBS: usize> ShlAssign<&Self> for Uint<BITS, LIMBS> {
755    #[inline(always)]
756    fn shl_assign(&mut self, rhs: &Self) {
757        *self = *self << rhs;
758    }
759}
760
761impl<const BITS: usize, const LIMBS: usize> ShrAssign<Self> for Uint<BITS, LIMBS> {
762    #[inline(always)]
763    fn shr_assign(&mut self, rhs: Self) {
764        *self = *self >> rhs;
765    }
766}
767
768impl<const BITS: usize, const LIMBS: usize> ShrAssign<&Self> for Uint<BITS, LIMBS> {
769    #[inline(always)]
770    fn shr_assign(&mut self, rhs: &Self) {
771        *self = *self >> rhs;
772    }
773}
774
775macro_rules! impl_shift {
776    (@main $u:ty) => {
777        impl<const BITS: usize, const LIMBS: usize> Shl<$u> for Uint<BITS, LIMBS> {
778            type Output = Self;
779
780            #[inline(always)]
781            #[allow(clippy::cast_possible_truncation)]
782            fn shl(self, rhs: $u) -> Self::Output {
783                self.wrapping_shl(rhs as usize)
784            }
785        }
786
787        impl<const BITS: usize, const LIMBS: usize> Shr<$u> for Uint<BITS, LIMBS> {
788            type Output = Self;
789
790            #[inline(always)]
791            #[allow(clippy::cast_possible_truncation)]
792            fn shr(self, rhs: $u) -> Self::Output {
793                self.wrapping_shr(rhs as usize)
794            }
795        }
796    };
797
798    (@ref $u:ty) => {
799        impl<const BITS: usize, const LIMBS: usize> Shl<&$u> for Uint<BITS, LIMBS> {
800            type Output = Self;
801
802            #[inline(always)]
803            fn shl(self, rhs: &$u) -> Self::Output {
804                <Self>::shl(self, *rhs)
805            }
806        }
807
808        impl<const BITS: usize, const LIMBS: usize> Shr<&$u> for Uint<BITS, LIMBS> {
809            type Output = Self;
810
811            #[inline(always)]
812            fn shr(self, rhs: &$u) -> Self::Output {
813                <Self>::shr(self, *rhs)
814            }
815        }
816    };
817
818    (@assign $u:ty) => {
819        impl<const BITS: usize, const LIMBS: usize> ShlAssign<$u> for Uint<BITS, LIMBS> {
820            #[inline(always)]
821            fn shl_assign(&mut self, rhs: $u) {
822                *self = *self << rhs;
823            }
824        }
825
826        impl<const BITS: usize, const LIMBS: usize> ShrAssign<$u> for Uint<BITS, LIMBS> {
827            #[inline(always)]
828            fn shr_assign(&mut self, rhs: $u) {
829                *self = *self >> rhs;
830            }
831        }
832    };
833
834    ($u:ty) => {
835        impl_shift!(@main $u);
836        impl_shift!(@ref $u);
837        impl_shift!(@assign $u);
838        impl_shift!(@assign &$u);
839    };
840
841    ($u:ty, $($tail:ty),*) => {
842        impl_shift!($u);
843        impl_shift!($($tail),*);
844    };
845}
846
847impl_shift!(usize, u8, u16, u32, isize, i8, i16, i32);
848
849// Only when losslessy castable to usize.
850#[cfg(target_pointer_width = "64")]
851impl_shift!(u64, i64);
852
853#[cfg(test)]
854mod tests {
855    use core::cmp::min;
856
857    use proptest::proptest;
858
859    use super::*;
860    use crate::{aliases::U128, const_for, nlimbs};
861
862    #[test]
863    fn test_leading_zeros() {
864        assert_eq!(Uint::<0, 0>::ZERO.leading_zeros(), 0);
865        assert_eq!(Uint::<1, 1>::ZERO.leading_zeros(), 1);
866        assert_eq!(Uint::<1, 1>::ONE.leading_zeros(), 0);
867        const_for!(BITS in NON_ZERO {
868            const LIMBS: usize = nlimbs(BITS);
869            type U = Uint::<BITS, LIMBS>;
870            assert_eq!(U::ZERO.leading_zeros(), BITS);
871            assert_eq!(U::MAX.leading_zeros(), 0);
872            assert_eq!(U::ONE.leading_zeros(), BITS - 1);
873            proptest!(|(value: U)| {
874                let zeros = value.leading_zeros();
875                assert!(zeros <= BITS);
876                assert!(zeros < BITS || value == U::ZERO);
877                if zeros < BITS {
878                    let (left, overflow) = value.overflowing_shl(zeros);
879                    assert!(!overflow);
880                    assert!(left.leading_zeros() == 0 || value == U::ZERO);
881                    assert!(left.bit(BITS - 1));
882                    assert_eq!(value >> (BITS - zeros), Uint::ZERO);
883                }
884            });
885        });
886        proptest!(|(value: u128)| {
887            let uint = U128::from(value);
888            assert_eq!(uint.leading_zeros(), value.leading_zeros() as usize);
889        });
890    }
891
892    #[test]
893    fn test_leading_ones() {
894        assert_eq!(Uint::<0, 0>::ZERO.leading_ones(), 0);
895        assert_eq!(Uint::<1, 1>::ZERO.leading_ones(), 0);
896        assert_eq!(Uint::<1, 1>::ONE.leading_ones(), 1);
897    }
898
899    #[test]
900    fn test_most_significant_bits() {
901        const_for!(BITS in NON_ZERO {
902            const LIMBS: usize = nlimbs(BITS);
903            type U = Uint::<BITS, LIMBS>;
904            proptest!(|(value: u64)| {
905                let value = if U::LIMBS <= 1 { value & U::MASK } else { value };
906                assert_eq!(U::from(value).most_significant_bits(), (value, 0));
907            });
908        });
909        proptest!(|(mut limbs: [u64; 2])| {
910            if limbs[1] == 0 {
911                limbs[1] = 1;
912            }
913            let (bits, exponent) = U128::from_limbs(limbs).most_significant_bits();
914            assert!(bits >= 1_u64 << 63);
915            assert_eq!(exponent, 64 - limbs[1].leading_zeros() as usize);
916        });
917    }
918
919    #[test]
920    fn test_checked_shl() {
921        assert_eq!(
922            Uint::<65, 2>::from_limbs([0x0010_0000_0000_0000, 0]).checked_shl(1),
923            Some(Uint::<65, 2>::from_limbs([0x0020_0000_0000_0000, 0]))
924        );
925        assert_eq!(
926            Uint::<127, 2>::from_limbs([0x0010_0000_0000_0000, 0]).checked_shl(64),
927            Some(Uint::<127, 2>::from_limbs([0, 0x0010_0000_0000_0000]))
928        );
929    }
930
931    #[test]
932    #[allow(
933        clippy::cast_lossless,
934        clippy::cast_possible_truncation,
935        clippy::cast_possible_wrap
936    )]
937    fn test_small() {
938        const_for!(BITS in [1, 2, 8, 16, 32, 63, 64] {
939            type U = Uint::<BITS, 1>;
940            proptest!(|(a: U, b: U)| {
941                assert_eq!(a | b, U::from_limbs([a.limbs[0] | b.limbs[0]]));
942                assert_eq!(a & b, U::from_limbs([a.limbs[0] & b.limbs[0]]));
943                assert_eq!(a ^ b, U::from_limbs([a.limbs[0] ^ b.limbs[0]]));
944            });
945            proptest!(|(a: U, s in 0..BITS)| {
946                assert_eq!(a << s, U::from_limbs([a.limbs[0] << s & U::MASK]));
947                assert_eq!(a >> s, U::from_limbs([a.limbs[0] >> s]));
948            });
949        });
950        proptest!(|(a: Uint::<32, 1>, s in 0_usize..=34)| {
951            assert_eq!(a.reverse_bits(), Uint::from((a.limbs[0] as u32).reverse_bits() as u64));
952            assert_eq!(a.rotate_left(s), Uint::from((a.limbs[0] as u32).rotate_left(s as u32) as u64));
953            assert_eq!(a.rotate_right(s), Uint::from((a.limbs[0] as u32).rotate_right(s as u32) as u64));
954            if s < 32 {
955                let arr_shifted = (((a.limbs[0] as i32) >> s) as u32) as u64;
956                assert_eq!(a.arithmetic_shr(s), Uint::from_limbs([arr_shifted]));
957            }
958        });
959        proptest!(|(a: Uint::<64, 1>, s in 0_usize..=66)| {
960            assert_eq!(a.reverse_bits(), Uint::from(a.limbs[0].reverse_bits()));
961            assert_eq!(a.rotate_left(s), Uint::from(a.limbs[0].rotate_left(s as u32)));
962            assert_eq!(a.rotate_right(s), Uint::from(a.limbs[0].rotate_right(s as u32)));
963            if s < 64 {
964                let arr_shifted = ((a.limbs[0] as i64) >> s) as u64;
965                assert_eq!(a.arithmetic_shr(s), Uint::from_limbs([arr_shifted]));
966            }
967        });
968    }
969
970    #[test]
971    fn test_shift_reverse() {
972        const_for!(BITS in SIZES {
973            const LIMBS: usize = nlimbs(BITS);
974            type U = Uint::<BITS, LIMBS>;
975            proptest!(|(value: U, shift in 0..=BITS + 2)| {
976                let left = (value << shift).reverse_bits();
977                let right = value.reverse_bits() >> shift;
978                assert_eq!(left, right);
979            });
980        });
981    }
982
983    #[test]
984    fn test_rotate() {
985        const_for!(BITS in SIZES {
986            const LIMBS: usize = nlimbs(BITS);
987            type U = Uint::<BITS, LIMBS>;
988            proptest!(|(value: U, shift in  0..=BITS + 2)| {
989                let rotated = value.rotate_left(shift).rotate_right(shift);
990                assert_eq!(value, rotated);
991            });
992        });
993    }
994
995    #[test]
996    fn test_arithmetic_shr() {
997        const_for!(BITS in SIZES {
998            const LIMBS: usize = nlimbs(BITS);
999            type U = Uint::<BITS, LIMBS>;
1000            proptest!(|(value: U, shift in  0..=BITS + 2)| {
1001                let shifted = value.arithmetic_shr(shift);
1002                assert_eq!(shifted.leading_ones(), match value.leading_ones() {
1003                    0 => 0,
1004                    n => min(BITS, n + shift)
1005                });
1006            });
1007        });
1008    }
1009
1010    #[test]
1011    fn test_overflowing_shr() {
1012        // Test: Single limb right shift from 40u64 by 1 bit.
1013        // Expects resulting integer: 20 with no fractional part.
1014        assert_eq!(
1015            Uint::<64, 1>::from_limbs([40u64]).overflowing_shr(1),
1016            (Uint::<64, 1>::from(20), false)
1017        );
1018
1019        // Test: Single limb right shift from 41u64 by 1 bit.
1020        // Expects resulting integer: 20 with a detected fractional part.
1021        assert_eq!(
1022            Uint::<64, 1>::from_limbs([41u64]).overflowing_shr(1),
1023            (Uint::<64, 1>::from(20), true)
1024        );
1025
1026        // Test: Two limbs right shift from 0x0010_0000_0000_0000 and 0 by 1 bit.
1027        // Expects resulting limbs: [0x0080_0000_0000_000, 0] with no fractional part.
1028        assert_eq!(
1029            Uint::<65, 2>::from_limbs([0x0010_0000_0000_0000, 0]).overflowing_shr(1),
1030            (Uint::<65, 2>::from_limbs([0x0008_0000_0000_0000, 0]), false)
1031        );
1032
1033        // Test: Shift beyond single limb capacity with MAX value.
1034        // Expects the highest possible value in 256-bit representation with a detected
1035        // fractional part.
1036        assert_eq!(
1037            Uint::<256, 4>::MAX.overflowing_shr(65),
1038            (
1039                Uint::<256, 4>::from_str_radix(
1040                    "7fffffffffffffffffffffffffffffffffffffffffffffff",
1041                    16
1042                )
1043                .unwrap(),
1044                true
1045            )
1046        );
1047        // Test: Large 4096-bit integer right shift by 34 bits.
1048        // Expects a specific value with no fractional part.
1049        assert_eq!(
1050            Uint::<4096, 64>::from_str_radix("3ffffffffffffffffffffffffffffc00000000", 16,)
1051                .unwrap()
1052                .overflowing_shr(34),
1053            (
1054                Uint::<4096, 64>::from_str_radix("fffffffffffffffffffffffffffff", 16).unwrap(),
1055                false
1056            )
1057        );
1058        // Test: Extremely large 4096-bit integer right shift by 100 bits.
1059        // Expects a specific value with no fractional part.
1060        assert_eq!(
1061            Uint::<4096, 64>::from_str_radix(
1062                "fffffffffffffffffffffffffffff0000000000000000000000000",
1063                16,
1064            )
1065            .unwrap()
1066            .overflowing_shr(100),
1067            (
1068                Uint::<4096, 64>::from_str_radix("fffffffffffffffffffffffffffff", 16).unwrap(),
1069                false
1070            )
1071        );
1072        // Test: Complex 4096-bit integer right shift by 1 bit.
1073        // Expects a specific value with no fractional part.
1074        assert_eq!(
1075            Uint::<4096, 64>::from_str_radix(
1076                "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0bdbfe",
1077                16,
1078            )
1079            .unwrap()
1080            .overflowing_shr(1),
1081            (
1082                Uint::<4096, 64>::from_str_radix(
1083                    "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffff85edff",
1084                    16
1085                )
1086                .unwrap(),
1087                false
1088            )
1089        );
1090        // Test: Large 4096-bit integer right shift by 1000 bits.
1091        // Expects a specific value with no fractional part.
1092        assert_eq!(
1093            Uint::<4096, 64>::from_str_radix(
1094                "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
1095                16,
1096            )
1097            .unwrap()
1098            .overflowing_shr(1000),
1099            (
1100                Uint::<4096, 64>::from_str_radix(
1101                    "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
1102                    16
1103                )
1104                .unwrap(),
1105                false
1106            )
1107        );
1108        // Test: MAX 4096-bit integer right shift by 34 bits.
1109        // Expects a specific value with a detected fractional part.
1110        assert_eq!(
1111            Uint::<4096, 64>::MAX
1112            .overflowing_shr(34),
1113            (
1114                Uint::<4096, 64>::from_str_radix(
1115                    "3fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
1116                    16
1117                )
1118                .unwrap(),
1119                true
1120            )
1121        );
1122    }
1123}