1#![doc = include_str!("../README.md")]
2#![warn(clippy::all, clippy::pedantic, clippy::cargo, clippy::nursery)]
3
4use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
5use std::fmt::{self, Write};
6
7#[doc = include_str!("../README.md")]
9#[proc_macro]
10pub fn uint(stream: TokenStream) -> TokenStream {
11 Transformer::new(None).transform_stream(stream)
12}
13
14#[proc_macro]
22#[doc(hidden)]
23pub fn uint_with_path(stream: TokenStream) -> TokenStream {
24 let mut stream_iter = stream.into_iter();
25 let Some(TokenTree::Group(group)) = stream_iter.next() else {
26 return error(
27 Span::call_site(),
28 "Expected a group containing the `ruint` crate path",
29 )
30 .into();
31 };
32 Transformer::new(Some(group.stream())).transform_stream(stream_iter.collect())
33}
34
35#[derive(Copy, Clone, PartialEq, Debug)]
36enum LiteralBaseType {
37 Uint,
38 Bits,
39}
40
41impl LiteralBaseType {
42 const PATTERN: &'static [char] = &['U', 'B'];
43}
44
45impl fmt::Display for LiteralBaseType {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::Uint => f.write_str("Uint"),
49 Self::Bits => f.write_str("Bits"),
50 }
51 }
52}
53
54impl std::str::FromStr for LiteralBaseType {
55 type Err = ();
56
57 fn from_str(s: &str) -> Result<Self, Self::Err> {
58 match s {
59 "U" => Ok(Self::Uint),
60 "B" => Ok(Self::Bits),
61 _ => Err(()),
62 }
63 }
64}
65
66fn error(span: Span, message: &str) -> TokenTree {
70 let tokens = TokenStream::from_iter(vec![
72 TokenTree::Ident(Ident::new("compile_error", span)),
73 TokenTree::Punct(Punct::new('!', Spacing::Alone)),
74 TokenTree::Group({
75 let mut group = Group::new(
76 Delimiter::Brace,
77 TokenStream::from_iter(vec![TokenTree::Literal(Literal::string(message))]),
78 );
79 group.set_span(span);
80 group
81 }),
82 ]);
83 TokenTree::Group(Group::new(Delimiter::None, tokens))
84}
85
86fn parse_digits(value: &str) -> Result<Vec<u64>, String> {
87 let (base, digits) = if value.len() >= 2 {
89 let (prefix, remainder) = value.split_at(2);
90 match prefix {
91 "0x" => (16_u8, remainder),
92 "0o" => (8, remainder),
93 "0b" => (2, remainder),
94 _ => (10, value),
95 }
96 } else {
97 (10, value)
98 };
99
100 let mut limbs = vec![0_u64];
102 for c in digits.chars() {
103 let digit = match c {
105 '0'..='9' => c as u64 - '0' as u64,
106 'a'..='f' => c as u64 - 'a' as u64 + 10,
107 'A'..='F' => c as u64 - 'A' as u64 + 10,
108 '_' => continue,
109 _ => return Err(format!("Invalid character '{c}'")),
110 };
111 #[allow(clippy::cast_lossless)]
112 if digit > base as u64 {
113 return Err(format!(
114 "Invalid digit {c} in base {base} (did you forget the `0x` prefix?)"
115 ));
116 }
117
118 let mut carry = digit;
120 #[allow(clippy::cast_lossless)]
121 #[allow(clippy::cast_possible_truncation)]
122 for limb in &mut limbs {
123 let product = (*limb as u128) * (base as u128) + (carry as u128);
124 *limb = product as u64;
125 carry = (product >> 64) as u64;
126 }
127 if carry > 0 {
128 limbs.push(carry);
129 }
130 }
131 Ok(limbs)
132}
133
134fn pad_limbs(bits: usize, mut limbs: Vec<u64>) -> Option<Vec<u64>> {
135 let num_limbs = (bits + 63) / 64;
137 let mask = if bits == 0 {
138 0
139 } else {
140 let bits = bits % 64;
141 if bits == 0 {
142 u64::MAX
143 } else {
144 (1 << bits) - 1
145 }
146 };
147
148 while limbs.len() > num_limbs && limbs.last() == Some(&0) {
150 limbs.pop();
151 }
152 while limbs.len() < num_limbs {
153 limbs.push(0);
154 }
155
156 if limbs.len() > num_limbs || limbs.last().copied().unwrap_or(0) > mask {
158 return None;
159 }
160 Some(limbs)
161}
162
163fn parse_suffix(source: &str) -> Option<(LiteralBaseType, usize, &str)> {
164 let suffix_index = source.rfind(LiteralBaseType::PATTERN)?;
166 let (value, suffix) = source.split_at(suffix_index);
167 let (base_type, bits) = suffix.split_at(1);
168 let base_type = base_type.parse::<LiteralBaseType>().ok()?;
169 let bits = bits.parse::<usize>().ok()?;
170
171 if base_type == LiteralBaseType::Bits && value.starts_with("0x") && !value.ends_with('_') {
173 return None;
174 }
175 Some((base_type, bits, value))
176}
177
178struct Transformer {
179 ruint_crate: TokenStream,
183}
184
185impl Transformer {
186 fn new(ruint_crate: Option<TokenStream>) -> Self {
187 Self {
188 ruint_crate: ruint_crate.unwrap_or_else(|| "::ruint".parse().unwrap()),
189 }
190 }
191
192 fn construct(&self, base_type: LiteralBaseType, bits: usize, limbs: &[u64]) -> TokenStream {
194 let mut limbs_str = String::new();
195 for limb in limbs {
196 write!(&mut limbs_str, "0x{limb:016x}_u64, ").unwrap();
197 }
198 let limbs_str = limbs_str.trim_end_matches(", ");
199 let limbs = (bits + 63) / 64;
200 let source = format!("::{base_type}::<{bits}, {limbs}>::from_limbs([{limbs_str}])");
201
202 let mut tokens = self.ruint_crate.clone();
203 tokens.extend(source.parse::<TokenStream>().unwrap());
204 tokens
205 }
206
207 fn transform_literal(&self, source: &str) -> Result<Option<TokenStream>, String> {
209 let Some((base_type, bits, value)) = parse_suffix(source) else {
211 return Ok(None);
212 };
213
214 let limbs = parse_digits(value)?;
217
218 let Some(limbs) = pad_limbs(bits, limbs) else {
220 let value = value.trim_end_matches('_');
221 return Err(format!("Value too large for {base_type}<{bits}>: {value}"));
222 };
223
224 Ok(Some(self.construct(base_type, bits, &limbs)))
225 }
226
227 fn transform_tree(&self, tree: TokenTree) -> TokenTree {
229 match tree {
230 TokenTree::Group(group) => {
231 let delimiter = group.delimiter();
232 let span = group.span();
233 let stream = self.transform_stream(group.stream());
234 let mut transformed = Group::new(delimiter, stream);
235 transformed.set_span(span);
236 TokenTree::Group(transformed)
237 }
238 TokenTree::Literal(a) => {
239 let span = a.span();
240 let source = a.to_string();
241 let mut tree = match self.transform_literal(&source) {
242 Ok(Some(stream)) => TokenTree::Group({
243 let mut group = Group::new(Delimiter::None, stream);
244 group.set_span(span);
245 group
246 }),
247 Ok(None) => TokenTree::Literal(a),
248 Err(message) => error(span, &message),
249 };
250 tree.set_span(span);
251 tree
252 }
253 tree => tree,
254 }
255 }
256
257 fn transform_stream(&self, stream: TokenStream) -> TokenStream {
259 stream
260 .into_iter()
261 .map(|tree| self.transform_tree(tree))
262 .collect()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_zero_size() {
272 assert_eq!(parse_digits("0"), Ok(vec![0]));
273 assert_eq!(parse_digits("00000"), Ok(vec![0]));
274 assert_eq!(parse_digits("0x00"), Ok(vec![0]));
275 assert_eq!(parse_digits("0b0000"), Ok(vec![0]));
276 assert_eq!(parse_digits("0b0000000"), Ok(vec![0]));
277
278 assert_eq!(parse_digits("0"), Ok(vec![0]));
279 assert_eq!(parse_digits("00000"), Ok(vec![0]));
280 assert_eq!(parse_digits("0x00"), Ok(vec![0]));
281 assert_eq!(parse_digits("0b0000"), Ok(vec![0]));
282 assert_eq!(parse_digits("0b0000000"), Ok(vec![0]));
283 }
284
285 #[test]
286 fn test_bases() {
287 assert_eq!(parse_digits("10"), Ok(vec![10]));
288 assert_eq!(parse_digits("0x10"), Ok(vec![16]));
289 assert_eq!(parse_digits("0b10"), Ok(vec![2]));
290 assert_eq!(parse_digits("0o10"), Ok(vec![8]));
291 }
292
293 #[test]
294 #[allow(clippy::unreadable_literal)]
295 fn test_overflow_during_parsing() {
296 assert_eq!(parse_digits("258664426012969093929703085429980814127835149614277183275038967946009968870203535512256352201271898244626862047232"), Ok(vec![0, 15125697203588300800, 6414901478162127871, 13296924585243691235, 13584922160258634318, 121098312706494698]));
297 assert_eq!(parse_digits("2135987035920910082395021706169552114602704522356652769947041607822219725780640550022962086936576"), Ok(vec![0, 0, 0, 0, 0, 1]));
298 }
299}