1#![doc = include_str!("../README.md")]
2#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
3#![allow(clippy::manual_div_ceil)]
4
5use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
6use std::fmt::{self, Write};
7
8#[doc = include_str!("../README.md")]
10#[proc_macro]
11pub fn uint(stream: TokenStream) -> TokenStream {
12 Transformer::new(None).transform_stream(stream)
13}
14
15#[proc_macro]
23#[doc(hidden)]
24pub fn uint_with_path(stream: TokenStream) -> TokenStream {
25 let mut stream_iter = stream.into_iter();
26 let Some(TokenTree::Group(group)) = stream_iter.next() else {
27 return error(
28 Span::call_site(),
29 "Expected a group containing the `ruint` crate path",
30 )
31 .into();
32 };
33 Transformer::new(Some(group.stream())).transform_stream(stream_iter.collect())
34}
35
36#[derive(Copy, Clone, PartialEq, Debug)]
37enum LiteralBaseType {
38 Uint,
39 Bits,
40}
41
42impl LiteralBaseType {
43 const PATTERN: &'static [char] = &['U', 'B'];
44}
45
46impl fmt::Display for LiteralBaseType {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 match self {
49 Self::Uint => f.write_str("Uint"),
50 Self::Bits => f.write_str("Bits"),
51 }
52 }
53}
54
55impl std::str::FromStr for LiteralBaseType {
56 type Err = ();
57
58 fn from_str(s: &str) -> Result<Self, Self::Err> {
59 match s {
60 "U" => Ok(Self::Uint),
61 "B" => Ok(Self::Bits),
62 _ => Err(()),
63 }
64 }
65}
66
67fn error(span: Span, message: &str) -> TokenTree {
71 let tokens = TokenStream::from_iter(vec![
73 TokenTree::Ident(Ident::new("compile_error", span)),
74 TokenTree::Punct(Punct::new('!', Spacing::Alone)),
75 TokenTree::Group({
76 let mut group = Group::new(
77 Delimiter::Brace,
78 TokenStream::from_iter(vec![TokenTree::Literal(Literal::string(message))]),
79 );
80 group.set_span(span);
81 group
82 }),
83 ]);
84 TokenTree::Group(Group::new(Delimiter::None, tokens))
85}
86
87fn parse_digits(value: &str) -> Result<Vec<u64>, String> {
88 let (base, digits) = if value.len() >= 2 {
90 let (prefix, remainder) = value.split_at(2);
91 match prefix {
92 "0x" => (16_u8, remainder),
93 "0o" => (8, remainder),
94 "0b" => (2, remainder),
95 _ => (10, value),
96 }
97 } else {
98 (10, value)
99 };
100
101 let mut limbs = vec![0_u64];
103 for c in digits.chars() {
104 let digit = match c {
106 '0'..='9' => c as u64 - '0' as u64,
107 'a'..='f' => c as u64 - 'a' as u64 + 10,
108 'A'..='F' => c as u64 - 'A' as u64 + 10,
109 '_' => continue,
110 _ => return Err(format!("Invalid character '{c}'")),
111 };
112 #[allow(clippy::cast_lossless)]
113 if digit > base as u64 {
114 return Err(format!(
115 "Invalid digit {c} in base {base} (did you forget the `0x` prefix?)"
116 ));
117 }
118
119 let mut carry = digit;
121 #[allow(clippy::cast_lossless)]
122 #[allow(clippy::cast_possible_truncation)]
123 for limb in &mut limbs {
124 let product = (*limb as u128) * (base as u128) + (carry as u128);
125 *limb = product as u64;
126 carry = (product >> 64) as u64;
127 }
128 if carry > 0 {
129 limbs.push(carry);
130 }
131 }
132 Ok(limbs)
133}
134
135fn pad_limbs(bits: usize, mut limbs: Vec<u64>) -> Option<Vec<u64>> {
136 let num_limbs = (bits + 63) / 64;
138 let mask = if bits == 0 {
139 0
140 } else {
141 let bits = bits % 64;
142 if bits == 0 {
143 u64::MAX
144 } else {
145 (1 << bits) - 1
146 }
147 };
148
149 while limbs.len() > num_limbs && limbs.last() == Some(&0) {
151 limbs.pop();
152 }
153 while limbs.len() < num_limbs {
154 limbs.push(0);
155 }
156
157 if limbs.len() > num_limbs || limbs.last().copied().unwrap_or(0) > mask {
159 return None;
160 }
161 Some(limbs)
162}
163
164fn parse_suffix(source: &str) -> Option<(LiteralBaseType, usize, &str)> {
165 let suffix_index = source.rfind(LiteralBaseType::PATTERN)?;
167 let (value, suffix) = source.split_at(suffix_index);
168 let (base_type, bits) = suffix.split_at(1);
169 let base_type = base_type.parse::<LiteralBaseType>().ok()?;
170 let bits = bits.parse::<usize>().ok()?;
171
172 if base_type == LiteralBaseType::Bits && value.starts_with("0x") && !value.ends_with('_') {
174 return None;
175 }
176 Some((base_type, bits, value))
177}
178
179struct Transformer {
180 ruint_crate: TokenStream,
184}
185
186impl Transformer {
187 fn new(ruint_crate: Option<TokenStream>) -> Self {
188 Self {
189 ruint_crate: ruint_crate.unwrap_or_else(|| "::ruint".parse().unwrap()),
190 }
191 }
192
193 fn construct(&self, base_type: LiteralBaseType, bits: usize, limbs: &[u64]) -> TokenStream {
195 let mut limbs_str = String::new();
196 for limb in limbs {
197 write!(&mut limbs_str, "0x{limb:016x}_u64, ").unwrap();
198 }
199 let limbs_str = limbs_str.trim_end_matches(", ");
200 let limbs = (bits + 63) / 64;
201 let source = format!("::{base_type}::<{bits}, {limbs}>::from_limbs([{limbs_str}])");
202
203 let mut tokens = self.ruint_crate.clone();
204 tokens.extend(source.parse::<TokenStream>().unwrap());
205 tokens
206 }
207
208 fn transform_literal(&self, source: &str) -> Result<Option<TokenStream>, String> {
210 let Some((base_type, bits, value)) = parse_suffix(source) else {
212 return Ok(None);
213 };
214
215 let limbs = parse_digits(value)?;
218
219 let Some(limbs) = pad_limbs(bits, limbs) else {
221 let value = value.trim_end_matches('_');
222 return Err(format!("Value too large for {base_type}<{bits}>: {value}"));
223 };
224
225 Ok(Some(self.construct(base_type, bits, &limbs)))
226 }
227
228 fn transform_tree(&self, tree: TokenTree) -> TokenTree {
230 match tree {
231 TokenTree::Group(group) => {
232 let delimiter = group.delimiter();
233 let span = group.span();
234 let stream = self.transform_stream(group.stream());
235 let mut transformed = Group::new(delimiter, stream);
236 transformed.set_span(span);
237 TokenTree::Group(transformed)
238 }
239 TokenTree::Literal(a) => {
240 let span = a.span();
241 let source = a.to_string();
242 let mut tree = match self.transform_literal(&source) {
243 Ok(Some(stream)) => TokenTree::Group({
244 let mut group = Group::new(Delimiter::None, stream);
245 group.set_span(span);
246 group
247 }),
248 Ok(None) => TokenTree::Literal(a),
249 Err(message) => error(span, &message),
250 };
251 tree.set_span(span);
252 tree
253 }
254 tree => tree,
255 }
256 }
257
258 fn transform_stream(&self, stream: TokenStream) -> TokenStream {
260 stream
261 .into_iter()
262 .map(|tree| self.transform_tree(tree))
263 .collect()
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_zero_size() {
273 assert_eq!(parse_digits("0"), Ok(vec![0]));
274 assert_eq!(parse_digits("00000"), Ok(vec![0]));
275 assert_eq!(parse_digits("0x00"), Ok(vec![0]));
276 assert_eq!(parse_digits("0b0000"), Ok(vec![0]));
277 assert_eq!(parse_digits("0b0000000"), Ok(vec![0]));
278
279 assert_eq!(parse_digits("0"), Ok(vec![0]));
280 assert_eq!(parse_digits("00000"), Ok(vec![0]));
281 assert_eq!(parse_digits("0x00"), Ok(vec![0]));
282 assert_eq!(parse_digits("0b0000"), Ok(vec![0]));
283 assert_eq!(parse_digits("0b0000000"), Ok(vec![0]));
284 }
285
286 #[test]
287 fn test_bases() {
288 assert_eq!(parse_digits("10"), Ok(vec![10]));
289 assert_eq!(parse_digits("0x10"), Ok(vec![16]));
290 assert_eq!(parse_digits("0b10"), Ok(vec![2]));
291 assert_eq!(parse_digits("0o10"), Ok(vec![8]));
292 }
293
294 #[test]
295 #[allow(clippy::unreadable_literal)]
296 fn test_overflow_during_parsing() {
297 assert_eq!(parse_digits("258664426012969093929703085429980814127835149614277183275038967946009968870203535512256352201271898244626862047232"), Ok(vec![0, 15125697203588300800, 6414901478162127871, 13296924585243691235, 13584922160258634318, 121098312706494698]));
298 assert_eq!(parse_digits("2135987035920910082395021706169552114602704522356652769947041607822219725780640550022962086936576"), Ok(vec![0, 0, 0, 0, 0, 1]));
299 }
300}