ruint_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(clippy::all, clippy::pedantic, clippy::cargo, clippy::nursery)]
3
4use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
5use std::fmt::{self, Write};
6
7// Repeat the crate doc.
8#[doc = include_str!("../README.md")]
9#[proc_macro]
10pub fn uint(stream: TokenStream) -> TokenStream {
11    Transformer::new(None).transform_stream(stream)
12}
13
14/// Same as [`uint`], but with the first token always being a
15/// [group](proc_macro::Group) containing the `ruint` crate path.
16///
17/// This allows the macro to be used in a crates that don't on `ruint` through a
18/// wrapper `macro_rules!` that passes `$crate` as the path.
19///
20/// This is an implementation detail and should not be used directly.
21#[proc_macro]
22#[doc(hidden)]
23pub fn uint_with_path(stream: TokenStream) -> TokenStream {
24    let mut stream_iter = stream.into_iter();
25    let Some(TokenTree::Group(group)) = stream_iter.next() else {
26        return error(
27            Span::call_site(),
28            "Expected a group containing the `ruint` crate path",
29        )
30        .into();
31    };
32    Transformer::new(Some(group.stream())).transform_stream(stream_iter.collect())
33}
34
35#[derive(Copy, Clone, PartialEq, Debug)]
36enum LiteralBaseType {
37    Uint,
38    Bits,
39}
40
41impl LiteralBaseType {
42    const PATTERN: &'static [char] = &['U', 'B'];
43}
44
45impl fmt::Display for LiteralBaseType {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            Self::Uint => f.write_str("Uint"),
49            Self::Bits => f.write_str("Bits"),
50        }
51    }
52}
53
54impl std::str::FromStr for LiteralBaseType {
55    type Err = ();
56
57    fn from_str(s: &str) -> Result<Self, Self::Err> {
58        match s {
59            "U" => Ok(Self::Uint),
60            "B" => Ok(Self::Bits),
61            _ => Err(()),
62        }
63    }
64}
65
66/// Construct a compiler error message.
67// FEATURE: (BLOCKED) Replace with Diagnostic API when stable.
68// See <https://doc.rust-lang.org/stable/proc_macro/struct.Diagnostic.html>
69fn error(span: Span, message: &str) -> TokenTree {
70    // See: https://docs.rs/syn/1.0.70/src/syn/error.rs.html#243
71    let tokens = TokenStream::from_iter(vec![
72        TokenTree::Ident(Ident::new("compile_error", span)),
73        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
74        TokenTree::Group({
75            let mut group = Group::new(
76                Delimiter::Brace,
77                TokenStream::from_iter(vec![TokenTree::Literal(Literal::string(message))]),
78            );
79            group.set_span(span);
80            group
81        }),
82    ]);
83    TokenTree::Group(Group::new(Delimiter::None, tokens))
84}
85
86fn parse_digits(value: &str) -> Result<Vec<u64>, String> {
87    // Parse base
88    let (base, digits) = if value.len() >= 2 {
89        let (prefix, remainder) = value.split_at(2);
90        match prefix {
91            "0x" => (16_u8, remainder),
92            "0o" => (8, remainder),
93            "0b" => (2, remainder),
94            _ => (10, value),
95        }
96    } else {
97        (10, value)
98    };
99
100    // Parse digits in base
101    let mut limbs = vec![0_u64];
102    for c in digits.chars() {
103        // Read next digit
104        let digit = match c {
105            '0'..='9' => c as u64 - '0' as u64,
106            'a'..='f' => c as u64 - 'a' as u64 + 10,
107            'A'..='F' => c as u64 - 'A' as u64 + 10,
108            '_' => continue,
109            _ => return Err(format!("Invalid character '{c}'")),
110        };
111        #[allow(clippy::cast_lossless)]
112        if digit > base as u64 {
113            return Err(format!(
114                "Invalid digit {c} in base {base} (did you forget the `0x` prefix?)"
115            ));
116        }
117
118        // Multiply result by base and add digit
119        let mut carry = digit;
120        #[allow(clippy::cast_lossless)]
121        #[allow(clippy::cast_possible_truncation)]
122        for limb in &mut limbs {
123            let product = (*limb as u128) * (base as u128) + (carry as u128);
124            *limb = product as u64;
125            carry = (product >> 64) as u64;
126        }
127        if carry > 0 {
128            limbs.push(carry);
129        }
130    }
131    Ok(limbs)
132}
133
134fn pad_limbs(bits: usize, mut limbs: Vec<u64>) -> Option<Vec<u64>> {
135    // Get limb count and mask
136    let num_limbs = (bits + 63) / 64;
137    let mask = if bits == 0 {
138        0
139    } else {
140        let bits = bits % 64;
141        if bits == 0 {
142            u64::MAX
143        } else {
144            (1 << bits) - 1
145        }
146    };
147
148    // Remove trailing zeros, pad with zeros
149    while limbs.len() > num_limbs && limbs.last() == Some(&0) {
150        limbs.pop();
151    }
152    while limbs.len() < num_limbs {
153        limbs.push(0);
154    }
155
156    // Validate length
157    if limbs.len() > num_limbs || limbs.last().copied().unwrap_or(0) > mask {
158        return None;
159    }
160    Some(limbs)
161}
162
163fn parse_suffix(source: &str) -> Option<(LiteralBaseType, usize, &str)> {
164    // Parse into value, bits, and base type.
165    let suffix_index = source.rfind(LiteralBaseType::PATTERN)?;
166    let (value, suffix) = source.split_at(suffix_index);
167    let (base_type, bits) = suffix.split_at(1);
168    let base_type = base_type.parse::<LiteralBaseType>().ok()?;
169    let bits = bits.parse::<usize>().ok()?;
170
171    // Ignore hexadecimal Bits literals without `_` before the suffix.
172    if base_type == LiteralBaseType::Bits && value.starts_with("0x") && !value.ends_with('_') {
173        return None;
174    }
175    Some((base_type, bits, value))
176}
177
178struct Transformer {
179    /// The `ruint` crate path.
180    /// Note that this stream's span must be used in order for the `$crate` to
181    /// work.
182    ruint_crate: TokenStream,
183}
184
185impl Transformer {
186    fn new(ruint_crate: Option<TokenStream>) -> Self {
187        Self {
188            ruint_crate: ruint_crate.unwrap_or_else(|| "::ruint".parse().unwrap()),
189        }
190    }
191
192    /// Construct a `<{base_type}><{bits}>` literal from `limbs`.
193    fn construct(&self, base_type: LiteralBaseType, bits: usize, limbs: &[u64]) -> TokenStream {
194        let mut limbs_str = String::new();
195        for limb in limbs {
196            write!(&mut limbs_str, "0x{limb:016x}_u64, ").unwrap();
197        }
198        let limbs_str = limbs_str.trim_end_matches(", ");
199        let limbs = (bits + 63) / 64;
200        let source = format!("::{base_type}::<{bits}, {limbs}>::from_limbs([{limbs_str}])");
201
202        let mut tokens = self.ruint_crate.clone();
203        tokens.extend(source.parse::<TokenStream>().unwrap());
204        tokens
205    }
206
207    /// Transforms a [`Literal`] and returns the substitute [`TokenStream`].
208    fn transform_literal(&self, source: &str) -> Result<Option<TokenStream>, String> {
209        // Check if literal has a suffix we accept.
210        let Some((base_type, bits, value)) = parse_suffix(source) else {
211            return Ok(None);
212        };
213
214        // Parse `value` into limbs.
215        // At this point we are confident the literal was for us, so we throw errors.
216        let limbs = parse_digits(value)?;
217
218        // Pad limbs to the correct length.
219        let Some(limbs) = pad_limbs(bits, limbs) else {
220            let value = value.trim_end_matches('_');
221            return Err(format!("Value too large for {base_type}<{bits}>: {value}"));
222        };
223
224        Ok(Some(self.construct(base_type, bits, &limbs)))
225    }
226
227    /// Recurse down tree and transform all literals.
228    fn transform_tree(&self, tree: TokenTree) -> TokenTree {
229        match tree {
230            TokenTree::Group(group) => {
231                let delimiter = group.delimiter();
232                let span = group.span();
233                let stream = self.transform_stream(group.stream());
234                let mut transformed = Group::new(delimiter, stream);
235                transformed.set_span(span);
236                TokenTree::Group(transformed)
237            }
238            TokenTree::Literal(a) => {
239                let span = a.span();
240                let source = a.to_string();
241                let mut tree = match self.transform_literal(&source) {
242                    Ok(Some(stream)) => TokenTree::Group({
243                        let mut group = Group::new(Delimiter::None, stream);
244                        group.set_span(span);
245                        group
246                    }),
247                    Ok(None) => TokenTree::Literal(a),
248                    Err(message) => error(span, &message),
249                };
250                tree.set_span(span);
251                tree
252            }
253            tree => tree,
254        }
255    }
256
257    /// Iterate over a [`TokenStream`] and transform all [`TokenTree`]s.
258    fn transform_stream(&self, stream: TokenStream) -> TokenStream {
259        stream
260            .into_iter()
261            .map(|tree| self.transform_tree(tree))
262            .collect()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_zero_size() {
272        assert_eq!(parse_digits("0"), Ok(vec![0]));
273        assert_eq!(parse_digits("00000"), Ok(vec![0]));
274        assert_eq!(parse_digits("0x00"), Ok(vec![0]));
275        assert_eq!(parse_digits("0b0000"), Ok(vec![0]));
276        assert_eq!(parse_digits("0b0000000"), Ok(vec![0]));
277
278        assert_eq!(parse_digits("0"), Ok(vec![0]));
279        assert_eq!(parse_digits("00000"), Ok(vec![0]));
280        assert_eq!(parse_digits("0x00"), Ok(vec![0]));
281        assert_eq!(parse_digits("0b0000"), Ok(vec![0]));
282        assert_eq!(parse_digits("0b0000000"), Ok(vec![0]));
283    }
284
285    #[test]
286    fn test_bases() {
287        assert_eq!(parse_digits("10"), Ok(vec![10]));
288        assert_eq!(parse_digits("0x10"), Ok(vec![16]));
289        assert_eq!(parse_digits("0b10"), Ok(vec![2]));
290        assert_eq!(parse_digits("0o10"), Ok(vec![8]));
291    }
292
293    #[test]
294    #[allow(clippy::unreadable_literal)]
295    fn test_overflow_during_parsing() {
296        assert_eq!(parse_digits("258664426012969093929703085429980814127835149614277183275038967946009968870203535512256352201271898244626862047232"), Ok(vec![0, 15125697203588300800, 6414901478162127871, 13296924585243691235, 13584922160258634318, 121098312706494698]));
297        assert_eq!(parse_digits("2135987035920910082395021706169552114602704522356652769947041607822219725780640550022962086936576"), Ok(vec![0, 0, 0, 0, 0, 1]));
298    }
299}