ruint_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
3#![allow(clippy::manual_div_ceil)]
4
5use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
6use std::fmt::{self, Write};
7
8// Repeat the crate doc.
9#[doc = include_str!("../README.md")]
10#[proc_macro]
11pub fn uint(stream: TokenStream) -> TokenStream {
12    Transformer::new(None).transform_stream(stream)
13}
14
15/// Same as [`uint`], but with the first token always being a
16/// [group](proc_macro::Group) containing the `ruint` crate path.
17///
18/// This allows the macro to be used in a crates that don't on `ruint` through a
19/// wrapper `macro_rules!` that passes `$crate` as the path.
20///
21/// This is an implementation detail and should not be used directly.
22#[proc_macro]
23#[doc(hidden)]
24pub fn uint_with_path(stream: TokenStream) -> TokenStream {
25    let mut stream_iter = stream.into_iter();
26    let Some(TokenTree::Group(group)) = stream_iter.next() else {
27        return error(
28            Span::call_site(),
29            "Expected a group containing the `ruint` crate path",
30        )
31        .into();
32    };
33    Transformer::new(Some(group.stream())).transform_stream(stream_iter.collect())
34}
35
36#[derive(Copy, Clone, PartialEq, Debug)]
37enum LiteralBaseType {
38    Uint,
39    Bits,
40}
41
42impl LiteralBaseType {
43    const PATTERN: &'static [char] = &['U', 'B'];
44}
45
46impl fmt::Display for LiteralBaseType {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            Self::Uint => f.write_str("Uint"),
50            Self::Bits => f.write_str("Bits"),
51        }
52    }
53}
54
55impl std::str::FromStr for LiteralBaseType {
56    type Err = ();
57
58    fn from_str(s: &str) -> Result<Self, Self::Err> {
59        match s {
60            "U" => Ok(Self::Uint),
61            "B" => Ok(Self::Bits),
62            _ => Err(()),
63        }
64    }
65}
66
67/// Construct a compiler error message.
68// FEATURE: (BLOCKED) Replace with Diagnostic API when stable.
69// See <https://doc.rust-lang.org/stable/proc_macro/struct.Diagnostic.html>
70fn error(span: Span, message: &str) -> TokenTree {
71    // See: https://docs.rs/syn/1.0.70/src/syn/error.rs.html#243
72    let tokens = TokenStream::from_iter(vec![
73        TokenTree::Ident(Ident::new("compile_error", span)),
74        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
75        TokenTree::Group({
76            let mut group = Group::new(
77                Delimiter::Brace,
78                TokenStream::from_iter(vec![TokenTree::Literal(Literal::string(message))]),
79            );
80            group.set_span(span);
81            group
82        }),
83    ]);
84    TokenTree::Group(Group::new(Delimiter::None, tokens))
85}
86
87fn parse_digits(value: &str) -> Result<Vec<u64>, String> {
88    // Parse base
89    let (base, digits) = if value.len() >= 2 {
90        let (prefix, remainder) = value.split_at(2);
91        match prefix {
92            "0x" => (16_u8, remainder),
93            "0o" => (8, remainder),
94            "0b" => (2, remainder),
95            _ => (10, value),
96        }
97    } else {
98        (10, value)
99    };
100
101    // Parse digits in base
102    let mut limbs = vec![0_u64];
103    for c in digits.chars() {
104        // Read next digit
105        let digit = match c {
106            '0'..='9' => c as u64 - '0' as u64,
107            'a'..='f' => c as u64 - 'a' as u64 + 10,
108            'A'..='F' => c as u64 - 'A' as u64 + 10,
109            '_' => continue,
110            _ => return Err(format!("Invalid character '{c}'")),
111        };
112        #[allow(clippy::cast_lossless)]
113        if digit > base as u64 {
114            return Err(format!(
115                "Invalid digit {c} in base {base} (did you forget the `0x` prefix?)"
116            ));
117        }
118
119        // Multiply result by base and add digit
120        let mut carry = digit;
121        #[allow(clippy::cast_lossless)]
122        #[allow(clippy::cast_possible_truncation)]
123        for limb in &mut limbs {
124            let product = (*limb as u128) * (base as u128) + (carry as u128);
125            *limb = product as u64;
126            carry = (product >> 64) as u64;
127        }
128        if carry > 0 {
129            limbs.push(carry);
130        }
131    }
132    Ok(limbs)
133}
134
135fn pad_limbs(bits: usize, mut limbs: Vec<u64>) -> Option<Vec<u64>> {
136    // Get limb count and mask
137    let num_limbs = (bits + 63) / 64;
138    let mask = if bits == 0 {
139        0
140    } else {
141        let bits = bits % 64;
142        if bits == 0 {
143            u64::MAX
144        } else {
145            (1 << bits) - 1
146        }
147    };
148
149    // Remove trailing zeros, pad with zeros
150    while limbs.len() > num_limbs && limbs.last() == Some(&0) {
151        limbs.pop();
152    }
153    while limbs.len() < num_limbs {
154        limbs.push(0);
155    }
156
157    // Validate length
158    if limbs.len() > num_limbs || limbs.last().copied().unwrap_or(0) > mask {
159        return None;
160    }
161    Some(limbs)
162}
163
164fn parse_suffix(source: &str) -> Option<(LiteralBaseType, usize, &str)> {
165    // Parse into value, bits, and base type.
166    let suffix_index = source.rfind(LiteralBaseType::PATTERN)?;
167    let (value, suffix) = source.split_at(suffix_index);
168    let (base_type, bits) = suffix.split_at(1);
169    let base_type = base_type.parse::<LiteralBaseType>().ok()?;
170    let bits = bits.parse::<usize>().ok()?;
171
172    // Ignore hexadecimal Bits literals without `_` before the suffix.
173    if base_type == LiteralBaseType::Bits && value.starts_with("0x") && !value.ends_with('_') {
174        return None;
175    }
176    Some((base_type, bits, value))
177}
178
179struct Transformer {
180    /// The `ruint` crate path.
181    /// Note that this stream's span must be used in order for the `$crate` to
182    /// work.
183    ruint_crate: TokenStream,
184}
185
186impl Transformer {
187    fn new(ruint_crate: Option<TokenStream>) -> Self {
188        Self {
189            ruint_crate: ruint_crate.unwrap_or_else(|| "::ruint".parse().unwrap()),
190        }
191    }
192
193    /// Construct a `<{base_type}><{bits}>` literal from `limbs`.
194    fn construct(&self, base_type: LiteralBaseType, bits: usize, limbs: &[u64]) -> TokenStream {
195        let mut limbs_str = String::new();
196        for limb in limbs {
197            write!(&mut limbs_str, "0x{limb:016x}_u64, ").unwrap();
198        }
199        let limbs_str = limbs_str.trim_end_matches(", ");
200        let limbs = (bits + 63) / 64;
201        let source = format!("::{base_type}::<{bits}, {limbs}>::from_limbs([{limbs_str}])");
202
203        let mut tokens = self.ruint_crate.clone();
204        tokens.extend(source.parse::<TokenStream>().unwrap());
205        tokens
206    }
207
208    /// Transforms a [`Literal`] and returns the substitute [`TokenStream`].
209    fn transform_literal(&self, source: &str) -> Result<Option<TokenStream>, String> {
210        // Check if literal has a suffix we accept.
211        let Some((base_type, bits, value)) = parse_suffix(source) else {
212            return Ok(None);
213        };
214
215        // Parse `value` into limbs.
216        // At this point we are confident the literal was for us, so we throw errors.
217        let limbs = parse_digits(value)?;
218
219        // Pad limbs to the correct length.
220        let Some(limbs) = pad_limbs(bits, limbs) else {
221            let value = value.trim_end_matches('_');
222            return Err(format!("Value too large for {base_type}<{bits}>: {value}"));
223        };
224
225        Ok(Some(self.construct(base_type, bits, &limbs)))
226    }
227
228    /// Recurse down tree and transform all literals.
229    fn transform_tree(&self, tree: TokenTree) -> TokenTree {
230        match tree {
231            TokenTree::Group(group) => {
232                let delimiter = group.delimiter();
233                let span = group.span();
234                let stream = self.transform_stream(group.stream());
235                let mut transformed = Group::new(delimiter, stream);
236                transformed.set_span(span);
237                TokenTree::Group(transformed)
238            }
239            TokenTree::Literal(a) => {
240                let span = a.span();
241                let source = a.to_string();
242                let mut tree = match self.transform_literal(&source) {
243                    Ok(Some(stream)) => TokenTree::Group({
244                        let mut group = Group::new(Delimiter::None, stream);
245                        group.set_span(span);
246                        group
247                    }),
248                    Ok(None) => TokenTree::Literal(a),
249                    Err(message) => error(span, &message),
250                };
251                tree.set_span(span);
252                tree
253            }
254            tree => tree,
255        }
256    }
257
258    /// Iterate over a [`TokenStream`] and transform all [`TokenTree`]s.
259    fn transform_stream(&self, stream: TokenStream) -> TokenStream {
260        stream
261            .into_iter()
262            .map(|tree| self.transform_tree(tree))
263            .collect()
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_zero_size() {
273        assert_eq!(parse_digits("0"), Ok(vec![0]));
274        assert_eq!(parse_digits("00000"), Ok(vec![0]));
275        assert_eq!(parse_digits("0x00"), Ok(vec![0]));
276        assert_eq!(parse_digits("0b0000"), Ok(vec![0]));
277        assert_eq!(parse_digits("0b0000000"), Ok(vec![0]));
278
279        assert_eq!(parse_digits("0"), Ok(vec![0]));
280        assert_eq!(parse_digits("00000"), Ok(vec![0]));
281        assert_eq!(parse_digits("0x00"), Ok(vec![0]));
282        assert_eq!(parse_digits("0b0000"), Ok(vec![0]));
283        assert_eq!(parse_digits("0b0000000"), Ok(vec![0]));
284    }
285
286    #[test]
287    fn test_bases() {
288        assert_eq!(parse_digits("10"), Ok(vec![10]));
289        assert_eq!(parse_digits("0x10"), Ok(vec![16]));
290        assert_eq!(parse_digits("0b10"), Ok(vec![2]));
291        assert_eq!(parse_digits("0o10"), Ok(vec![8]));
292    }
293
294    #[test]
295    #[allow(clippy::unreadable_literal)]
296    fn test_overflow_during_parsing() {
297        assert_eq!(parse_digits("258664426012969093929703085429980814127835149614277183275038967946009968870203535512256352201271898244626862047232"), Ok(vec![0, 15125697203588300800, 6414901478162127871, 13296924585243691235, 13584922160258634318, 121098312706494698]));
298        assert_eq!(parse_digits("2135987035920910082395021706169552114602704522356652769947041607822219725780640550022962086936576"), Ok(vec![0, 0, 0, 0, 0, 1]));
299    }
300}