ruint/
base_convert.rs

1use crate::{
2    algorithms::{addmul_nx1, mul_nx1},
3    Uint,
4};
5use core::fmt;
6
7/// Error for [`from_base_le`][Uint::from_base_le] and
8/// [`from_base_be`][Uint::from_base_be].
9#[allow(clippy::module_name_repetitions)]
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11pub enum BaseConvertError {
12    /// The value is too large to fit the target type.
13    Overflow,
14
15    /// The requested number base `.0` is less than two.
16    InvalidBase(u64),
17
18    /// The provided digit `.0` is out of range for requested base `.1`.
19    InvalidDigit(u64, u64),
20}
21
22#[cfg(feature = "std")]
23impl std::error::Error for BaseConvertError {}
24
25impl fmt::Display for BaseConvertError {
26    #[inline]
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            Self::Overflow => f.write_str("the value is too large to fit the target type"),
30            Self::InvalidBase(base) => {
31                write!(f, "the requested number base {base} is less than two")
32            }
33            Self::InvalidDigit(digit, base) => {
34                write!(f, "digit {digit} is out of range for base {base}")
35            }
36        }
37    }
38}
39
40impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
41    /// Returns an iterator over the base `base` digits of the number in
42    /// little-endian order.
43    ///
44    /// Pro tip: instead of setting `base = 10`, set it to the highest
45    /// power of `10` that still fits `u64`. This way much fewer iterations
46    /// are required to extract all the digits.
47    // OPT: Internalize this trick so the user won't have to worry about it.
48    /// # Panics
49    ///
50    /// Panics if the base is less than 2.
51    #[inline]
52    pub fn to_base_le(&self, base: u64) -> impl Iterator<Item = u64> {
53        assert!(base > 1);
54        SpigotLittle {
55            base,
56            limbs: self.limbs,
57        }
58    }
59
60    /// Returns an iterator over the base `base` digits of the number in
61    /// big-endian order.
62    ///
63    /// Pro tip: instead of setting `base = 10`, set it to the highest
64    /// power of `10` that still fits `u64`. This way much fewer iterations
65    /// are required to extract all the digits.
66    ///
67    /// # Panics
68    ///
69    /// Panics if the base is less than 2.
70    #[inline]
71    #[cfg(feature = "alloc")] // OPT: Find an allocation free method. Maybe extract from the top?
72    pub fn to_base_be(&self, base: u64) -> impl Iterator<Item = u64> {
73        struct OwnedVecIterator {
74            vec: alloc::vec::Vec<u64>,
75        }
76
77        impl Iterator for OwnedVecIterator {
78            type Item = u64;
79
80            #[inline]
81            fn next(&mut self) -> Option<Self::Item> {
82                self.vec.pop()
83            }
84        }
85
86        assert!(base > 1);
87        OwnedVecIterator {
88            vec: self.to_base_le(base).collect(),
89        }
90    }
91
92    /// Constructs the [`Uint`] from digits in the base `base` in little-endian.
93    ///
94    /// # Errors
95    ///
96    /// * [`BaseConvertError::InvalidBase`] if the base is less than 2.
97    /// * [`BaseConvertError::InvalidDigit`] if a digit is out of range.
98    /// * [`BaseConvertError::Overflow`] if the number is too large to fit.
99    #[inline]
100    pub fn from_base_le<I>(base: u64, digits: I) -> Result<Self, BaseConvertError>
101    where
102        I: IntoIterator<Item = u64>,
103    {
104        if base < 2 {
105            return Err(BaseConvertError::InvalidBase(base));
106        }
107        if BITS == 0 {
108            for digit in digits {
109                if digit >= base {
110                    return Err(BaseConvertError::InvalidDigit(digit, base));
111                }
112                if digit != 0 {
113                    return Err(BaseConvertError::Overflow);
114                }
115            }
116            return Ok(Self::ZERO);
117        }
118
119        let mut iter = digits.into_iter();
120        let mut result = Self::ZERO;
121        let mut power = Self::from(1);
122        for digit in iter.by_ref() {
123            if digit >= base {
124                return Err(BaseConvertError::InvalidDigit(digit, base));
125            }
126
127            // Add digit to result
128            let overflow = addmul_nx1(&mut result.limbs, power.as_limbs(), digit);
129            if overflow != 0 || result.limbs[LIMBS - 1] > Self::MASK {
130                return Err(BaseConvertError::Overflow);
131            }
132
133            // Update power
134            let overflow = mul_nx1(&mut power.limbs, base);
135            if overflow != 0 || power.limbs[LIMBS - 1] > Self::MASK {
136                // Following digits must be zero
137                break;
138            }
139        }
140        for digit in iter {
141            if digit >= base {
142                return Err(BaseConvertError::InvalidDigit(digit, base));
143            }
144            if digit != 0 {
145                return Err(BaseConvertError::Overflow);
146            }
147        }
148        Ok(result)
149    }
150
151    /// Constructs the [`Uint`] from digits in the base `base` in big-endian.
152    ///
153    /// # Errors
154    ///
155    /// * [`BaseConvertError::InvalidBase`] if the base is less than 2.
156    /// * [`BaseConvertError::InvalidDigit`] if a digit is out of range.
157    /// * [`BaseConvertError::Overflow`] if the number is too large to fit.
158    #[inline]
159    pub fn from_base_be<I: IntoIterator<Item = u64>>(
160        base: u64,
161        digits: I,
162    ) -> Result<Self, BaseConvertError> {
163        // OPT: Special handling of bases that divide 2^64, and bases that are
164        // powers of 2.
165        // OPT: Same trick as with `to_base_le`, find the largest power of base
166        // that fits `u64` and accumulate there first.
167        if base < 2 {
168            return Err(BaseConvertError::InvalidBase(base));
169        }
170
171        let mut result = Self::ZERO;
172        for digit in digits {
173            if digit >= base {
174                return Err(BaseConvertError::InvalidDigit(digit, base));
175            }
176            // Multiply by base.
177            // OPT: keep track of non-zero limbs and mul the minimum.
178            let mut carry: u128 = u128::from(digit);
179            #[allow(clippy::cast_possible_truncation)]
180            for limb in &mut result.limbs {
181                carry += u128::from(*limb) * u128::from(base);
182                *limb = carry as u64;
183                carry >>= 64;
184            }
185            if carry > 0 || (LIMBS != 0 && result.limbs[LIMBS - 1] > Self::MASK) {
186                return Err(BaseConvertError::Overflow);
187            }
188        }
189
190        Ok(result)
191    }
192}
193
194struct SpigotLittle<const LIMBS: usize> {
195    base:  u64,
196    limbs: [u64; LIMBS],
197}
198
199impl<const LIMBS: usize> Iterator for SpigotLittle<LIMBS> {
200    type Item = u64;
201
202    #[inline]
203    #[allow(clippy::cast_possible_truncation)] // Doesn't truncate
204    fn next(&mut self) -> Option<Self::Item> {
205        // Knuth Algorithm S.
206        let mut zero: u64 = 0_u64;
207        let mut remainder = 0_u128;
208        // OPT: If we keep track of leading zero limbs we can half iterations.
209        for limb in self.limbs.iter_mut().rev() {
210            zero |= *limb;
211            remainder = (remainder << 64) | u128::from(*limb);
212            *limb = (remainder / u128::from(self.base)) as u64;
213            remainder %= u128::from(self.base);
214        }
215        if zero == 0 {
216            None
217        } else {
218            Some(remainder as u64)
219        }
220    }
221}
222
223#[cfg(test)]
224#[allow(clippy::unreadable_literal)]
225#[allow(clippy::zero_prefixed_literal)]
226mod tests {
227    use super::*;
228
229    // 90630363884335538722706632492458228784305343302099024356772372330524102404852
230    const N: Uint<256, 4> = Uint::from_limbs([
231        0xa8ec92344438aaf4_u64,
232        0x9819ebdbd1faaab1_u64,
233        0x573b1a7064c19c1a_u64,
234        0xc85ef7d79691fe79_u64,
235    ]);
236
237    #[test]
238    fn test_to_base_le() {
239        assert_eq!(
240            Uint::<64, 1>::from(123456789)
241                .to_base_le(10)
242                .collect::<Vec<_>>(),
243            vec![9, 8, 7, 6, 5, 4, 3, 2, 1]
244        );
245        assert_eq!(
246            N.to_base_le(10000000000000000000_u64).collect::<Vec<_>>(),
247            vec![
248                2372330524102404852,
249                0534330209902435677,
250                7066324924582287843,
251                0630363884335538722,
252                9
253            ]
254        );
255    }
256
257    #[test]
258    fn test_from_base_le() {
259        assert_eq!(
260            Uint::<64, 1>::from_base_le(10, [9, 8, 7, 6, 5, 4, 3, 2, 1]),
261            Ok(Uint::<64, 1>::from(123456789))
262        );
263        assert_eq!(
264            Uint::<256, 4>::from_base_le(10000000000000000000_u64, [
265                2372330524102404852,
266                0534330209902435677,
267                7066324924582287843,
268                0630363884335538722,
269                9
270            ])
271            .unwrap(),
272            N
273        );
274    }
275
276    #[test]
277    fn test_to_base_be() {
278        assert_eq!(
279            Uint::<64, 1>::from(123456789)
280                .to_base_be(10)
281                .collect::<Vec<_>>(),
282            vec![1, 2, 3, 4, 5, 6, 7, 8, 9]
283        );
284        assert_eq!(
285            N.to_base_be(10000000000000000000_u64).collect::<Vec<_>>(),
286            vec![
287                9,
288                0630363884335538722,
289                7066324924582287843,
290                0534330209902435677,
291                2372330524102404852
292            ]
293        );
294    }
295
296    #[test]
297    fn test_from_base_be() {
298        assert_eq!(
299            Uint::<64, 1>::from_base_be(10, [1, 2, 3, 4, 5, 6, 7, 8, 9]),
300            Ok(Uint::<64, 1>::from(123456789))
301        );
302        assert_eq!(
303            Uint::<256, 4>::from_base_be(10000000000000000000_u64, [
304                9,
305                0630363884335538722,
306                7066324924582287843,
307                0534330209902435677,
308                2372330524102404852
309            ])
310            .unwrap(),
311            N
312        );
313    }
314
315    #[test]
316    fn test_from_base_be_overflow() {
317        assert_eq!(
318            Uint::<0, 0>::from_base_be(10, core::iter::empty()),
319            Ok(Uint::<0, 0>::ZERO)
320        );
321        assert_eq!(
322            Uint::<0, 0>::from_base_be(10, core::iter::once(0)),
323            Ok(Uint::<0, 0>::ZERO)
324        );
325        assert_eq!(
326            Uint::<0, 0>::from_base_be(10, core::iter::once(1)),
327            Err(BaseConvertError::Overflow)
328        );
329        assert_eq!(
330            Uint::<1, 1>::from_base_be(10, [1, 0, 0].into_iter()),
331            Err(BaseConvertError::Overflow)
332        );
333    }
334}