alloy_sol_macro_expander/expand/
contract.rs

1//! [`ItemContract`] expansion.
2
3use super::{anon_name, ExpCtxt};
4use crate::utils::ExprArray;
5use alloy_sol_macro_input::{docs_str, mk_doc, ContainsSolAttrs};
6use ast::{Item, ItemContract, ItemError, ItemEvent, ItemFunction, SolIdent, Spanned};
7use heck::ToSnakeCase;
8use proc_macro2::{Ident, TokenStream};
9use quote::{format_ident, quote};
10use syn::{parse_quote, Attribute, Result};
11
12/// Expands an [`ItemContract`]:
13///
14/// ```ignore (pseudo-code)
15/// pub mod #name {
16///     #(#items)*
17///
18///     pub enum #{name}Calls {
19///         ...
20///    }
21///
22///     pub enum #{name}Errors {
23///         ...
24///    }
25///
26///     pub enum #{name}Events {
27///         ...
28///    }
29/// }
30/// ```
31pub(super) fn expand(cx: &mut ExpCtxt<'_>, contract: &ItemContract) -> Result<TokenStream> {
32    let ItemContract { name, body, .. } = contract;
33
34    let (sol_attrs, attrs) = contract.split_attrs()?;
35
36    let extra_methods = sol_attrs.extra_methods.or(cx.attrs.extra_methods).unwrap_or(false);
37    let rpc = sol_attrs.rpc.or(cx.attrs.rpc).unwrap_or(false);
38    let abi = sol_attrs.abi.or(cx.attrs.abi).unwrap_or(false);
39    let docs = sol_attrs.docs.or(cx.attrs.docs).unwrap_or(true);
40
41    let bytecode = sol_attrs.bytecode.map(|lit| {
42        let name = Ident::new("BYTECODE", lit.span());
43        let hex = lit.value();
44        let bytes = hex::decode(&hex).unwrap();
45        let lit_bytes = proc_macro2::Literal::byte_string(&bytes).with_span(lit.span());
46        quote! {
47            /// The creation / init bytecode of the contract.
48            ///
49            /// ```text
50            #[doc = #hex]
51            /// ```
52            #[rustfmt::skip]
53            #[allow(clippy::all)]
54            pub static #name: alloy_sol_types::private::Bytes =
55                alloy_sol_types::private::Bytes::from_static(#lit_bytes);
56        }
57    });
58    let deployed_bytecode = sol_attrs.deployed_bytecode.map(|lit| {
59        let name = Ident::new("DEPLOYED_BYTECODE", lit.span());
60        let hex = lit.value();
61        let bytes = hex::decode(&hex).unwrap();
62        let lit_bytes = proc_macro2::Literal::byte_string(&bytes).with_span(lit.span());
63        quote! {
64            /// The runtime bytecode of the contract, as deployed on the network.
65            ///
66            /// ```text
67            #[doc = #hex]
68            /// ```
69            #[rustfmt::skip]
70            #[allow(clippy::all)]
71            pub static #name: alloy_sol_types::private::Bytes =
72                alloy_sol_types::private::Bytes::from_static(#lit_bytes);
73        }
74    });
75
76    let mut constructor = None;
77    let mut fallback = None;
78    let mut receive = None;
79    let mut functions = Vec::with_capacity(contract.body.len());
80    let mut errors = Vec::with_capacity(contract.body.len());
81    let mut events = Vec::with_capacity(contract.body.len());
82
83    let (mut mod_attrs, item_attrs) =
84        attrs.into_iter().partition::<Vec<_>, _>(|a| a.path().is_ident("doc"));
85    mod_attrs.extend(item_attrs.iter().filter(|a| !a.path().is_ident("derive")).cloned());
86
87    let mut item_tokens = TokenStream::new();
88    for item in body {
89        match item {
90            Item::Function(function) => match function.kind {
91                ast::FunctionKind::Function(_) if function.name.is_some() => {
92                    functions.push(function.clone());
93                }
94                ast::FunctionKind::Function(_) => {}
95                ast::FunctionKind::Modifier(_) => {}
96                ast::FunctionKind::Constructor(_) => {
97                    if constructor.is_none() {
98                        constructor = Some(function);
99                    } else {
100                        let msg = "duplicate constructor";
101                        return Err(syn::Error::new(function.span(), msg));
102                    }
103                }
104                ast::FunctionKind::Fallback(_) => {
105                    if fallback.is_none() {
106                        fallback = Some(function);
107                    } else {
108                        let msg = "duplicate fallback function";
109                        return Err(syn::Error::new(function.span(), msg));
110                    }
111                }
112                ast::FunctionKind::Receive(_) => {
113                    if receive.is_none() {
114                        receive = Some(function);
115                    } else {
116                        let msg = "duplicate receive function";
117                        return Err(syn::Error::new(function.span(), msg));
118                    }
119                }
120            },
121            Item::Error(error) => errors.push(error),
122            Item::Event(event) => events.push(event),
123            Item::Variable(var_def) => {
124                if let Some(function) = super::var_def::var_as_function(cx, var_def)? {
125                    functions.push(function);
126                }
127            }
128            _ => {}
129        }
130
131        if item.attrs().is_none() || item_attrs.is_empty() {
132            // avoid cloning item if we don't have to
133            item_tokens.extend(cx.expand_item(item)?);
134        } else {
135            // prepend `item_attrs` to `item.attrs`
136            let mut item = item.clone();
137            item.attrs_mut().expect("is_none checked above").splice(0..0, item_attrs.clone());
138            item_tokens.extend(cx.expand_item(&item)?);
139        }
140    }
141
142    let enum_expander = CallLikeExpander { cx, contract_name: name.clone(), extra_methods };
143    // Remove any `Default` derives.
144    let mut enum_attrs = item_attrs;
145    for attr in &mut enum_attrs {
146        if !attr.path().is_ident("derive") {
147            continue;
148        }
149
150        let derives = alloy_sol_macro_input::parse_derives(attr);
151        let mut derives = derives.into_iter().collect::<Vec<_>>();
152        if derives.is_empty() {
153            continue;
154        }
155
156        let len = derives.len();
157        derives.retain(|derive| !derive.is_ident("Default"));
158        if derives.len() == len {
159            continue;
160        }
161
162        attr.meta = parse_quote! { derive(#(#derives),*) };
163    }
164
165    let functions_enum = (!functions.is_empty()).then(|| {
166        let mut attrs = enum_attrs.clone();
167        let doc_str = format!("Container for all the [`{name}`](self) function calls.");
168        attrs.push(parse_quote!(#[doc = #doc_str]));
169        enum_expander.expand(ToExpand::Functions(&functions), attrs)
170    });
171
172    let errors_enum = (!errors.is_empty()).then(|| {
173        let mut attrs = enum_attrs.clone();
174        let doc_str = format!("Container for all the [`{name}`](self) custom errors.");
175        attrs.push(parse_quote!(#[doc = #doc_str]));
176        enum_expander.expand(ToExpand::Errors(&errors), attrs)
177    });
178
179    let events_enum = (!events.is_empty()).then(|| {
180        let mut attrs = enum_attrs;
181        let doc_str = format!("Container for all the [`{name}`](self) events.");
182        attrs.push(parse_quote!(#[doc = #doc_str]));
183        enum_expander.expand(ToExpand::Events(&events), attrs)
184    });
185
186    let mod_descr_doc = (docs && docs_str(&mod_attrs).trim().is_empty())
187        .then(|| mk_doc("Module containing a contract's types and functions."));
188    let mod_iface_doc = (docs && !docs_str(&mod_attrs).contains("```solidity\n"))
189        .then(|| mk_doc(format!("\n\n```solidity\n{contract}\n```")));
190
191    let abi = abi.then(|| {
192        if_json! {
193            use crate::verbatim::verbatim;
194            use super::to_abi;
195
196            let crates = &cx.crates;
197            let constructor = verbatim(&constructor.map(|x| to_abi::constructor(x, cx)), crates);
198            let fallback = verbatim(&fallback.map(|x| to_abi::fallback(x, cx)), crates);
199            let receive = verbatim(&receive.map(|x| to_abi::receive(x, cx)), crates);
200            let functions_map = to_abi::functions_map(&functions, cx);
201            let events_map = to_abi::events_map(&events, cx);
202            let errors_map = to_abi::errors_map(&errors, cx);
203            quote! {
204                /// Contains [dynamic ABI definitions](alloy_sol_types::private::alloy_json_abi) for [this contract](self).
205                pub mod abi {
206                    use super::*;
207                    use alloy_sol_types::private::{alloy_json_abi as json, BTreeMap, Vec};
208
209                    /// Returns the ABI for [this contract](super).
210                    pub fn contract() -> json::JsonAbi {
211                        json::JsonAbi {
212                            constructor: constructor(),
213                            fallback: fallback(),
214                            receive: receive(),
215                            functions: functions(),
216                            events: events(),
217                            errors: errors(),
218                        }
219                    }
220
221                    /// Returns the [`Constructor`](json::Constructor) of [this contract](super), if any.
222                    pub fn constructor() -> Option<json::Constructor> {
223                        #constructor
224                    }
225
226                    /// Returns the [`Fallback`](json::Fallback) function of [this contract](super), if any.
227                    pub fn fallback() -> Option<json::Fallback> {
228                        #fallback
229                    }
230
231                    /// Returns the [`Receive`](json::Receive) function of [this contract](super), if any.
232                    pub fn receive() -> Option<json::Receive> {
233                        #receive
234                    }
235
236                    /// Returns a map of all the [`Function`](json::Function)s of [this contract](super).
237                    pub fn functions() -> BTreeMap<String, Vec<json::Function>> {
238                        #functions_map
239                    }
240
241                    /// Returns a map of all the [`Event`](json::Event)s of [this contract](super).
242                    pub fn events() -> BTreeMap<String, Vec<json::Event>> {
243                        #events_map
244                    }
245
246                    /// Returns a map of all the [`Error`](json::Error)s of [this contract](super).
247                    pub fn errors() -> BTreeMap<String, Vec<json::Error>> {
248                        #errors_map
249                    }
250                }
251            }
252        }
253    });
254
255    let rpc = rpc.then(|| {
256        let contract_name = name;
257        let name = format_ident!("{contract_name}Instance");
258        let name_s = name.to_string();
259        let methods = functions.iter().map(|f| call_builder_method(f, cx));
260        let new_fn_doc = format!(
261            "Creates a new wrapper around an on-chain [`{contract_name}`](self) contract instance.\n\
262             \n\
263             See the [wrapper's documentation](`{name}`) for more details."
264        );
265        let struct_doc = format!(
266            "A [`{contract_name}`](self) instance.\n\
267             \n\
268             Contains type-safe methods for interacting with an on-chain instance of the\n\
269             [`{contract_name}`](self) contract located at a given `address`, using a given\n\
270             provider `P`.\n\
271             \n\
272             If the contract bytecode is available (see the [`sol!`](alloy_sol_types::sol!)\n\
273             documentation on how to provide it), the `deploy` and `deploy_builder` methods can\n\
274             be used to deploy a new instance of the contract.\n\
275             \n\
276             See the [module-level documentation](self) for all the available methods."
277        );
278        let (deploy_fn, deploy_method) = bytecode.is_some().then(|| {
279            let deploy_doc_str =
280                "Deploys this contract using the given `provider` and constructor arguments, if any.\n\
281                 \n\
282                 Returns a new instance of the contract, if the deployment was successful.\n\
283                 \n\
284                 For more fine-grained control over the deployment process, use [`deploy_builder`] instead.";
285            let deploy_doc = mk_doc(deploy_doc_str);
286
287            let deploy_builder_doc_str =
288                "Creates a `RawCallBuilder` for deploying this contract using the given `provider`\n\
289                 and constructor arguments, if any.\n\
290                 \n\
291                 This is a simple wrapper around creating a `RawCallBuilder` with the data set to\n\
292                 the bytecode concatenated with the constructor's ABI-encoded arguments.";
293            let deploy_builder_doc = mk_doc(deploy_builder_doc_str);
294
295            let (params, args) = constructor.and_then(|c| {
296                if c.parameters.is_empty() {
297                    return None;
298                }
299
300                let names1 = c.parameters.names().enumerate().map(anon_name);
301                let names2 = names1.clone();
302                let tys = c.parameters.types().map(|ty| {
303                    cx.expand_rust_type(ty)
304                });
305                Some((quote!(#(#names1: #tys),*), quote!(#(#names2,)*)))
306            }).unzip();
307            let deploy_builder_data = if matches!(constructor, Some(c) if !c.parameters.is_empty()) {
308                quote! {
309                    [
310                        &BYTECODE[..],
311                        &alloy_sol_types::SolConstructor::abi_encode(&constructorCall { #args })[..]
312                    ].concat().into()
313                }
314            } else {
315                quote! {
316                    ::core::clone::Clone::clone(&BYTECODE)
317                }
318            };
319
320            (
321                quote! {
322                    #deploy_doc
323                    #[inline]
324                    pub fn deploy<T: alloy_contract::private::Transport + ::core::clone::Clone, P: alloy_contract::private::Provider<T, N>, N: alloy_contract::private::Network>(provider: P, #params)
325                        -> impl ::core::future::Future<Output = alloy_contract::Result<#name<T, P, N>>>
326                    {
327                        #name::<T, P, N>::deploy(provider, #args)
328                    }
329
330                    #deploy_builder_doc
331                    #[inline]
332                    pub fn deploy_builder<T: alloy_contract::private::Transport + ::core::clone::Clone, P: alloy_contract::private::Provider<T, N>, N: alloy_contract::private::Network>(provider: P, #params)
333                        -> alloy_contract::RawCallBuilder<T, P, N>
334                    {
335                        #name::<T, P, N>::deploy_builder(provider, #args)
336                    }
337                },
338                quote! {
339                    #deploy_doc
340                    #[inline]
341                    pub async fn deploy(provider: P, #params)
342                        -> alloy_contract::Result<#name<T, P, N>>
343                    {
344                        let call_builder = Self::deploy_builder(provider, #args);
345                        let contract_address = call_builder.deploy().await?;
346                        Ok(Self::new(contract_address, call_builder.provider))
347                    }
348
349                    #deploy_builder_doc
350                    #[inline]
351                    pub fn deploy_builder(provider: P, #params)
352                        -> alloy_contract::RawCallBuilder<T, P, N>
353                    {
354                        alloy_contract::RawCallBuilder::new_raw_deploy(provider, #deploy_builder_data)
355                    }
356                },
357            )
358        }).unzip();
359
360        let filter_methods = events.iter().map(|&e| {
361            let event_name = cx.overloaded_name(e.into());
362            let name = format_ident!("{event_name}_filter");
363            let doc = format!(
364                "Creates a new event filter for the [`{event_name}`] event.",
365            );
366            quote! {
367                #[doc = #doc]
368                pub fn #name(&self) -> alloy_contract::Event<T, &P, #event_name, N> {
369                    self.event_filter::<#event_name>()
370                }
371            }
372        });
373
374        let alloy_contract = &cx.crates.contract;
375        let generics_t_p_n = quote!(<T: alloy_contract::private::Transport + ::core::clone::Clone, P: alloy_contract::private::Provider<T, N>, N: alloy_contract::private::Network>);
376
377        quote! {
378            use #alloy_contract as alloy_contract;
379
380            #[doc = #new_fn_doc]
381            #[inline]
382            pub const fn new #generics_t_p_n(
383                address: alloy_sol_types::private::Address,
384                provider: P,
385            ) -> #name<T, P, N> {
386                #name::<T, P, N>::new(address, provider)
387            }
388
389            #deploy_fn
390
391            #[doc = #struct_doc]
392            #[derive(Clone)]
393            pub struct #name<T, P, N = alloy_contract::private::Ethereum> {
394                address: alloy_sol_types::private::Address,
395                provider: P,
396                _network_transport: ::core::marker::PhantomData<(N, T)>,
397            }
398
399            #[automatically_derived]
400            impl<T, P, N> ::core::fmt::Debug for #name<T, P, N> {
401                #[inline]
402                fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
403                    f.debug_tuple(#name_s).field(&self.address).finish()
404                }
405            }
406
407            /// Instantiation and getters/setters.
408            #[automatically_derived]
409            impl #generics_t_p_n #name<T, P, N> {
410                #[doc = #new_fn_doc]
411                #[inline]
412                pub const fn new(address: alloy_sol_types::private::Address, provider: P) -> Self {
413                    Self { address, provider, _network_transport: ::core::marker::PhantomData }
414                }
415
416                #deploy_method
417
418                /// Returns a reference to the address.
419                #[inline]
420                pub const fn address(&self) -> &alloy_sol_types::private::Address {
421                    &self.address
422                }
423
424                /// Sets the address.
425                #[inline]
426                pub fn set_address(&mut self, address: alloy_sol_types::private::Address) {
427                    self.address = address;
428                }
429
430                /// Sets the address and returns `self`.
431                pub fn at(mut self, address: alloy_sol_types::private::Address) -> Self {
432                    self.set_address(address);
433                    self
434                }
435
436                /// Returns a reference to the provider.
437                #[inline]
438                pub const fn provider(&self) -> &P {
439                    &self.provider
440                }
441            }
442
443            impl<T, P: ::core::clone::Clone, N> #name<T, &P, N> {
444                /// Clones the provider and returns a new instance with the cloned provider.
445                #[inline]
446                pub fn with_cloned_provider(self) -> #name<T, P, N> {
447                    #name { address: self.address, provider: ::core::clone::Clone::clone(&self.provider), _network_transport: ::core::marker::PhantomData }
448                }
449            }
450
451            /// Function calls.
452            #[automatically_derived]
453            impl #generics_t_p_n #name<T, P, N> {
454                /// Creates a new call builder using this contract instance's provider and address.
455                ///
456                /// Note that the call can be any function call, not just those defined in this
457                /// contract. Prefer using the other methods for building type-safe contract calls.
458                pub fn call_builder<C: alloy_sol_types::SolCall>(&self, call: &C)
459                    -> alloy_contract::SolCallBuilder<T, &P, C, N>
460                {
461                    alloy_contract::SolCallBuilder::new_sol(&self.provider, &self.address, call)
462                }
463
464                #(#methods)*
465            }
466
467            /// Event filters.
468            #[automatically_derived]
469            impl #generics_t_p_n #name<T, P, N> {
470                /// Creates a new event filter using this contract instance's provider and address.
471                ///
472                /// Note that the type can be any event, not just those defined in this contract.
473                /// Prefer using the other methods for building type-safe event filters.
474                pub fn event_filter<E: alloy_sol_types::SolEvent>(&self)
475                    -> alloy_contract::Event<T, &P, E, N>
476                {
477                    alloy_contract::Event::new_sol(&self.provider, &self.address)
478                }
479
480                #(#filter_methods)*
481            }
482        }
483    });
484
485    let alloy_sol_types = &cx.crates.sol_types;
486
487    let tokens = quote! {
488        #mod_descr_doc
489        #(#mod_attrs)*
490        #mod_iface_doc
491        #[allow(non_camel_case_types, non_snake_case, clippy::pub_underscore_fields, clippy::style, clippy::empty_structs_with_brackets)]
492        pub mod #name {
493            use super::*;
494            use #alloy_sol_types as alloy_sol_types;
495
496            #bytecode
497            #deployed_bytecode
498
499            #item_tokens
500
501            #functions_enum
502            #errors_enum
503            #events_enum
504
505            #abi
506
507            #rpc
508        }
509    };
510    Ok(tokens)
511}
512
513// note that item impls generated here do not need to be wrapped in an anonymous
514// constant (`const _: () = { ... };`) because they are in one already
515
516/// Expands a `SolInterface` enum:
517///
518/// ```ignore (pseudo-code)
519/// #name = #{contract_name}Calls | #{contract_name}Errors | #{contract_name}Events;
520///
521/// pub enum #name {
522///    #(#variants(#types),)*
523/// }
524///
525/// impl SolInterface for #name {
526///     ...
527/// }
528///
529/// impl #name {
530///     pub const SELECTORS: &'static [[u8; _]] = &[...];
531/// }
532///
533/// #if extra_methods
534/// #(
535///     impl From<#types> for #name { ... }
536///     impl TryFrom<#name> for #types { ... }
537/// )*
538///
539/// impl #name {
540///     #(
541///         pub fn #is_variant,#as_variant,#as_variant_mut(...) -> ... { ... }
542///     )*
543/// }
544/// #endif
545/// ```
546struct CallLikeExpander<'a> {
547    cx: &'a ExpCtxt<'a>,
548    contract_name: SolIdent,
549    extra_methods: bool,
550}
551
552#[derive(Clone, Debug)]
553struct ExpandData {
554    name: Ident,
555    variants: Vec<Ident>,
556    types: Option<Vec<Ident>>,
557    min_data_len: usize,
558    trait_: Ident,
559    selectors: Vec<ExprArray<u8>>,
560}
561
562impl ExpandData {
563    fn types(&self) -> &Vec<Ident> {
564        let types = self.types.as_ref().unwrap_or(&self.variants);
565        assert_eq!(types.len(), self.variants.len());
566        types
567    }
568
569    fn sort_by_selector(&mut self) {
570        let len = self.selectors.len();
571        if len <= 1 {
572            return;
573        }
574
575        let prev = self.selectors.clone();
576        self.selectors.sort_unstable();
577        // Arbitrary max length.
578        if len <= 20 && prev == self.selectors {
579            return;
580        }
581
582        let old_variants = self.variants.clone();
583        let old_types = self.types.clone();
584        let new_idxs =
585            prev.iter().map(|selector| self.selectors.iter().position(|s| s == selector).unwrap());
586        for (old, new) in new_idxs.enumerate() {
587            if old == new {
588                continue;
589            }
590
591            self.variants[new] = old_variants[old].clone();
592            if let Some(types) = self.types.as_mut() {
593                types[new] = old_types.as_ref().unwrap()[old].clone();
594            }
595        }
596    }
597}
598
599enum ToExpand<'a> {
600    Functions(&'a [ItemFunction]),
601    Errors(&'a [&'a ItemError]),
602    Events(&'a [&'a ItemEvent]),
603}
604
605impl ToExpand<'_> {
606    fn to_data(&self, expander: &CallLikeExpander<'_>) -> ExpandData {
607        let &CallLikeExpander { cx, ref contract_name, .. } = expander;
608        match self {
609            Self::Functions(functions) => {
610                let variants: Vec<_> =
611                    functions.iter().map(|f| cx.overloaded_name(f.into()).0).collect();
612
613                let types: Vec<_> = variants.iter().map(|name| cx.raw_call_name(name)).collect();
614
615                ExpandData {
616                    name: format_ident!("{contract_name}Calls"),
617                    variants,
618                    types: Some(types),
619                    min_data_len: functions
620                        .iter()
621                        .map(|function| cx.params_base_data_size(&function.parameters))
622                        .min()
623                        .unwrap(),
624                    trait_: format_ident!("SolCall"),
625                    selectors: functions.iter().map(|f| cx.function_selector(f)).collect(),
626                }
627            }
628
629            Self::Errors(errors) => ExpandData {
630                name: format_ident!("{contract_name}Errors"),
631                variants: errors.iter().map(|error| error.name.0.clone()).collect(),
632                types: None,
633                min_data_len: errors
634                    .iter()
635                    .map(|error| cx.params_base_data_size(&error.parameters))
636                    .min()
637                    .unwrap(),
638                trait_: format_ident!("SolError"),
639                selectors: errors.iter().map(|e| cx.error_selector(e)).collect(),
640            },
641
642            Self::Events(events) => {
643                let variants: Vec<_> =
644                    events.iter().map(|&event| cx.overloaded_name(event.into()).0).collect();
645
646                ExpandData {
647                    name: format_ident!("{contract_name}Events"),
648                    variants,
649                    types: None,
650                    min_data_len: events
651                        .iter()
652                        .map(|event| cx.params_base_data_size(&event.params()))
653                        .min()
654                        .unwrap(),
655                    trait_: format_ident!("SolEvent"),
656                    selectors: events.iter().map(|e| cx.event_selector(e)).collect(),
657                }
658            }
659        }
660    }
661}
662
663impl CallLikeExpander<'_> {
664    fn expand(&self, to_expand: ToExpand<'_>, attrs: Vec<Attribute>) -> TokenStream {
665        let data = &to_expand.to_data(self);
666
667        let mut sorted_data = data.clone();
668        sorted_data.sort_by_selector();
669        #[cfg(debug_assertions)]
670        for (i, sv) in sorted_data.variants.iter().enumerate() {
671            let s = &sorted_data.selectors[i];
672
673            let normal_pos = data.variants.iter().position(|v| v == sv).unwrap();
674            let ns = &data.selectors[normal_pos];
675            assert_eq!(s, ns);
676        }
677
678        if let ToExpand::Events(events) = to_expand {
679            return self.expand_events(events, data, &sorted_data, attrs);
680        }
681
682        let def = self.generate_enum(data, &sorted_data, attrs);
683        let ExpandData { name, variants, min_data_len, trait_, .. } = data;
684        let types = data.types();
685        let name_s = name.to_string();
686        let count = data.variants.len();
687
688        let sorted_variants = &sorted_data.variants;
689        let sorted_types = sorted_data.types();
690
691        quote! {
692            #def
693
694            #[automatically_derived]
695            impl alloy_sol_types::SolInterface for #name {
696                const NAME: &'static str = #name_s;
697                const MIN_DATA_LENGTH: usize = #min_data_len;
698                const COUNT: usize = #count;
699
700                #[inline]
701                fn selector(&self) -> [u8; 4] {
702                    match self {#(
703                        Self::#variants(_) => <#types as alloy_sol_types::#trait_>::SELECTOR,
704                    )*}
705                }
706
707                #[inline]
708                fn selector_at(i: usize) -> ::core::option::Option<[u8; 4]> {
709                    Self::SELECTORS.get(i).copied()
710                }
711
712                #[inline]
713                fn valid_selector(selector: [u8; 4]) -> bool {
714                    Self::SELECTORS.binary_search(&selector).is_ok()
715                }
716
717                #[inline]
718                #[allow(non_snake_case)]
719                fn abi_decode_raw(
720                    selector: [u8; 4],
721                    data: &[u8],
722                    validate: bool
723                )-> alloy_sol_types::Result<Self> {
724                    static DECODE_SHIMS: &[fn(&[u8], bool) -> alloy_sol_types::Result<#name>] = &[
725                        #({
726                            fn #sorted_variants(data: &[u8], validate: bool) -> alloy_sol_types::Result<#name> {
727                                <#sorted_types as alloy_sol_types::#trait_>::abi_decode_raw(data, validate)
728                                    .map(#name::#sorted_variants)
729                            }
730                            #sorted_variants
731                        }),*
732                    ];
733
734                    let Ok(idx) = Self::SELECTORS.binary_search(&selector) else {
735                        return Err(alloy_sol_types::Error::unknown_selector(
736                            <Self as alloy_sol_types::SolInterface>::NAME,
737                            selector,
738                        ));
739                    };
740                    // `SELECTORS` and `DECODE_SHIMS` have the same length and are sorted in the same order.
741                    DECODE_SHIMS[idx](data, validate)
742                }
743
744                #[inline]
745                fn abi_encoded_size(&self) -> usize {
746                    match self {#(
747                        Self::#variants(inner) =>
748                            <#types as alloy_sol_types::#trait_>::abi_encoded_size(inner),
749                    )*}
750                }
751
752                #[inline]
753                fn abi_encode_raw(&self, out: &mut alloy_sol_types::private::Vec<u8>) {
754                    match self {#(
755                        Self::#variants(inner) =>
756                            <#types as alloy_sol_types::#trait_>::abi_encode_raw(inner, out),
757                    )*}
758                }
759            }
760        }
761    }
762
763    fn expand_events(
764        &self,
765        events: &[&ItemEvent],
766        data: &ExpandData,
767        sorted_data: &ExpandData,
768        attrs: Vec<Attribute>,
769    ) -> TokenStream {
770        let def = self.generate_enum(data, sorted_data, attrs);
771        let ExpandData { name, trait_, .. } = data;
772        let name_s = name.to_string();
773        let count = data.variants.len();
774
775        let has_anon = events.iter().any(|e| e.is_anonymous());
776        let has_non_anon = events.iter().any(|e| !e.is_anonymous());
777        assert!(has_anon || has_non_anon, "events shouldn't be empty");
778
779        let e_name = |&e: &&ItemEvent| self.cx.overloaded_name(e.into());
780        let err = quote! {
781            alloy_sol_types::private::Err(alloy_sol_types::Error::InvalidLog {
782                name: <Self as alloy_sol_types::SolEventInterface>::NAME,
783                log: alloy_sol_types::private::Box::new(alloy_sol_types::private::LogData::new_unchecked(
784                    topics.to_vec(),
785                    data.to_vec().into(),
786                )),
787            })
788        };
789        let non_anon_impl = has_non_anon.then(|| {
790            let variants = events.iter().filter(|e| !e.is_anonymous()).map(e_name);
791            let ret = has_anon.then(|| quote!(return));
792            let ret_err = (!has_anon).then_some(&err);
793            quote! {
794                match topics.first().copied() {
795                    #(
796                        Some(<#variants as alloy_sol_types::#trait_>::SIGNATURE_HASH) =>
797                            #ret <#variants as alloy_sol_types::#trait_>::decode_raw_log(topics, data, validate)
798                                .map(Self::#variants),
799                    )*
800                    _ => { #ret_err }
801                }
802            }
803        });
804        let anon_impl = has_anon.then(|| {
805            let variants = events.iter().filter(|e| e.is_anonymous()).map(e_name);
806            quote! {
807                #(
808                    if let Ok(res) = <#variants as alloy_sol_types::#trait_>::decode_raw_log(topics, data, validate) {
809                        return Ok(Self::#variants(res));
810                    }
811                )*
812                #err
813            }
814        });
815
816        let into_impl = {
817            let variants = events.iter().map(e_name);
818            let v2 = variants.clone();
819            quote! {
820                #[automatically_derived]
821                impl alloy_sol_types::private::IntoLogData for #name {
822                    fn to_log_data(&self) -> alloy_sol_types::private::LogData {
823                        match self {#(
824                            Self::#variants(inner) =>
825                            alloy_sol_types::private::IntoLogData::to_log_data(inner),
826                        )*}
827                    }
828
829                    fn into_log_data(self) -> alloy_sol_types::private::LogData {
830                        match self {#(
831                            Self::#v2(inner) =>
832                            alloy_sol_types::private::IntoLogData::into_log_data(inner),
833                        )*}
834                    }
835                }
836            }
837        };
838
839        quote! {
840            #def
841
842            #[automatically_derived]
843            impl alloy_sol_types::SolEventInterface for #name {
844                const NAME: &'static str = #name_s;
845                const COUNT: usize = #count;
846
847                fn decode_raw_log(topics: &[alloy_sol_types::Word], data: &[u8], validate: bool) -> alloy_sol_types::Result<Self> {
848                    #non_anon_impl
849                    #anon_impl
850                }
851            }
852
853            #into_impl
854        }
855    }
856
857    fn generate_enum(
858        &self,
859        data: &ExpandData,
860        sorted_data: &ExpandData,
861        mut attrs: Vec<Attribute>,
862    ) -> TokenStream {
863        let ExpandData { name, variants, .. } = data;
864        let types = data.types();
865
866        let selectors = &sorted_data.selectors;
867
868        let selector_len = selectors.first().unwrap().array.len();
869        assert!(selectors.iter().all(|s| s.array.len() == selector_len));
870        let selector_type = quote!([u8; #selector_len]);
871
872        self.cx.type_derives(&mut attrs, types.iter().cloned().map(ast::Type::custom), false);
873
874        let mut tokens = quote! {
875            #(#attrs)*
876            pub enum #name {
877                #(
878                    #[allow(missing_docs)]
879                    #variants(#types),
880                )*
881            }
882
883            #[automatically_derived]
884            impl #name {
885                /// All the selectors of this enum.
886                ///
887                /// Note that the selectors might not be in the same order as the variants.
888                /// No guarantees are made about the order of the selectors.
889                ///
890                /// Prefer using `SolInterface` methods instead.
891                // NOTE: This is currently sorted to allow for binary search in `SolInterface`.
892                pub const SELECTORS: &'static [#selector_type] = &[#(#selectors),*];
893            }
894        };
895
896        if self.extra_methods {
897            let conversions =
898                variants.iter().zip(types).map(|(v, t)| generate_variant_conversions(name, v, t));
899            let methods = variants.iter().zip(types).map(generate_variant_methods);
900            tokens.extend(conversions);
901            tokens.extend(quote! {
902                #[automatically_derived]
903                impl #name {
904                    #(#methods)*
905                }
906            });
907        }
908
909        tokens
910    }
911}
912
913fn generate_variant_conversions(name: &Ident, variant: &Ident, ty: &Ident) -> TokenStream {
914    quote! {
915        #[automatically_derived]
916        impl ::core::convert::From<#ty> for #name {
917            #[inline]
918            fn from(value: #ty) -> Self {
919                Self::#variant(value)
920            }
921        }
922
923        #[automatically_derived]
924        impl ::core::convert::TryFrom<#name> for #ty {
925            type Error = #name;
926
927            #[inline]
928            fn try_from(value: #name) -> ::core::result::Result<Self, #name> {
929                match value {
930                    #name::#variant(value) => ::core::result::Result::Ok(value),
931                    _ => ::core::result::Result::Err(value),
932                }
933            }
934        }
935    }
936}
937
938fn generate_variant_methods((variant, ty): (&Ident, &Ident)) -> TokenStream {
939    let name_snake = snakify(&variant.to_string());
940
941    let is_variant = format_ident!("is_{name_snake}");
942    let is_variant_doc =
943        format!("Returns `true` if `self` matches [`{variant}`](Self::{variant}).");
944
945    let as_variant = format_ident!("as_{name_snake}");
946    let as_variant_doc = format!(
947        "Returns an immutable reference to the inner [`{ty}`] if `self` matches [`{variant}`](Self::{variant})."
948    );
949
950    let as_variant_mut = format_ident!("as_{name_snake}_mut");
951    let as_variant_mut_doc = format!(
952        "Returns a mutable reference to the inner [`{ty}`] if `self` matches [`{variant}`](Self::{variant})."
953    );
954
955    quote! {
956        #[doc = #is_variant_doc]
957        #[inline]
958        pub const fn #is_variant(&self) -> bool {
959            ::core::matches!(self, Self::#variant(_))
960        }
961
962        #[doc = #as_variant_doc]
963        #[inline]
964        pub const fn #as_variant(&self) -> ::core::option::Option<&#ty> {
965            match self {
966                Self::#variant(inner) => ::core::option::Option::Some(inner),
967                _ => ::core::option::Option::None,
968            }
969        }
970
971        #[doc = #as_variant_mut_doc]
972        #[inline]
973        pub fn #as_variant_mut(&mut self) -> ::core::option::Option<&mut #ty> {
974            match self {
975                Self::#variant(inner) => ::core::option::Option::Some(inner),
976                _ => ::core::option::Option::None,
977            }
978        }
979    }
980}
981
982fn call_builder_method(f: &ItemFunction, cx: &ExpCtxt<'_>) -> TokenStream {
983    let name = cx.function_name(f);
984    let call_name = cx.call_name(f);
985    let param_names1 = f.parameters.names().enumerate().map(anon_name);
986    let param_names2 = param_names1.clone();
987    let param_tys = f.parameters.types().map(|ty| cx.expand_rust_type(ty));
988    let doc = format!("Creates a new call builder for the [`{name}`] function.");
989    quote! {
990        #[doc = #doc]
991        pub fn #name(&self, #(#param_names1: #param_tys),*) -> alloy_contract::SolCallBuilder<T, &P, #call_name, N> {
992            self.call_builder(&#call_name { #(#param_names2),* })
993        }
994    }
995}
996
997/// `heck` doesn't treat numbers as new words, and discards leading underscores.
998fn snakify(s: &str) -> String {
999    let leading_n = s.chars().take_while(|c| *c == '_').count();
1000    let (leading, s) = s.split_at(leading_n);
1001    let mut output: Vec<char> = leading.chars().chain(s.to_snake_case().chars()).collect();
1002
1003    let mut num_starts = vec![];
1004    for (pos, c) in output.iter().enumerate() {
1005        if pos != 0
1006            && c.is_ascii_digit()
1007            && !output[pos - 1].is_ascii_digit()
1008            && !output[pos - 1].is_ascii_punctuation()
1009        {
1010            num_starts.push(pos);
1011        }
1012    }
1013    // need to do in reverse, because after inserting, all chars after the point of
1014    // insertion are off
1015    for i in num_starts.into_iter().rev() {
1016        output.insert(i, '_');
1017    }
1018    output.into_iter().collect()
1019}