ruint/
bits.rs

1use core::ops::{
2    BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, ShlAssign, Shr,
3    ShrAssign,
4};
5
6use crate::Uint;
7
8impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
9    /// Returns whether a specific bit is set.
10    ///
11    /// Returns `false` if `index` exceeds the bit width of the number.
12    #[must_use]
13    #[inline]
14    pub const fn bit(&self, index: usize) -> bool {
15        if index >= BITS {
16            return false;
17        }
18        let (limbs, bits) = (index / 64, index % 64);
19        self.limbs[limbs] & (1 << bits) != 0
20    }
21
22    /// Sets a specific bit to a value.
23    #[inline]
24    pub fn set_bit(&mut self, index: usize, value: bool) {
25        if index >= BITS {
26            return;
27        }
28        let (limbs, bits) = (index / 64, index % 64);
29        if value {
30            self.limbs[limbs] |= 1 << bits;
31        } else {
32            self.limbs[limbs] &= !(1 << bits);
33        }
34    }
35
36    /// Returns a specific byte. The byte at index `0` is the least significant
37    /// byte (little endian).
38    ///
39    /// # Panics
40    ///
41    /// Panics if `index` is greater than or equal to the byte width of the
42    /// number.
43    ///
44    /// # Examples
45    ///
46    /// ```
47    /// # use ruint::uint;
48    /// let x = uint!(0x1234567890_U64);
49    /// let bytes = [
50    ///     x.byte(0), // 0x90
51    ///     x.byte(1), // 0x78
52    ///     x.byte(2), // 0x56
53    ///     x.byte(3), // 0x34
54    ///     x.byte(4), // 0x12
55    ///     x.byte(5), // 0x00
56    ///     x.byte(6), // 0x00
57    ///     x.byte(7), // 0x00
58    /// ];
59    /// assert_eq!(bytes, x.to_le_bytes());
60    /// ```
61    ///
62    /// Panics if out of range.
63    ///
64    /// ```should_panic
65    /// # use ruint::uint;
66    /// let x = uint!(0x1234567890_U64);
67    /// let _ = x.byte(8);
68    /// ```
69    #[inline]
70    #[must_use]
71    #[track_caller]
72    pub const fn byte(&self, index: usize) -> u8 {
73        #[cfg(target_endian = "little")]
74        {
75            self.as_le_slice()[index]
76        }
77
78        #[cfg(target_endian = "big")]
79        #[allow(clippy::cast_possible_truncation)] // intentional
80        {
81            (self.limbs[index / 8] >> ((index % 8) * 8)) as u8
82        }
83    }
84
85    /// Returns a specific byte, or `None` if `index` is out of range. The byte
86    /// at index `0` is the least significant byte (little endian).
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// # use ruint::uint;
92    /// let x = uint!(0x1234567890_U64);
93    /// assert_eq!(x.checked_byte(0), Some(0x90));
94    /// assert_eq!(x.checked_byte(7), Some(0x00));
95    /// // Out of range
96    /// assert_eq!(x.checked_byte(8), None);
97    /// ```
98    #[inline]
99    #[must_use]
100    pub const fn checked_byte(&self, index: usize) -> Option<u8> {
101        if index < Self::BYTES {
102            Some(self.byte(index))
103        } else {
104            None
105        }
106    }
107
108    /// Reverses the order of bits in the integer. The least significant bit
109    /// becomes the most significant bit, second least-significant bit becomes
110    /// second most-significant bit, etc.
111    #[inline]
112    #[must_use]
113    pub fn reverse_bits(mut self) -> Self {
114        self.limbs.reverse();
115        for limb in &mut self.limbs {
116            *limb = limb.reverse_bits();
117        }
118        if BITS % 64 != 0 {
119            self >>= 64 - BITS % 64;
120        }
121        self
122    }
123
124    /// Inverts all the bits in the integer.
125    #[inline]
126    #[must_use]
127    pub const fn not(mut self) -> Self {
128        if BITS == 0 {
129            return Self::ZERO;
130        }
131
132        let mut i = 0;
133        while i < LIMBS {
134            self.limbs[i] = !self.limbs[i];
135            i += 1;
136        }
137
138        self.masked()
139    }
140
141    /// Returns the number of leading zeros in the binary representation of
142    /// `self`.
143    #[inline]
144    #[must_use]
145    pub const fn leading_zeros(&self) -> usize {
146        let mut i = LIMBS;
147        while i > 0 {
148            i -= 1;
149            if self.limbs[i] != 0 {
150                let n = LIMBS - 1 - i;
151                let skipped = n * 64;
152                let fixed = Self::MASK.leading_zeros() as usize;
153                let top = self.limbs[i].leading_zeros() as usize;
154                return skipped + top - fixed;
155            }
156        }
157
158        BITS
159    }
160
161    /// Returns the number of leading ones in the binary representation of
162    /// `self`.
163    #[inline]
164    #[must_use]
165    pub const fn leading_ones(&self) -> usize {
166        Self::not(*self).leading_zeros()
167    }
168
169    /// Returns the number of trailing zeros in the binary representation of
170    /// `self`.
171    #[inline]
172    #[must_use]
173    pub fn trailing_zeros(&self) -> usize {
174        self.as_limbs()
175            .iter()
176            .position(|&limb| limb != 0)
177            .map_or(BITS, |n| {
178                n * 64 + self.as_limbs()[n].trailing_zeros() as usize
179            })
180    }
181
182    /// Returns the number of trailing ones in the binary representation of
183    /// `self`.
184    #[inline]
185    #[must_use]
186    pub fn trailing_ones(&self) -> usize {
187        self.as_limbs()
188            .iter()
189            .position(|&limb| limb != u64::MAX)
190            .map_or(BITS, |n| {
191                n * 64 + self.as_limbs()[n].trailing_ones() as usize
192            })
193    }
194
195    /// Returns the number of ones in the binary representation of `self`.
196    #[inline]
197    #[must_use]
198    pub const fn count_ones(&self) -> usize {
199        let mut total = 0;
200
201        let mut i = 0;
202        while i < LIMBS {
203            total += self.limbs[i].count_ones() as usize;
204            i += 1;
205        }
206
207        total
208    }
209
210    /// Returns the number of zeros in the binary representation of `self`.
211    #[must_use]
212    #[inline]
213    pub const fn count_zeros(&self) -> usize {
214        BITS - self.count_ones()
215    }
216
217    /// Returns the dynamic length of this number in bits, ignoring leading
218    /// zeros.
219    ///
220    /// For the maximum length of the type, use [`Uint::BITS`](Self::BITS).
221    #[must_use]
222    #[inline]
223    pub const fn bit_len(&self) -> usize {
224        BITS - self.leading_zeros()
225    }
226
227    /// Returns the dynamic length of this number in bytes, ignoring leading
228    /// zeros.
229    ///
230    /// For the maximum length of the type, use [`Uint::BYTES`](Self::BYTES).
231    #[must_use]
232    #[inline]
233    pub const fn byte_len(&self) -> usize {
234        (self.bit_len() + 7) / 8
235    }
236
237    /// Returns the most significant 64 bits of the number and the exponent.
238    ///
239    /// Given return value $(\mathtt{bits}, \mathtt{exponent})$, the `self` can
240    /// be approximated as
241    ///
242    /// $$
243    /// \mathtt{self} ≈ \mathtt{bits} ⋅ 2^\mathtt{exponent}
244    /// $$
245    ///
246    /// If `self` is $<≥> 2^{63}$, then `exponent` will be zero and `bits` will
247    /// have leading zeros.
248    #[inline]
249    #[must_use]
250    pub fn most_significant_bits(&self) -> (u64, usize) {
251        let first_set_limb = self
252            .as_limbs()
253            .iter()
254            .rposition(|&limb| limb != 0)
255            .unwrap_or(0);
256        if first_set_limb == 0 {
257            (self.as_limbs().first().copied().unwrap_or(0), 0)
258        } else {
259            let hi = self.as_limbs()[first_set_limb];
260            let lo = self.as_limbs()[first_set_limb - 1];
261            let leading_zeros = hi.leading_zeros();
262            let bits = if leading_zeros > 0 {
263                (hi << leading_zeros) | (lo >> (64 - leading_zeros))
264            } else {
265                hi
266            };
267            let exponent = first_set_limb * 64 - leading_zeros as usize;
268            (bits, exponent)
269        }
270    }
271
272    /// Checked left shift by `rhs` bits.
273    ///
274    /// Returns $\mathtt{self} ⋅ 2^{\mathtt{rhs}}$ or [`None`] if the result
275    /// would $≥ 2^{\mathtt{BITS}}$. That is, it returns [`None`] if the bits
276    /// shifted out would be non-zero.
277    ///
278    /// Note: This differs from [`u64::checked_shl`] which returns `None` if the
279    /// shift is larger than BITS (which is IMHO not very useful).
280    #[inline(always)]
281    #[must_use]
282    pub fn checked_shl(self, rhs: usize) -> Option<Self> {
283        match self.overflowing_shl(rhs) {
284            (value, false) => Some(value),
285            _ => None,
286        }
287    }
288
289    /// Saturating left shift by `rhs` bits.
290    ///
291    /// Returns $\mathtt{self} ⋅ 2^{\mathtt{rhs}}$ or [`Uint::MAX`] if the
292    /// result would $≥ 2^{\mathtt{BITS}}$. That is, it returns
293    /// [`Uint::MAX`] if the bits shifted out would be non-zero.
294    #[inline(always)]
295    #[must_use]
296    pub fn saturating_shl(self, rhs: usize) -> Self {
297        match self.overflowing_shl(rhs) {
298            (value, false) => value,
299            _ => Self::MAX,
300        }
301    }
302
303    /// Left shift by `rhs` bits with overflow detection.
304    ///
305    /// Returns $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$.
306    /// If the product is $≥ 2^{\mathtt{BITS}}$ it returns `true`. That is, it
307    /// returns true if the bits shifted out are non-zero.
308    ///
309    /// Note: This differs from [`u64::overflowing_shl`] which returns `true` if
310    /// the shift is larger than `BITS` (which is IMHO not very useful).
311    #[inline]
312    #[must_use]
313    pub fn overflowing_shl(self, rhs: usize) -> (Self, bool) {
314        let (limbs, bits) = (rhs / 64, rhs % 64);
315        if limbs >= LIMBS {
316            return (Self::ZERO, self != Self::ZERO);
317        }
318
319        let word_bits = 64;
320        let mut r = Self::ZERO;
321        let mut carry = 0;
322        for i in 0..Self::LIMBS - limbs {
323            let x = self.limbs[i];
324            r.limbs[i + limbs] = (x << bits) | carry;
325            carry = (x >> (word_bits - bits - 1)) >> 1;
326        }
327        r.apply_mask();
328        (r, carry != 0)
329    }
330
331    /// Left shift by `rhs` bits.
332    ///
333    /// Returns $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$.
334    ///
335    /// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs`
336    /// by `BITS` (which is IMHO not very useful).
337    #[cfg(not(target_os = "zkvm"))]
338    #[inline(always)]
339    #[must_use]
340    pub fn wrapping_shl(self, rhs: usize) -> Self {
341        self.overflowing_shl(rhs).0
342    }
343
344    /// Left shift by `rhs` bits.
345    ///
346    /// Returns $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$.
347    ///
348    /// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs`
349    /// by `BITS` (which is IMHO not very useful).
350    #[cfg(target_os = "zkvm")]
351    #[inline(always)]
352    #[must_use]
353    pub fn wrapping_shl(mut self, rhs: usize) -> Self {
354        if BITS == 256 {
355            if rhs >= 256 {
356                return Self::ZERO;
357            }
358            use crate::support::zkvm::zkvm_u256_wrapping_shl_impl;
359            let rhs = rhs as u64;
360            unsafe {
361                zkvm_u256_wrapping_shl_impl(
362                    self.limbs.as_mut_ptr() as *mut u8,
363                    self.limbs.as_ptr() as *const u8,
364                    [rhs, 0, 0, 0].as_ptr() as *const u8,
365                );
366            }
367            self
368        } else {
369            self.overflowing_shl(rhs).0
370        }
371    }
372
373    /// Checked right shift by `rhs` bits.
374    ///
375    /// $$
376    /// \frac{\mathtt{self}}{2^{\mathtt{rhs}}}
377    /// $$
378    ///
379    /// Returns the above or [`None`] if the division is not exact. This is the
380    /// same as
381    ///
382    /// Note: This differs from [`u64::checked_shr`] which returns `None` if the
383    /// shift is larger than BITS (which is IMHO not very useful).
384    #[inline(always)]
385    #[must_use]
386    pub fn checked_shr(self, rhs: usize) -> Option<Self> {
387        match self.overflowing_shr(rhs) {
388            (value, false) => Some(value),
389            _ => None,
390        }
391    }
392
393    /// Right shift by `rhs` bits with underflow detection.
394    ///
395    /// $$
396    /// \floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}
397    /// $$
398    ///
399    /// Returns the above and `false` if the division was exact, and `true` if
400    /// it was rounded down. This is the same as non-zero bits being shifted
401    /// out.
402    ///
403    /// Note: This differs from [`u64::overflowing_shr`] which returns `true` if
404    /// the shift is larger than `BITS` (which is IMHO not very useful).
405    #[inline]
406    #[must_use]
407    pub fn overflowing_shr(self, rhs: usize) -> (Self, bool) {
408        let (limbs, bits) = (rhs / 64, rhs % 64);
409        if limbs >= LIMBS {
410            return (Self::ZERO, self != Self::ZERO);
411        }
412
413        let word_bits = 64;
414        let mut r = Self::ZERO;
415        let mut carry = 0;
416        for i in 0..LIMBS - limbs {
417            let x = self.limbs[LIMBS - 1 - i];
418            r.limbs[LIMBS - 1 - i - limbs] = (x >> bits) | carry;
419            carry = (x << (word_bits - bits - 1)) << 1;
420        }
421        (r, carry != 0)
422    }
423
424    /// Right shift by `rhs` bits.
425    ///
426    /// $$
427    /// \mathtt{wrapping\\_shr}(\mathtt{self}, \mathtt{rhs}) =
428    /// \floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}
429    /// $$
430    ///
431    /// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs`
432    /// by `BITS` (which is IMHO not very useful).
433    #[cfg(not(target_os = "zkvm"))]
434    #[inline(always)]
435    #[must_use]
436    pub fn wrapping_shr(self, rhs: usize) -> Self {
437        self.overflowing_shr(rhs).0
438    }
439
440    /// Right shift by `rhs` bits.
441    ///
442    /// $$
443    /// \mathtt{wrapping\\_shr}(\mathtt{self}, \mathtt{rhs}) =
444    /// \floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}
445    /// $$
446    ///
447    /// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs`
448    /// by `BITS` (which is IMHO not very useful).
449    #[cfg(target_os = "zkvm")]
450    #[inline(always)]
451    #[must_use]
452    pub fn wrapping_shr(mut self, rhs: usize) -> Self {
453        if BITS == 256 {
454            if rhs >= 256 {
455                return Self::ZERO;
456            }
457            use crate::support::zkvm::zkvm_u256_wrapping_shr_impl;
458            let rhs = rhs as u64;
459            unsafe {
460                zkvm_u256_wrapping_shr_impl(
461                    self.limbs.as_mut_ptr() as *mut u8,
462                    self.limbs.as_ptr() as *const u8,
463                    [rhs, 0, 0, 0].as_ptr() as *const u8,
464                );
465            }
466            self
467        } else {
468            self.overflowing_shr(rhs).0
469        }
470    }
471
472    /// Arithmetic shift right by `rhs` bits.
473    #[cfg(not(target_os = "zkvm"))]
474    #[inline]
475    #[must_use]
476    pub fn arithmetic_shr(self, rhs: usize) -> Self {
477        if BITS == 0 {
478            return Self::ZERO;
479        }
480        let sign = self.bit(BITS - 1);
481        let mut r = self >> rhs;
482        if sign {
483            r |= Self::MAX << BITS.saturating_sub(rhs);
484        }
485        r
486    }
487
488    /// Arithmetic shift right by `rhs` bits.
489    #[cfg(target_os = "zkvm")]
490    #[inline]
491    #[must_use]
492    pub fn arithmetic_shr(mut self, rhs: usize) -> Self {
493        if BITS == 256 {
494            let rhs = if rhs >= 256 { 255 } else { rhs };
495            use crate::support::zkvm::zkvm_u256_arithmetic_shr_impl;
496            let rhs = rhs as u64;
497            unsafe {
498                zkvm_u256_arithmetic_shr_impl(
499                    self.limbs.as_mut_ptr() as *mut u8,
500                    self.limbs.as_ptr() as *const u8,
501                    [rhs, 0, 0, 0].as_ptr() as *const u8,
502                );
503            }
504            self
505        } else {
506            if BITS == 0 {
507                return Self::ZERO;
508            }
509            let sign = self.bit(BITS - 1);
510            let mut r = self >> rhs;
511            if sign {
512                r |= Self::MAX << BITS.saturating_sub(rhs);
513            }
514            r
515        }
516    }
517
518    /// Shifts the bits to the left by a specified amount, `rhs`, wrapping the
519    /// truncated bits to the end of the resulting integer.
520    #[inline]
521    #[must_use]
522    #[allow(clippy::missing_const_for_fn)] // False positive
523    pub fn rotate_left(self, rhs: usize) -> Self {
524        if BITS == 0 {
525            return Self::ZERO;
526        }
527        let rhs = rhs % BITS;
528        (self << rhs) | (self >> (BITS - rhs))
529    }
530
531    /// Shifts the bits to the right by a specified amount, `rhs`, wrapping the
532    /// truncated bits to the beginning of the resulting integer.
533    #[inline(always)]
534    #[must_use]
535    pub fn rotate_right(self, rhs: usize) -> Self {
536        if BITS == 0 {
537            return Self::ZERO;
538        }
539        let rhs = rhs % BITS;
540        self.rotate_left(BITS - rhs)
541    }
542}
543
544impl<const BITS: usize, const LIMBS: usize> Not for Uint<BITS, LIMBS> {
545    type Output = Self;
546
547    #[cfg(not(target_os = "zkvm"))]
548    #[inline]
549    fn not(self) -> Self::Output {
550        Self::not(self)
551    }
552
553    #[cfg(target_os = "zkvm")]
554    #[inline(always)]
555    fn not(mut self) -> Self::Output {
556        use crate::support::zkvm::zkvm_u256_wrapping_sub_impl;
557        if BITS == 256 {
558            unsafe {
559                zkvm_u256_wrapping_sub_impl(
560                    self.limbs.as_mut_ptr() as *mut u8,
561                    Self::MAX.limbs.as_ptr() as *const u8,
562                    self.limbs.as_ptr() as *const u8,
563                );
564            }
565            self
566        } else {
567            if BITS == 0 {
568                return Self::ZERO;
569            }
570            for limb in &mut self.limbs {
571                *limb = u64::not(*limb);
572            }
573            self.limbs[LIMBS - 1] &= Self::MASK;
574            self
575        }
576    }
577}
578
579impl<const BITS: usize, const LIMBS: usize> Not for &Uint<BITS, LIMBS> {
580    type Output = Uint<BITS, LIMBS>;
581
582    #[inline]
583    fn not(self) -> Self::Output {
584        (*self).not()
585    }
586}
587
588macro_rules! impl_bit_op {
589    ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident, $fn_zkvm_impl:ident) => {
590        impl<const BITS: usize, const LIMBS: usize> $trait_assign<Uint<BITS, LIMBS>>
591            for Uint<BITS, LIMBS>
592        {
593            #[inline(always)]
594            fn $fn_assign(&mut self, rhs: Uint<BITS, LIMBS>) {
595                self.$fn_assign(&rhs);
596            }
597        }
598
599        impl<const BITS: usize, const LIMBS: usize> $trait_assign<&Uint<BITS, LIMBS>>
600            for Uint<BITS, LIMBS>
601        {
602            #[cfg(not(target_os = "zkvm"))]
603            #[inline]
604            fn $fn_assign(&mut self, rhs: &Uint<BITS, LIMBS>) {
605                for i in 0..LIMBS {
606                    u64::$fn_assign(&mut self.limbs[i], rhs.limbs[i]);
607                }
608            }
609
610            #[cfg(target_os = "zkvm")]
611            #[inline(always)]
612            fn $fn_assign(&mut self, rhs: &Uint<BITS, LIMBS>) {
613                if BITS == 256 {
614                    use crate::support::zkvm::$fn_zkvm_impl;
615                    unsafe {
616                        $fn_zkvm_impl(
617                            self.limbs.as_mut_ptr() as *mut u8,
618                            self.limbs.as_ptr() as *const u8,
619                            rhs.limbs.as_ptr() as *const u8,
620                        );
621                    }
622                } else {
623                    for i in 0..LIMBS {
624                        u64::$fn_assign(&mut self.limbs[i], rhs.limbs[i]);
625                    }
626                }
627            }
628        }
629
630        impl<const BITS: usize, const LIMBS: usize> $trait<Uint<BITS, LIMBS>>
631            for Uint<BITS, LIMBS>
632        {
633            type Output = Uint<BITS, LIMBS>;
634
635            #[inline(always)]
636            fn $fn(mut self, rhs: Uint<BITS, LIMBS>) -> Self::Output {
637                self.$fn_assign(rhs);
638                self
639            }
640        }
641
642        impl<const BITS: usize, const LIMBS: usize> $trait<&Uint<BITS, LIMBS>>
643            for Uint<BITS, LIMBS>
644        {
645            type Output = Uint<BITS, LIMBS>;
646
647            #[inline(always)]
648            fn $fn(mut self, rhs: &Uint<BITS, LIMBS>) -> Self::Output {
649                self.$fn_assign(rhs);
650                self
651            }
652        }
653
654        impl<const BITS: usize, const LIMBS: usize> $trait<Uint<BITS, LIMBS>>
655            for &Uint<BITS, LIMBS>
656        {
657            type Output = Uint<BITS, LIMBS>;
658
659            #[inline(always)]
660            fn $fn(self, mut rhs: Uint<BITS, LIMBS>) -> Self::Output {
661                rhs.$fn_assign(self);
662                rhs
663            }
664        }
665
666        impl<const BITS: usize, const LIMBS: usize> $trait<&Uint<BITS, LIMBS>>
667            for &Uint<BITS, LIMBS>
668        {
669            type Output = Uint<BITS, LIMBS>;
670
671            #[inline(always)]
672            fn $fn(self, rhs: &Uint<BITS, LIMBS>) -> Self::Output {
673                self.clone().$fn(rhs)
674            }
675        }
676    };
677}
678
679impl_bit_op!(
680    BitOr,
681    bitor,
682    BitOrAssign,
683    bitor_assign,
684    zkvm_u256_bitor_impl
685);
686impl_bit_op!(
687    BitAnd,
688    bitand,
689    BitAndAssign,
690    bitand_assign,
691    zkvm_u256_bitand_impl
692);
693impl_bit_op!(
694    BitXor,
695    bitxor,
696    BitXorAssign,
697    bitxor_assign,
698    zkvm_u256_bitxor_impl
699);
700
701impl<const BITS: usize, const LIMBS: usize> Shl<Self> for Uint<BITS, LIMBS> {
702    type Output = Self;
703
704    #[inline(always)]
705    fn shl(self, rhs: Self) -> Self::Output {
706        // This check shortcuts, and prevents panics on the `[0]` later
707        if BITS == 0 {
708            return self;
709        }
710        // Rationale: if BITS is larger than 2**64 - 1, it means we're running
711        // on a 128-bit platform with 2.3 exabytes of memory. In this case,
712        // the code produces incorrect output.
713        #[allow(clippy::cast_possible_truncation)]
714        self.wrapping_shl(rhs.as_limbs()[0] as usize)
715    }
716}
717
718impl<const BITS: usize, const LIMBS: usize> Shl<&Self> for Uint<BITS, LIMBS> {
719    type Output = Self;
720
721    #[inline(always)]
722    fn shl(self, rhs: &Self) -> Self::Output {
723        self << *rhs
724    }
725}
726
727impl<const BITS: usize, const LIMBS: usize> Shr<Self> for Uint<BITS, LIMBS> {
728    type Output = Self;
729
730    #[inline(always)]
731    fn shr(self, rhs: Self) -> Self::Output {
732        // This check shortcuts, and prevents panics on the `[0]` later
733        if BITS == 0 {
734            return self;
735        }
736        // Rationale: if BITS is larger than 2**64 - 1, it means we're running
737        // on a 128-bit platform with 2.3 exabytes of memory. In this case,
738        // the code produces incorrect output.
739        #[allow(clippy::cast_possible_truncation)]
740        self.wrapping_shr(rhs.as_limbs()[0] as usize)
741    }
742}
743
744impl<const BITS: usize, const LIMBS: usize> Shr<&Self> for Uint<BITS, LIMBS> {
745    type Output = Self;
746
747    #[inline(always)]
748    fn shr(self, rhs: &Self) -> Self::Output {
749        self >> *rhs
750    }
751}
752
753impl<const BITS: usize, const LIMBS: usize> ShlAssign<Self> for Uint<BITS, LIMBS> {
754    #[inline(always)]
755    fn shl_assign(&mut self, rhs: Self) {
756        *self = *self << rhs;
757    }
758}
759
760impl<const BITS: usize, const LIMBS: usize> ShlAssign<&Self> for Uint<BITS, LIMBS> {
761    #[inline(always)]
762    fn shl_assign(&mut self, rhs: &Self) {
763        *self = *self << rhs;
764    }
765}
766
767impl<const BITS: usize, const LIMBS: usize> ShrAssign<Self> for Uint<BITS, LIMBS> {
768    #[inline(always)]
769    fn shr_assign(&mut self, rhs: Self) {
770        *self = *self >> rhs;
771    }
772}
773
774impl<const BITS: usize, const LIMBS: usize> ShrAssign<&Self> for Uint<BITS, LIMBS> {
775    #[inline(always)]
776    fn shr_assign(&mut self, rhs: &Self) {
777        *self = *self >> rhs;
778    }
779}
780
781macro_rules! impl_shift {
782    (@main $u:ty) => {
783        impl<const BITS: usize, const LIMBS: usize> Shl<$u> for Uint<BITS, LIMBS> {
784            type Output = Self;
785
786            #[inline(always)]
787            #[allow(clippy::cast_possible_truncation)]
788            fn shl(self, rhs: $u) -> Self::Output {
789                self.wrapping_shl(rhs as usize)
790            }
791        }
792
793        impl<const BITS: usize, const LIMBS: usize> Shr<$u> for Uint<BITS, LIMBS> {
794            type Output = Self;
795
796            #[inline(always)]
797            #[allow(clippy::cast_possible_truncation)]
798            fn shr(self, rhs: $u) -> Self::Output {
799                self.wrapping_shr(rhs as usize)
800            }
801        }
802    };
803
804    (@ref $u:ty) => {
805        impl<const BITS: usize, const LIMBS: usize> Shl<&$u> for Uint<BITS, LIMBS> {
806            type Output = Self;
807
808            #[inline(always)]
809            fn shl(self, rhs: &$u) -> Self::Output {
810                <Self>::shl(self, *rhs)
811            }
812        }
813
814        impl<const BITS: usize, const LIMBS: usize> Shr<&$u> for Uint<BITS, LIMBS> {
815            type Output = Self;
816
817            #[inline(always)]
818            fn shr(self, rhs: &$u) -> Self::Output {
819                <Self>::shr(self, *rhs)
820            }
821        }
822    };
823
824    (@assign $u:ty) => {
825        impl<const BITS: usize, const LIMBS: usize> ShlAssign<$u> for Uint<BITS, LIMBS> {
826            #[inline(always)]
827            fn shl_assign(&mut self, rhs: $u) {
828                *self = *self << rhs;
829            }
830        }
831
832        impl<const BITS: usize, const LIMBS: usize> ShrAssign<$u> for Uint<BITS, LIMBS> {
833            #[inline(always)]
834            fn shr_assign(&mut self, rhs: $u) {
835                *self = *self >> rhs;
836            }
837        }
838    };
839
840    ($u:ty) => {
841        impl_shift!(@main $u);
842        impl_shift!(@ref $u);
843        impl_shift!(@assign $u);
844        impl_shift!(@assign &$u);
845    };
846
847    ($u:ty, $($tail:ty),*) => {
848        impl_shift!($u);
849        impl_shift!($($tail),*);
850    };
851}
852
853impl_shift!(usize, u8, u16, u32, isize, i8, i16, i32);
854
855// Only when losslessy castable to usize.
856#[cfg(target_pointer_width = "64")]
857impl_shift!(u64, i64);
858
859#[cfg(test)]
860mod tests {
861    use core::cmp::min;
862
863    use proptest::proptest;
864
865    use super::*;
866    use crate::{aliases::U128, const_for, nlimbs};
867
868    #[test]
869    fn test_leading_zeros() {
870        assert_eq!(Uint::<0, 0>::ZERO.leading_zeros(), 0);
871        assert_eq!(Uint::<1, 1>::ZERO.leading_zeros(), 1);
872        assert_eq!(Uint::<1, 1>::ONE.leading_zeros(), 0);
873        const_for!(BITS in NON_ZERO {
874            const LIMBS: usize = nlimbs(BITS);
875            type U = Uint::<BITS, LIMBS>;
876            assert_eq!(U::ZERO.leading_zeros(), BITS);
877            assert_eq!(U::MAX.leading_zeros(), 0);
878            assert_eq!(U::ONE.leading_zeros(), BITS - 1);
879            proptest!(|(value: U)| {
880                let zeros = value.leading_zeros();
881                assert!(zeros <= BITS);
882                assert!(zeros < BITS || value == U::ZERO);
883                if zeros < BITS {
884                    let (left, overflow) = value.overflowing_shl(zeros);
885                    assert!(!overflow);
886                    assert!(left.leading_zeros() == 0 || value == U::ZERO);
887                    assert!(left.bit(BITS - 1));
888                    assert_eq!(value >> (BITS - zeros), Uint::ZERO);
889                }
890            });
891        });
892        proptest!(|(value: u128)| {
893            let uint = U128::from(value);
894            assert_eq!(uint.leading_zeros(), value.leading_zeros() as usize);
895        });
896    }
897
898    #[test]
899    fn test_leading_ones() {
900        assert_eq!(Uint::<0, 0>::ZERO.leading_ones(), 0);
901        assert_eq!(Uint::<1, 1>::ZERO.leading_ones(), 0);
902        assert_eq!(Uint::<1, 1>::ONE.leading_ones(), 1);
903    }
904
905    #[test]
906    fn test_most_significant_bits() {
907        const_for!(BITS in NON_ZERO {
908            const LIMBS: usize = nlimbs(BITS);
909            type U = Uint::<BITS, LIMBS>;
910            proptest!(|(value: u64)| {
911                let value = if U::LIMBS <= 1 { value & U::MASK } else { value };
912                assert_eq!(U::from(value).most_significant_bits(), (value, 0));
913            });
914        });
915        proptest!(|(mut limbs: [u64; 2])| {
916            if limbs[1] == 0 {
917                limbs[1] = 1;
918            }
919            let (bits, exponent) = U128::from_limbs(limbs).most_significant_bits();
920            assert!(bits >= 1_u64 << 63);
921            assert_eq!(exponent, 64 - limbs[1].leading_zeros() as usize);
922        });
923    }
924
925    #[test]
926    fn test_checked_shl() {
927        assert_eq!(
928            Uint::<65, 2>::from_limbs([0x0010_0000_0000_0000, 0]).checked_shl(1),
929            Some(Uint::<65, 2>::from_limbs([0x0020_0000_0000_0000, 0]))
930        );
931        assert_eq!(
932            Uint::<127, 2>::from_limbs([0x0010_0000_0000_0000, 0]).checked_shl(64),
933            Some(Uint::<127, 2>::from_limbs([0, 0x0010_0000_0000_0000]))
934        );
935    }
936
937    #[test]
938    #[allow(
939        clippy::cast_lossless,
940        clippy::cast_possible_truncation,
941        clippy::cast_possible_wrap
942    )]
943    fn test_small() {
944        const_for!(BITS in [1, 2, 8, 16, 32, 63, 64] {
945            type U = Uint::<BITS, 1>;
946            proptest!(|(a: U, b: U)| {
947                assert_eq!(a | b, U::from_limbs([a.limbs[0] | b.limbs[0]]));
948                assert_eq!(a & b, U::from_limbs([a.limbs[0] & b.limbs[0]]));
949                assert_eq!(a ^ b, U::from_limbs([a.limbs[0] ^ b.limbs[0]]));
950            });
951            proptest!(|(a: U, s in 0..BITS)| {
952                assert_eq!(a << s, U::from_limbs([a.limbs[0] << s & U::MASK]));
953                assert_eq!(a >> s, U::from_limbs([a.limbs[0] >> s]));
954            });
955        });
956        proptest!(|(a: Uint::<32, 1>, s in 0_usize..=34)| {
957            assert_eq!(a.reverse_bits(), Uint::from((a.limbs[0] as u32).reverse_bits() as u64));
958            assert_eq!(a.rotate_left(s), Uint::from((a.limbs[0] as u32).rotate_left(s as u32) as u64));
959            assert_eq!(a.rotate_right(s), Uint::from((a.limbs[0] as u32).rotate_right(s as u32) as u64));
960            if s < 32 {
961                let arr_shifted = (((a.limbs[0] as i32) >> s) as u32) as u64;
962                assert_eq!(a.arithmetic_shr(s), Uint::from_limbs([arr_shifted]));
963            }
964        });
965        proptest!(|(a: Uint::<64, 1>, s in 0_usize..=66)| {
966            assert_eq!(a.reverse_bits(), Uint::from(a.limbs[0].reverse_bits()));
967            assert_eq!(a.rotate_left(s), Uint::from(a.limbs[0].rotate_left(s as u32)));
968            assert_eq!(a.rotate_right(s), Uint::from(a.limbs[0].rotate_right(s as u32)));
969            if s < 64 {
970                let arr_shifted = ((a.limbs[0] as i64) >> s) as u64;
971                assert_eq!(a.arithmetic_shr(s), Uint::from_limbs([arr_shifted]));
972            }
973        });
974    }
975
976    #[test]
977    fn test_shift_reverse() {
978        const_for!(BITS in SIZES {
979            const LIMBS: usize = nlimbs(BITS);
980            type U = Uint::<BITS, LIMBS>;
981            proptest!(|(value: U, shift in 0..=BITS + 2)| {
982                let left = (value << shift).reverse_bits();
983                let right = value.reverse_bits() >> shift;
984                assert_eq!(left, right);
985            });
986        });
987    }
988
989    #[test]
990    fn test_rotate() {
991        const_for!(BITS in SIZES {
992            const LIMBS: usize = nlimbs(BITS);
993            type U = Uint::<BITS, LIMBS>;
994            proptest!(|(value: U, shift in  0..=BITS + 2)| {
995                let rotated = value.rotate_left(shift).rotate_right(shift);
996                assert_eq!(value, rotated);
997            });
998        });
999    }
1000
1001    #[test]
1002    fn test_arithmetic_shr() {
1003        const_for!(BITS in SIZES {
1004            const LIMBS: usize = nlimbs(BITS);
1005            type U = Uint::<BITS, LIMBS>;
1006            proptest!(|(value: U, shift in  0..=BITS + 2)| {
1007                let shifted = value.arithmetic_shr(shift);
1008                assert_eq!(shifted.leading_ones(), match value.leading_ones() {
1009                    0 => 0,
1010                    n => min(BITS, n + shift)
1011                });
1012            });
1013        });
1014    }
1015
1016    #[test]
1017    fn test_overflowing_shr() {
1018        // Test: Single limb right shift from 40u64 by 1 bit.
1019        // Expects resulting integer: 20 with no fractional part.
1020        assert_eq!(
1021            Uint::<64, 1>::from_limbs([40u64]).overflowing_shr(1),
1022            (Uint::<64, 1>::from(20), false)
1023        );
1024
1025        // Test: Single limb right shift from 41u64 by 1 bit.
1026        // Expects resulting integer: 20 with a detected fractional part.
1027        assert_eq!(
1028            Uint::<64, 1>::from_limbs([41u64]).overflowing_shr(1),
1029            (Uint::<64, 1>::from(20), true)
1030        );
1031
1032        // Test: Two limbs right shift from 0x0010_0000_0000_0000 and 0 by 1 bit.
1033        // Expects resulting limbs: [0x0080_0000_0000_000, 0] with no fractional part.
1034        assert_eq!(
1035            Uint::<65, 2>::from_limbs([0x0010_0000_0000_0000, 0]).overflowing_shr(1),
1036            (Uint::<65, 2>::from_limbs([0x0008_0000_0000_0000, 0]), false)
1037        );
1038
1039        // Test: Shift beyond single limb capacity with MAX value.
1040        // Expects the highest possible value in 256-bit representation with a detected
1041        // fractional part.
1042        assert_eq!(
1043            Uint::<256, 4>::MAX.overflowing_shr(65),
1044            (
1045                Uint::<256, 4>::from_str_radix(
1046                    "7fffffffffffffffffffffffffffffffffffffffffffffff",
1047                    16
1048                )
1049                .unwrap(),
1050                true
1051            )
1052        );
1053        // Test: Large 4096-bit integer right shift by 34 bits.
1054        // Expects a specific value with no fractional part.
1055        assert_eq!(
1056            Uint::<4096, 64>::from_str_radix("3ffffffffffffffffffffffffffffc00000000", 16,)
1057                .unwrap()
1058                .overflowing_shr(34),
1059            (
1060                Uint::<4096, 64>::from_str_radix("fffffffffffffffffffffffffffff", 16).unwrap(),
1061                false
1062            )
1063        );
1064        // Test: Extremely large 4096-bit integer right shift by 100 bits.
1065        // Expects a specific value with no fractional part.
1066        assert_eq!(
1067            Uint::<4096, 64>::from_str_radix(
1068                "fffffffffffffffffffffffffffff0000000000000000000000000",
1069                16,
1070            )
1071            .unwrap()
1072            .overflowing_shr(100),
1073            (
1074                Uint::<4096, 64>::from_str_radix("fffffffffffffffffffffffffffff", 16).unwrap(),
1075                false
1076            )
1077        );
1078        // Test: Complex 4096-bit integer right shift by 1 bit.
1079        // Expects a specific value with no fractional part.
1080        assert_eq!(
1081            Uint::<4096, 64>::from_str_radix(
1082                "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0bdbfe",
1083                16,
1084            )
1085            .unwrap()
1086            .overflowing_shr(1),
1087            (
1088                Uint::<4096, 64>::from_str_radix(
1089                    "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffff85edff",
1090                    16
1091                )
1092                .unwrap(),
1093                false
1094            )
1095        );
1096        // Test: Large 4096-bit integer right shift by 1000 bits.
1097        // Expects a specific value with no fractional part.
1098        assert_eq!(
1099            Uint::<4096, 64>::from_str_radix(
1100                "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
1101                16,
1102            )
1103            .unwrap()
1104            .overflowing_shr(1000),
1105            (
1106                Uint::<4096, 64>::from_str_radix(
1107                    "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
1108                    16
1109                )
1110                .unwrap(),
1111                false
1112            )
1113        );
1114        // Test: MAX 4096-bit integer right shift by 34 bits.
1115        // Expects a specific value with a detected fractional part.
1116        assert_eq!(
1117            Uint::<4096, 64>::MAX
1118            .overflowing_shr(34),
1119            (
1120                Uint::<4096, 64>::from_str_radix(
1121                    "3fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
1122                    16
1123                )
1124                .unwrap(),
1125                true
1126            )
1127        );
1128    }
1129}