1use super::ExpCtxt;
4use ast::{Item, Parameters, Spanned, Type, TypeArray};
5use proc_macro2::{Ident, Literal, Span, TokenStream};
6use proc_macro_error2::{abort, emit_error};
7use quote::{quote_spanned, ToTokens};
8use std::{fmt, num::NonZeroU16};
9
10const MAX_SUPPORTED_ARRAY_LEN: usize = 32;
11const MAX_SUPPORTED_TUPLE_LEN: usize = 12;
12
13impl ExpCtxt<'_> {
14 pub fn expand_type(&self, ty: &Type) -> TokenStream {
17 let mut tokens = TokenStream::new();
18 self.expand_type_to(ty, &mut tokens);
19 tokens
20 }
21
22 pub fn expand_rust_type(&self, ty: &Type) -> TokenStream {
28 let mut tokens = TokenStream::new();
29 self.expand_rust_type_to(ty, &mut tokens);
30 tokens
31 }
32
33 pub fn expand_type_to(&self, ty: &Type, tokens: &mut TokenStream) {
38 let alloy_sol_types = &self.crates.sol_types;
39 let tts = match *ty {
40 Type::Address(span, _) => quote_spanned! {span=> #alloy_sol_types::sol_data::Address },
41 Type::Bool(span) => quote_spanned! {span=> #alloy_sol_types::sol_data::Bool },
42 Type::String(span) => quote_spanned! {span=> #alloy_sol_types::sol_data::String },
43 Type::Bytes(span) => quote_spanned! {span=> #alloy_sol_types::sol_data::Bytes },
44
45 Type::FixedBytes(span, size) => {
46 assert!(size.get() <= 32);
47 let size = Literal::u16_unsuffixed(size.get());
48 quote_spanned! {span=> #alloy_sol_types::sol_data::FixedBytes<#size> }
49 }
50 Type::Int(span, size) | Type::Uint(span, size) => {
51 let name = match ty {
52 Type::Int(..) => "Int",
53 Type::Uint(..) => "Uint",
54 _ => unreachable!(),
55 };
56 let name = Ident::new(name, span);
57
58 let size = size.map_or(256, NonZeroU16::get);
59 assert!(size <= 256 && size % 8 == 0);
60 let size = Literal::u16_unsuffixed(size);
61
62 quote_spanned! {span=> #alloy_sol_types::sol_data::#name<#size> }
63 }
64
65 Type::Tuple(ref tuple) => {
66 return tuple.paren_token.surround(tokens, |tokens| {
67 for pair in tuple.types.pairs() {
68 let (ty, comma) = pair.into_tuple();
69 self.expand_type_to(ty, tokens);
70 comma.to_tokens(tokens);
71 }
72 })
73 }
74 Type::Array(ref array) => {
75 let ty = self.expand_type(&array.ty);
76 let span = array.span();
77 if let Some(size) = self.eval_array_size(array) {
78 quote_spanned! {span=> #alloy_sol_types::sol_data::FixedArray<#ty, #size> }
79 } else {
80 quote_spanned! {span=> #alloy_sol_types::sol_data::Array<#ty> }
81 }
82 }
83 Type::Function(ref function) => quote_spanned! {function.span()=>
84 #alloy_sol_types::sol_data::Function
85 },
86 Type::Mapping(ref mapping) => quote_spanned! {mapping.span()=>
87 ::core::compile_error!("Mapping types are not supported here")
88 },
89
90 Type::Custom(ref custom) => {
91 if let Some(Item::Contract(c)) = self.try_item(custom) {
92 quote_spanned! {c.span()=> #alloy_sol_types::sol_data::Address }
93 } else {
94 let segments = custom.iter();
95 quote_spanned! {custom.span()=> #(#segments)::* }
96 }
97 }
98 };
99 tokens.extend(tts);
100 }
101
102 pub(crate) fn expand_rust_type_to(&self, ty: &Type, tokens: &mut TokenStream) {
107 let alloy_sol_types = &self.crates.sol_types;
108 let tts = match *ty {
109 Type::Address(span, _) => quote_spanned! {span=> #alloy_sol_types::private::Address },
110 Type::Bool(span) => return Ident::new("bool", span).to_tokens(tokens),
111 Type::String(span) => quote_spanned! {span=> #alloy_sol_types::private::String },
112 Type::Bytes(span) => quote_spanned! {span=> #alloy_sol_types::private::Bytes },
113
114 Type::FixedBytes(span, size) => {
115 assert!(size.get() <= 32);
116 let size = Literal::u16_unsuffixed(size.get());
117 quote_spanned! {span=> #alloy_sol_types::private::FixedBytes<#size> }
118 }
119 Type::Int(span, size) | Type::Uint(span, size) => {
120 let size = size.map_or(256, NonZeroU16::get);
121 let primitive = matches!(size, 8 | 16 | 32 | 64 | 128);
122 if primitive {
123 let prefix = match ty {
124 Type::Int(..) => "i",
125 Type::Uint(..) => "u",
126 _ => unreachable!(),
127 };
128 return Ident::new(&format!("{prefix}{size}"), span).to_tokens(tokens);
129 }
130 let prefix = match ty {
131 Type::Int(..) => "I",
132 Type::Uint(..) => "U",
133 _ => unreachable!(),
134 };
135 let name = Ident::new(&format!("{prefix}{size}"), span);
136 quote_spanned! {span=> #alloy_sol_types::private::primitives::aliases::#name }
137 }
138
139 Type::Tuple(ref tuple) => {
140 return tuple.paren_token.surround(tokens, |tokens| {
141 for pair in tuple.types.pairs() {
142 let (ty, comma) = pair.into_tuple();
143 self.expand_rust_type_to(ty, tokens);
144 comma.to_tokens(tokens);
145 }
146 })
147 }
148 Type::Array(ref array) => {
149 let ty = self.expand_rust_type(&array.ty);
150 let span = array.span();
151 if let Some(size) = self.eval_array_size(array) {
152 quote_spanned! {span=> [#ty; #size] }
153 } else {
154 quote_spanned! {span=> #alloy_sol_types::private::Vec<#ty> }
155 }
156 }
157 Type::Function(ref function) => quote_spanned! {function.span()=>
158 #alloy_sol_types::private::Function
159 },
160 Type::Mapping(ref mapping) => quote_spanned! {mapping.span()=>
161 ::core::compile_error!("Mapping types are not supported here")
162 },
163
164 Type::Custom(_) => {
166 let span = ty.span();
167 let ty = self.expand_type(ty);
168 quote_spanned! {span=> <#ty as #alloy_sol_types::SolType>::RustType }
169 }
170 };
171 tokens.extend(tts);
172 }
173
174 pub(crate) fn params_base_data_size<P>(&self, params: &Parameters<P>) -> usize {
178 params.iter().map(|param| self.type_base_data_size(¶m.ty)).sum()
179 }
180
181 pub(crate) fn type_base_data_size(&self, ty: &Type) -> usize {
187 match ty {
188 Type::Address(..)
190 | Type::Bool(_)
191 | Type::Int(..)
192 | Type::Uint(..)
193 | Type::FixedBytes(..)
194 | Type::Function(_) => 32,
195
196 Type::String(_) | Type::Bytes(_) | Type::Array(TypeArray { size: None, .. }) => 64,
198
199 Type::Array(a @ TypeArray { ty: inner, size: Some(_), .. }) => {
201 let Some(size) = self.eval_array_size(a) else { return 0 };
202 self.type_base_data_size(inner).checked_mul(size).unwrap_or(0)
203 }
204
205 Type::Tuple(tuple) => tuple.types.iter().map(|ty| self.type_base_data_size(ty)).sum(),
207
208 Type::Custom(name) => match self.try_item(name) {
209 Some(Item::Contract(_)) | Some(Item::Enum(_)) => 32,
210 Some(Item::Error(error)) => {
211 error.parameters.types().map(|ty| self.type_base_data_size(ty)).sum()
212 }
213 Some(Item::Event(event)) => {
214 event.parameters.iter().map(|p| self.type_base_data_size(&p.ty)).sum()
215 }
216 Some(Item::Struct(strukt)) => {
217 strukt.fields.types().map(|ty| self.type_base_data_size(ty)).sum()
218 }
219 Some(Item::Udt(udt)) => self.type_base_data_size(&udt.ty),
220 Some(item) => abort!(item.span(), "Invalid type in struct field: {:?}", item),
221 None => 0,
222 },
223
224 Type::Mapping(_) => 0,
226 }
227 }
228
229 pub(crate) fn can_derive_default(&self, ty: &Type) -> bool {
231 match ty {
232 Type::Array(a) => {
233 self.eval_array_size(a).map_or(true, |sz| sz <= MAX_SUPPORTED_ARRAY_LEN)
234 && self.can_derive_default(&a.ty)
235 }
236 Type::Tuple(tuple) => {
237 if tuple.types.len() > MAX_SUPPORTED_TUPLE_LEN {
238 false
239 } else {
240 tuple.types.iter().all(|ty| self.can_derive_default(ty))
241 }
242 }
243
244 Type::Custom(name) => match self.try_item(name) {
245 Some(Item::Contract(_)) => true,
246 Some(Item::Enum(_)) => false,
247 Some(Item::Error(error)) => {
248 error.parameters.types().all(|ty| self.can_derive_default(ty))
249 }
250 Some(Item::Event(event)) => {
251 event.parameters.iter().all(|p| self.can_derive_default(&p.ty))
252 }
253 Some(Item::Struct(strukt)) => {
254 strukt.fields.types().all(|ty| self.can_derive_default(ty))
255 }
256 Some(Item::Udt(udt)) => self.can_derive_default(&udt.ty),
257 Some(item) => abort!(item.span(), "Invalid type in struct field: {:?}", item),
258 _ => false,
259 },
260
261 _ => true,
262 }
263 }
264
265 pub(crate) fn can_derive_builtin_traits(&self, ty: &Type) -> bool {
268 match ty {
269 Type::Array(a) => self.can_derive_builtin_traits(&a.ty),
270 Type::Tuple(tuple) => {
271 if tuple.types.len() > MAX_SUPPORTED_TUPLE_LEN {
272 false
273 } else {
274 tuple.types.iter().all(|ty| self.can_derive_builtin_traits(ty))
275 }
276 }
277
278 Type::Custom(name) => match self.try_item(name) {
279 Some(Item::Contract(_)) | Some(Item::Enum(_)) => true,
280 Some(Item::Error(error)) => {
281 error.parameters.types().all(|ty| self.can_derive_builtin_traits(ty))
282 }
283 Some(Item::Event(event)) => {
284 event.parameters.iter().all(|p| self.can_derive_builtin_traits(&p.ty))
285 }
286 Some(Item::Struct(strukt)) => {
287 strukt.fields.types().all(|ty| self.can_derive_builtin_traits(ty))
288 }
289 Some(Item::Udt(udt)) => self.can_derive_builtin_traits(&udt.ty),
290 Some(item) => abort!(item.span(), "Invalid type in struct field: {:?}", item),
291 _ => false,
292 },
293
294 _ => true,
295 }
296 }
297
298 pub fn eval_array_size(&self, array: &TypeArray) -> Option<ArraySize> {
300 let size = array.size.as_deref()?;
301 ArraySizeEvaluator::new(self).eval(size)
302 }
303}
304
305type ArraySize = usize;
306
307struct ArraySizeEvaluator<'a> {
308 cx: &'a ExpCtxt<'a>,
309 depth: usize,
310}
311
312impl<'a> ArraySizeEvaluator<'a> {
313 fn new(cx: &'a ExpCtxt<'a>) -> Self {
314 Self { cx, depth: 0 }
315 }
316
317 fn eval(&mut self, expr: &ast::Expr) -> Option<ArraySize> {
318 match self.try_eval(expr) {
319 Ok(value) => Some(value),
320 Err(err) => {
321 emit_error!(
322 expr.span(), "evaluation of constant value failed";
323 note = err.span() => err.kind.msg()
324 );
325 None
326 }
327 }
328 }
329
330 fn try_eval(&mut self, expr: &ast::Expr) -> Result<ArraySize, EvalError> {
331 self.depth += 1;
332 if self.depth > 32 {
333 return Err(EvalErrorKind::RecursionLimitReached.spanned(expr.span()));
334 }
335 let mut r = self.try_eval_expr(expr);
336 if let Err(e) = &mut r {
337 if e.span.is_none() {
338 e.span = Some(expr.span());
339 }
340 }
341 self.depth -= 1;
342 r
343 }
344
345 fn try_eval_expr(&mut self, expr: &ast::Expr) -> Result<ArraySize, EvalError> {
346 let expr = expr.peel_parens();
347 match expr {
348 ast::Expr::Lit(ast::Lit::Number(ast::LitNumber::Int(n))) => {
349 n.base10_digits().parse::<ArraySize>().map_err(|_| EE::ParseInt.into())
350 }
351 ast::Expr::Binary(bin) => {
352 let lhs = self.try_eval(&bin.left)?;
353 let rhs = self.try_eval(&bin.right)?;
354 self.eval_binop(bin.op, lhs, rhs)
355 }
356 ast::Expr::Ident(ident) => {
357 let name = ast::sol_path![ident.clone()];
358 let Some(item) = self.cx.try_item(&name) else {
359 eprintln!("{}", std::backtrace::Backtrace::force_capture());
360 eprintln!("{:#?}", self.cx.all_items);
361 return Err(EE::CouldNotResolve.into());
362 };
363 let ast::Item::Variable(var) = item else {
364 return Err(EE::NonConstantVar.into());
365 };
366 if !var.attributes.has_constant() {
367 return Err(EE::NonConstantVar.into());
368 }
369 let Some((_, expr)) = var.initializer.as_ref() else {
370 return Err(EE::NonConstantVar.into());
371 };
372 self.try_eval(expr)
373 }
374 ast::Expr::LitDenominated(ast::LitDenominated {
375 number: ast::LitNumber::Int(n),
376 denom,
377 }) => {
378 let n = n.base10_digits().parse::<ArraySize>().map_err(|_| EE::ParseInt)?;
379 let Ok(denom) = denom.value().try_into() else {
380 return Err(EE::IntTooBig.into());
381 };
382 n.checked_mul(denom).ok_or_else(|| EE::ArithmeticOverflow.into())
383 }
384 ast::Expr::Unary(unary) => {
385 let value = self.try_eval(&unary.expr)?;
386 self.eval_unop(unary.op, value)
387 }
388 _ => Err(EE::UnsupportedExpr.into()),
389 }
390 }
391
392 fn eval_binop(
393 &mut self,
394 bin: ast::BinOp,
395 lhs: ArraySize,
396 rhs: ArraySize,
397 ) -> Result<ArraySize, EvalError> {
398 match bin {
399 ast::BinOp::Shr(..) => rhs
400 .try_into()
401 .ok()
402 .and_then(|rhs| lhs.checked_shr(rhs))
403 .ok_or_else(|| EE::ArithmeticOverflow.into()),
404 ast::BinOp::Shl(..) => rhs
405 .try_into()
406 .ok()
407 .and_then(|rhs| lhs.checked_shl(rhs))
408 .ok_or_else(|| EE::ArithmeticOverflow.into()),
409 ast::BinOp::BitAnd(..) => Ok(lhs & rhs),
410 ast::BinOp::BitOr(..) => Ok(lhs | rhs),
411 ast::BinOp::BitXor(..) => Ok(lhs ^ rhs),
412 ast::BinOp::Add(..) => {
413 lhs.checked_add(rhs).ok_or_else(|| EE::ArithmeticOverflow.into())
414 }
415 ast::BinOp::Sub(..) => {
416 lhs.checked_sub(rhs).ok_or_else(|| EE::ArithmeticOverflow.into())
417 }
418 ast::BinOp::Pow(..) => rhs
419 .try_into()
420 .ok()
421 .and_then(|rhs| lhs.checked_pow(rhs))
422 .ok_or_else(|| EE::ArithmeticOverflow.into()),
423 ast::BinOp::Mul(..) => {
424 lhs.checked_mul(rhs).ok_or_else(|| EE::ArithmeticOverflow.into())
425 }
426 ast::BinOp::Div(..) => lhs.checked_div(rhs).ok_or_else(|| EE::DivisionByZero.into()),
427 ast::BinOp::Rem(..) => lhs.checked_div(rhs).ok_or_else(|| EE::DivisionByZero.into()),
428 _ => Err(EE::UnsupportedExpr.into()),
429 }
430 }
431
432 fn eval_unop(&mut self, unop: ast::UnOp, value: ArraySize) -> Result<ArraySize, EvalError> {
433 match unop {
434 ast::UnOp::Neg(..) => value.checked_neg().ok_or_else(|| EE::ArithmeticOverflow.into()),
435 ast::UnOp::BitNot(..) | ast::UnOp::Not(..) => Ok(!value),
436 _ => Err(EE::UnsupportedUnaryOp.into()),
437 }
438 }
439}
440
441struct EvalError {
442 kind: EvalErrorKind,
443 span: Option<Span>,
444}
445
446impl From<EvalErrorKind> for EvalError {
447 fn from(kind: EvalErrorKind) -> Self {
448 Self { kind, span: None }
449 }
450}
451
452impl EvalError {
453 fn span(&self) -> Span {
454 self.span.unwrap_or_else(Span::call_site)
455 }
456}
457
458enum EvalErrorKind {
459 RecursionLimitReached,
460 ArithmeticOverflow,
461 ParseInt,
462 IntTooBig,
463 DivisionByZero,
464 UnsupportedUnaryOp,
465 UnsupportedExpr,
466 CouldNotResolve,
467 NonConstantVar,
468}
469use EvalErrorKind as EE;
470
471impl EvalErrorKind {
472 fn spanned(self, span: Span) -> EvalError {
473 EvalError { kind: self, span: Some(span) }
474 }
475
476 fn msg(&self) -> &'static str {
477 match self {
478 Self::RecursionLimitReached => "recursion limit reached",
479 Self::ArithmeticOverflow => "arithmetic overflow",
480 Self::ParseInt => "failed to parse integer",
481 Self::IntTooBig => "integer value is too big",
482 Self::DivisionByZero => "division by zero",
483 Self::UnsupportedUnaryOp => "unsupported unary operation",
484 Self::UnsupportedExpr => "unsupported expression",
485 Self::CouldNotResolve => "could not resolve identifier",
486 Self::NonConstantVar => "only constant variables are allowed",
487 }
488 }
489}
490
491pub(crate) struct TypePrinter<'ast> {
495 cx: &'ast ExpCtxt<'ast>,
496 ty: &'ast Type,
497}
498
499impl<'ast> TypePrinter<'ast> {
500 pub(crate) fn new(cx: &'ast ExpCtxt<'ast>, ty: &'ast Type) -> Self {
501 Self { cx, ty }
502 }
503}
504
505impl fmt::Display for TypePrinter<'_> {
506 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507 match self.ty {
508 Type::Int(_, None) => f.write_str("int256"),
509 Type::Uint(_, None) => f.write_str("uint256"),
510
511 Type::Array(array) => {
512 Self::new(self.cx, &array.ty).fmt(f)?;
513 f.write_str("[")?;
514 if let Some(size) = self.cx.eval_array_size(array) {
515 size.fmt(f)?;
516 }
517 f.write_str("]")
518 }
519 Type::Tuple(tuple) => {
520 f.write_str("(")?;
521 for (i, ty) in tuple.types.iter().enumerate() {
522 if i > 0 {
523 f.write_str(",")?;
524 }
525 Self::new(self.cx, ty).fmt(f)?;
526 }
527 f.write_str(")")
528 }
529
530 Type::Custom(name) => Self::new(self.cx, self.cx.custom_type(name)).fmt(f),
531
532 ty => ty.fmt(f),
533 }
534 }
535}