alloy_sol_macro_expander/expand/
ty.rs

1//! [`Type`] expansion.
2
3use super::ExpCtxt;
4use ast::{Item, Parameters, Spanned, Type, TypeArray};
5use proc_macro2::{Ident, Literal, Span, TokenStream};
6use proc_macro_error2::{abort, emit_error};
7use quote::{quote_spanned, ToTokens};
8use std::{fmt, num::NonZeroU16};
9
10const MAX_SUPPORTED_ARRAY_LEN: usize = 32;
11const MAX_SUPPORTED_TUPLE_LEN: usize = 12;
12
13impl ExpCtxt<'_> {
14    /// Expands a single [`Type`] recursively to its `alloy_sol_types::sol_data`
15    /// equivalent.
16    pub fn expand_type(&self, ty: &Type) -> TokenStream {
17        let mut tokens = TokenStream::new();
18        self.expand_type_to(ty, &mut tokens);
19        tokens
20    }
21
22    /// Expands a single [`Type`] recursively to its Rust type equivalent.
23    ///
24    /// This is the same as `<#expand_type(ty) as SolType>::RustType`, but generates
25    /// nicer code for documentation and IDE/LSP support when the type is not
26    /// ambiguous.
27    pub fn expand_rust_type(&self, ty: &Type) -> TokenStream {
28        let mut tokens = TokenStream::new();
29        self.expand_rust_type_to(ty, &mut tokens);
30        tokens
31    }
32
33    /// Expands a single [`Type`] recursively to its `alloy_sol_types::sol_data` equivalent into the
34    /// given buffer.
35    ///
36    /// See [`expand_type`](Self::expand_type) for more information.
37    pub fn expand_type_to(&self, ty: &Type, tokens: &mut TokenStream) {
38        let alloy_sol_types = &self.crates.sol_types;
39        let tts = match *ty {
40            Type::Address(span, _) => quote_spanned! {span=> #alloy_sol_types::sol_data::Address },
41            Type::Bool(span) => quote_spanned! {span=> #alloy_sol_types::sol_data::Bool },
42            Type::String(span) => quote_spanned! {span=> #alloy_sol_types::sol_data::String },
43            Type::Bytes(span) => quote_spanned! {span=> #alloy_sol_types::sol_data::Bytes },
44
45            Type::FixedBytes(span, size) => {
46                assert!(size.get() <= 32);
47                let size = Literal::u16_unsuffixed(size.get());
48                quote_spanned! {span=> #alloy_sol_types::sol_data::FixedBytes<#size> }
49            }
50            Type::Int(span, size) | Type::Uint(span, size) => {
51                let name = match ty {
52                    Type::Int(..) => "Int",
53                    Type::Uint(..) => "Uint",
54                    _ => unreachable!(),
55                };
56                let name = Ident::new(name, span);
57
58                let size = size.map_or(256, NonZeroU16::get);
59                assert!(size <= 256 && size % 8 == 0);
60                let size = Literal::u16_unsuffixed(size);
61
62                quote_spanned! {span=> #alloy_sol_types::sol_data::#name<#size> }
63            }
64
65            Type::Tuple(ref tuple) => {
66                return tuple.paren_token.surround(tokens, |tokens| {
67                    for pair in tuple.types.pairs() {
68                        let (ty, comma) = pair.into_tuple();
69                        self.expand_type_to(ty, tokens);
70                        comma.to_tokens(tokens);
71                    }
72                })
73            }
74            Type::Array(ref array) => {
75                let ty = self.expand_type(&array.ty);
76                let span = array.span();
77                if let Some(size) = self.eval_array_size(array) {
78                    quote_spanned! {span=> #alloy_sol_types::sol_data::FixedArray<#ty, #size> }
79                } else {
80                    quote_spanned! {span=> #alloy_sol_types::sol_data::Array<#ty> }
81                }
82            }
83            Type::Function(ref function) => quote_spanned! {function.span()=>
84                #alloy_sol_types::sol_data::Function
85            },
86            Type::Mapping(ref mapping) => quote_spanned! {mapping.span()=>
87                ::core::compile_error!("Mapping types are not supported here")
88            },
89
90            Type::Custom(ref custom) => {
91                if let Some(Item::Contract(c)) = self.try_item(custom) {
92                    quote_spanned! {c.span()=> #alloy_sol_types::sol_data::Address }
93                } else {
94                    let segments = custom.iter();
95                    quote_spanned! {custom.span()=> #(#segments)::* }
96                }
97            }
98        };
99        tokens.extend(tts);
100    }
101
102    // IMPORTANT: Keep in sync with `sol-types/src/types/data_type.rs`
103    /// Expands a single [`Type`] recursively to its Rust type equivalent into the given buffer.
104    ///
105    /// See [`expand_rust_type`](Self::expand_rust_type) for more information.
106    pub(crate) fn expand_rust_type_to(&self, ty: &Type, tokens: &mut TokenStream) {
107        let alloy_sol_types = &self.crates.sol_types;
108        let tts = match *ty {
109            Type::Address(span, _) => quote_spanned! {span=> #alloy_sol_types::private::Address },
110            Type::Bool(span) => return Ident::new("bool", span).to_tokens(tokens),
111            Type::String(span) => quote_spanned! {span=> #alloy_sol_types::private::String },
112            Type::Bytes(span) => quote_spanned! {span=> #alloy_sol_types::private::Bytes },
113
114            Type::FixedBytes(span, size) => {
115                assert!(size.get() <= 32);
116                let size = Literal::u16_unsuffixed(size.get());
117                quote_spanned! {span=> #alloy_sol_types::private::FixedBytes<#size> }
118            }
119            Type::Int(span, size) | Type::Uint(span, size) => {
120                let size = size.map_or(256, NonZeroU16::get);
121                let primitive = matches!(size, 8 | 16 | 32 | 64 | 128);
122                if primitive {
123                    let prefix = match ty {
124                        Type::Int(..) => "i",
125                        Type::Uint(..) => "u",
126                        _ => unreachable!(),
127                    };
128                    return Ident::new(&format!("{prefix}{size}"), span).to_tokens(tokens);
129                }
130                let prefix = match ty {
131                    Type::Int(..) => "I",
132                    Type::Uint(..) => "U",
133                    _ => unreachable!(),
134                };
135                let name = Ident::new(&format!("{prefix}{size}"), span);
136                quote_spanned! {span=> #alloy_sol_types::private::primitives::aliases::#name }
137            }
138
139            Type::Tuple(ref tuple) => {
140                return tuple.paren_token.surround(tokens, |tokens| {
141                    for pair in tuple.types.pairs() {
142                        let (ty, comma) = pair.into_tuple();
143                        self.expand_rust_type_to(ty, tokens);
144                        comma.to_tokens(tokens);
145                    }
146                })
147            }
148            Type::Array(ref array) => {
149                let ty = self.expand_rust_type(&array.ty);
150                let span = array.span();
151                if let Some(size) = self.eval_array_size(array) {
152                    quote_spanned! {span=> [#ty; #size] }
153                } else {
154                    quote_spanned! {span=> #alloy_sol_types::private::Vec<#ty> }
155                }
156            }
157            Type::Function(ref function) => quote_spanned! {function.span()=>
158                #alloy_sol_types::private::Function
159            },
160            Type::Mapping(ref mapping) => quote_spanned! {mapping.span()=>
161                ::core::compile_error!("Mapping types are not supported here")
162            },
163
164            // Exhaustive fallback to `SolType::RustType`
165            Type::Custom(_) => {
166                let span = ty.span();
167                let ty = self.expand_type(ty);
168                quote_spanned! {span=> <#ty as #alloy_sol_types::SolType>::RustType }
169            }
170        };
171        tokens.extend(tts);
172    }
173
174    /// Calculates the base ABI-encoded size of the given parameters in bytes.
175    ///
176    /// See [`type_base_data_size`] for more information.
177    pub(crate) fn params_base_data_size<P>(&self, params: &Parameters<P>) -> usize {
178        params.iter().map(|param| self.type_base_data_size(&param.ty)).sum()
179    }
180
181    /// Recursively calculates the base ABI-encoded size of the given parameter
182    /// in bytes.
183    ///
184    /// That is, the minimum number of bytes required to encode `self` without
185    /// any dynamic data.
186    pub(crate) fn type_base_data_size(&self, ty: &Type) -> usize {
187        match ty {
188            // static types: 1 word
189            Type::Address(..)
190            | Type::Bool(_)
191            | Type::Int(..)
192            | Type::Uint(..)
193            | Type::FixedBytes(..)
194            | Type::Function(_) => 32,
195
196            // dynamic types: 1 offset word, 1 length word
197            Type::String(_) | Type::Bytes(_) | Type::Array(TypeArray { size: None, .. }) => 64,
198
199            // fixed array: size * encoded size
200            Type::Array(a @ TypeArray { ty: inner, size: Some(_), .. }) => {
201                let Some(size) = self.eval_array_size(a) else { return 0 };
202                self.type_base_data_size(inner).checked_mul(size).unwrap_or(0)
203            }
204
205            // tuple: sum of encoded sizes
206            Type::Tuple(tuple) => tuple.types.iter().map(|ty| self.type_base_data_size(ty)).sum(),
207
208            Type::Custom(name) => match self.try_item(name) {
209                Some(Item::Contract(_)) | Some(Item::Enum(_)) => 32,
210                Some(Item::Error(error)) => {
211                    error.parameters.types().map(|ty| self.type_base_data_size(ty)).sum()
212                }
213                Some(Item::Event(event)) => {
214                    event.parameters.iter().map(|p| self.type_base_data_size(&p.ty)).sum()
215                }
216                Some(Item::Struct(strukt)) => {
217                    strukt.fields.types().map(|ty| self.type_base_data_size(ty)).sum()
218                }
219                Some(Item::Udt(udt)) => self.type_base_data_size(&udt.ty),
220                Some(item) => abort!(item.span(), "Invalid type in struct field: {:?}", item),
221                None => 0,
222            },
223
224            // not applicable
225            Type::Mapping(_) => 0,
226        }
227    }
228
229    /// Returns whether the given type can derive the [`Default`] trait.
230    pub(crate) fn can_derive_default(&self, ty: &Type) -> bool {
231        match ty {
232            Type::Array(a) => {
233                self.eval_array_size(a).map_or(true, |sz| sz <= MAX_SUPPORTED_ARRAY_LEN)
234                    && self.can_derive_default(&a.ty)
235            }
236            Type::Tuple(tuple) => {
237                if tuple.types.len() > MAX_SUPPORTED_TUPLE_LEN {
238                    false
239                } else {
240                    tuple.types.iter().all(|ty| self.can_derive_default(ty))
241                }
242            }
243
244            Type::Custom(name) => match self.try_item(name) {
245                Some(Item::Contract(_)) => true,
246                Some(Item::Enum(_)) => false,
247                Some(Item::Error(error)) => {
248                    error.parameters.types().all(|ty| self.can_derive_default(ty))
249                }
250                Some(Item::Event(event)) => {
251                    event.parameters.iter().all(|p| self.can_derive_default(&p.ty))
252                }
253                Some(Item::Struct(strukt)) => {
254                    strukt.fields.types().all(|ty| self.can_derive_default(ty))
255                }
256                Some(Item::Udt(udt)) => self.can_derive_default(&udt.ty),
257                Some(item) => abort!(item.span(), "Invalid type in struct field: {:?}", item),
258                _ => false,
259            },
260
261            _ => true,
262        }
263    }
264
265    /// Returns whether the given type can derive the builtin traits listed in
266    /// `ExprCtxt::derives`, minus `Default`.
267    pub(crate) fn can_derive_builtin_traits(&self, ty: &Type) -> bool {
268        match ty {
269            Type::Array(a) => self.can_derive_builtin_traits(&a.ty),
270            Type::Tuple(tuple) => {
271                if tuple.types.len() > MAX_SUPPORTED_TUPLE_LEN {
272                    false
273                } else {
274                    tuple.types.iter().all(|ty| self.can_derive_builtin_traits(ty))
275                }
276            }
277
278            Type::Custom(name) => match self.try_item(name) {
279                Some(Item::Contract(_)) | Some(Item::Enum(_)) => true,
280                Some(Item::Error(error)) => {
281                    error.parameters.types().all(|ty| self.can_derive_builtin_traits(ty))
282                }
283                Some(Item::Event(event)) => {
284                    event.parameters.iter().all(|p| self.can_derive_builtin_traits(&p.ty))
285                }
286                Some(Item::Struct(strukt)) => {
287                    strukt.fields.types().all(|ty| self.can_derive_builtin_traits(ty))
288                }
289                Some(Item::Udt(udt)) => self.can_derive_builtin_traits(&udt.ty),
290                Some(item) => abort!(item.span(), "Invalid type in struct field: {:?}", item),
291                _ => false,
292            },
293
294            _ => true,
295        }
296    }
297
298    /// Evaluates the size of the given array type.
299    pub fn eval_array_size(&self, array: &TypeArray) -> Option<ArraySize> {
300        let size = array.size.as_deref()?;
301        ArraySizeEvaluator::new(self).eval(size)
302    }
303}
304
305type ArraySize = usize;
306
307struct ArraySizeEvaluator<'a> {
308    cx: &'a ExpCtxt<'a>,
309    depth: usize,
310}
311
312impl<'a> ArraySizeEvaluator<'a> {
313    fn new(cx: &'a ExpCtxt<'a>) -> Self {
314        Self { cx, depth: 0 }
315    }
316
317    fn eval(&mut self, expr: &ast::Expr) -> Option<ArraySize> {
318        match self.try_eval(expr) {
319            Ok(value) => Some(value),
320            Err(err) => {
321                emit_error!(
322                    expr.span(), "evaluation of constant value failed";
323                    note = err.span() => err.kind.msg()
324                );
325                None
326            }
327        }
328    }
329
330    fn try_eval(&mut self, expr: &ast::Expr) -> Result<ArraySize, EvalError> {
331        self.depth += 1;
332        if self.depth > 32 {
333            return Err(EvalErrorKind::RecursionLimitReached.spanned(expr.span()));
334        }
335        let mut r = self.try_eval_expr(expr);
336        if let Err(e) = &mut r {
337            if e.span.is_none() {
338                e.span = Some(expr.span());
339            }
340        }
341        self.depth -= 1;
342        r
343    }
344
345    fn try_eval_expr(&mut self, expr: &ast::Expr) -> Result<ArraySize, EvalError> {
346        let expr = expr.peel_parens();
347        match expr {
348            ast::Expr::Lit(ast::Lit::Number(ast::LitNumber::Int(n))) => {
349                n.base10_digits().parse::<ArraySize>().map_err(|_| EE::ParseInt.into())
350            }
351            ast::Expr::Binary(bin) => {
352                let lhs = self.try_eval(&bin.left)?;
353                let rhs = self.try_eval(&bin.right)?;
354                self.eval_binop(bin.op, lhs, rhs)
355            }
356            ast::Expr::Ident(ident) => {
357                let name = ast::sol_path![ident.clone()];
358                let Some(item) = self.cx.try_item(&name) else {
359                    eprintln!("{}", std::backtrace::Backtrace::force_capture());
360                    eprintln!("{:#?}", self.cx.all_items);
361                    return Err(EE::CouldNotResolve.into());
362                };
363                let ast::Item::Variable(var) = item else {
364                    return Err(EE::NonConstantVar.into());
365                };
366                if !var.attributes.has_constant() {
367                    return Err(EE::NonConstantVar.into());
368                }
369                let Some((_, expr)) = var.initializer.as_ref() else {
370                    return Err(EE::NonConstantVar.into());
371                };
372                self.try_eval(expr)
373            }
374            ast::Expr::LitDenominated(ast::LitDenominated {
375                number: ast::LitNumber::Int(n),
376                denom,
377            }) => {
378                let n = n.base10_digits().parse::<ArraySize>().map_err(|_| EE::ParseInt)?;
379                let Ok(denom) = denom.value().try_into() else {
380                    return Err(EE::IntTooBig.into());
381                };
382                n.checked_mul(denom).ok_or_else(|| EE::ArithmeticOverflow.into())
383            }
384            ast::Expr::Unary(unary) => {
385                let value = self.try_eval(&unary.expr)?;
386                self.eval_unop(unary.op, value)
387            }
388            _ => Err(EE::UnsupportedExpr.into()),
389        }
390    }
391
392    fn eval_binop(
393        &mut self,
394        bin: ast::BinOp,
395        lhs: ArraySize,
396        rhs: ArraySize,
397    ) -> Result<ArraySize, EvalError> {
398        match bin {
399            ast::BinOp::Shr(..) => rhs
400                .try_into()
401                .ok()
402                .and_then(|rhs| lhs.checked_shr(rhs))
403                .ok_or_else(|| EE::ArithmeticOverflow.into()),
404            ast::BinOp::Shl(..) => rhs
405                .try_into()
406                .ok()
407                .and_then(|rhs| lhs.checked_shl(rhs))
408                .ok_or_else(|| EE::ArithmeticOverflow.into()),
409            ast::BinOp::BitAnd(..) => Ok(lhs & rhs),
410            ast::BinOp::BitOr(..) => Ok(lhs | rhs),
411            ast::BinOp::BitXor(..) => Ok(lhs ^ rhs),
412            ast::BinOp::Add(..) => {
413                lhs.checked_add(rhs).ok_or_else(|| EE::ArithmeticOverflow.into())
414            }
415            ast::BinOp::Sub(..) => {
416                lhs.checked_sub(rhs).ok_or_else(|| EE::ArithmeticOverflow.into())
417            }
418            ast::BinOp::Pow(..) => rhs
419                .try_into()
420                .ok()
421                .and_then(|rhs| lhs.checked_pow(rhs))
422                .ok_or_else(|| EE::ArithmeticOverflow.into()),
423            ast::BinOp::Mul(..) => {
424                lhs.checked_mul(rhs).ok_or_else(|| EE::ArithmeticOverflow.into())
425            }
426            ast::BinOp::Div(..) => lhs.checked_div(rhs).ok_or_else(|| EE::DivisionByZero.into()),
427            ast::BinOp::Rem(..) => lhs.checked_div(rhs).ok_or_else(|| EE::DivisionByZero.into()),
428            _ => Err(EE::UnsupportedExpr.into()),
429        }
430    }
431
432    fn eval_unop(&mut self, unop: ast::UnOp, value: ArraySize) -> Result<ArraySize, EvalError> {
433        match unop {
434            ast::UnOp::Neg(..) => value.checked_neg().ok_or_else(|| EE::ArithmeticOverflow.into()),
435            ast::UnOp::BitNot(..) | ast::UnOp::Not(..) => Ok(!value),
436            _ => Err(EE::UnsupportedUnaryOp.into()),
437        }
438    }
439}
440
441struct EvalError {
442    kind: EvalErrorKind,
443    span: Option<Span>,
444}
445
446impl From<EvalErrorKind> for EvalError {
447    fn from(kind: EvalErrorKind) -> Self {
448        Self { kind, span: None }
449    }
450}
451
452impl EvalError {
453    fn span(&self) -> Span {
454        self.span.unwrap_or_else(Span::call_site)
455    }
456}
457
458enum EvalErrorKind {
459    RecursionLimitReached,
460    ArithmeticOverflow,
461    ParseInt,
462    IntTooBig,
463    DivisionByZero,
464    UnsupportedUnaryOp,
465    UnsupportedExpr,
466    CouldNotResolve,
467    NonConstantVar,
468}
469use EvalErrorKind as EE;
470
471impl EvalErrorKind {
472    fn spanned(self, span: Span) -> EvalError {
473        EvalError { kind: self, span: Some(span) }
474    }
475
476    fn msg(&self) -> &'static str {
477        match self {
478            Self::RecursionLimitReached => "recursion limit reached",
479            Self::ArithmeticOverflow => "arithmetic overflow",
480            Self::ParseInt => "failed to parse integer",
481            Self::IntTooBig => "integer value is too big",
482            Self::DivisionByZero => "division by zero",
483            Self::UnsupportedUnaryOp => "unsupported unary operation",
484            Self::UnsupportedExpr => "unsupported expression",
485            Self::CouldNotResolve => "could not resolve identifier",
486            Self::NonConstantVar => "only constant variables are allowed",
487        }
488    }
489}
490
491/// Implements [`fmt::Display`] which formats a [`Type`] to its canonical
492/// representation. This is then used in function, error, and event selector
493/// generation.
494pub(crate) struct TypePrinter<'ast> {
495    cx: &'ast ExpCtxt<'ast>,
496    ty: &'ast Type,
497}
498
499impl<'ast> TypePrinter<'ast> {
500    pub(crate) fn new(cx: &'ast ExpCtxt<'ast>, ty: &'ast Type) -> Self {
501        Self { cx, ty }
502    }
503}
504
505impl fmt::Display for TypePrinter<'_> {
506    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507        match self.ty {
508            Type::Int(_, None) => f.write_str("int256"),
509            Type::Uint(_, None) => f.write_str("uint256"),
510
511            Type::Array(array) => {
512                Self::new(self.cx, &array.ty).fmt(f)?;
513                f.write_str("[")?;
514                if let Some(size) = self.cx.eval_array_size(array) {
515                    size.fmt(f)?;
516                }
517                f.write_str("]")
518            }
519            Type::Tuple(tuple) => {
520                f.write_str("(")?;
521                for (i, ty) in tuple.types.iter().enumerate() {
522                    if i > 0 {
523                        f.write_str(",")?;
524                    }
525                    Self::new(self.cx, ty).fmt(f)?;
526                }
527                f.write_str(")")
528            }
529
530            Type::Custom(name) => Self::new(self.cx, self.cx.custom_type(name)).fmt(f),
531
532            ty => ty.fmt(f),
533        }
534    }
535}