alloy_sol_types/types/interface/
mod.rs

1use crate::{alloc::string::ToString, Error, Panic, Result, Revert, SolError};
2use alloc::{string::String, vec::Vec};
3use core::{convert::Infallible, fmt, iter::FusedIterator, marker::PhantomData};
4
5mod event;
6pub use event::SolEventInterface;
7
8/// A collection of ABI-encodable call-like types. This currently includes
9/// [`SolCall`] and [`SolError`].
10///
11/// This trait assumes that the implementing type always has a selector, and
12/// thus encoded/decoded data is always at least 4 bytes long.
13///
14/// This trait is implemented for [`Infallible`] to represent an empty
15/// interface. This is used by [`GenericContractError`].
16///
17/// [`SolCall`]: crate::SolCall
18/// [`SolError`]: crate::SolError
19///
20/// # Implementer's Guide
21///
22/// It should not be necessary to implement this trait manually. Instead, use
23/// the [`sol!`](crate::sol!) procedural macro to parse Solidity syntax into
24/// types that implement this trait.
25pub trait SolInterface: Sized {
26    /// The name of this type.
27    const NAME: &'static str;
28
29    /// The minimum length of the data for this type.
30    ///
31    /// This does *not* include the selector's length (4).
32    const MIN_DATA_LENGTH: usize;
33
34    /// The number of variants.
35    const COUNT: usize;
36
37    /// The selector of this instance.
38    fn selector(&self) -> [u8; 4];
39
40    /// The selector of this type at the given index, used in
41    /// [`selectors`](Self::selectors).
42    ///
43    /// This **must** return `None` if `i >= Self::COUNT`, and `Some` with a
44    /// different selector otherwise.
45    fn selector_at(i: usize) -> Option<[u8; 4]>;
46
47    /// Returns `true` if the given selector is known to this type.
48    fn valid_selector(selector: [u8; 4]) -> bool;
49
50    /// Returns an error if the given selector is not known to this type.
51    fn type_check(selector: [u8; 4]) -> Result<()> {
52        if Self::valid_selector(selector) {
53            Ok(())
54        } else {
55            Err(Error::UnknownSelector { name: Self::NAME, selector: selector.into() })
56        }
57    }
58
59    /// ABI-decodes the given data into one of the variants of `self`.
60    fn abi_decode_raw(selector: [u8; 4], data: &[u8], validate: bool) -> Result<Self>;
61
62    /// The size of the encoded data, *without* any selectors.
63    fn abi_encoded_size(&self) -> usize;
64
65    /// ABI-encodes `self` into the given buffer, *without* any selectors.
66    fn abi_encode_raw(&self, out: &mut Vec<u8>);
67
68    /// Returns an iterator over the selectors of this type.
69    #[inline]
70    fn selectors() -> Selectors<Self> {
71        Selectors::new()
72    }
73
74    /// ABI-encodes `self` into the given buffer.
75    #[inline]
76    fn abi_encode(&self) -> Vec<u8> {
77        let mut out = Vec::with_capacity(4 + self.abi_encoded_size());
78        out.extend(self.selector());
79        self.abi_encode_raw(&mut out);
80        out
81    }
82
83    /// ABI-decodes the given data into one of the variants of `self`.
84    #[inline]
85    fn abi_decode(data: &[u8], validate: bool) -> Result<Self> {
86        if data.len() < Self::MIN_DATA_LENGTH.saturating_add(4) {
87            Err(crate::Error::type_check_fail(data, Self::NAME))
88        } else {
89            let (selector, data) = data.split_first_chunk().unwrap();
90            Self::abi_decode_raw(*selector, data, validate)
91        }
92    }
93}
94
95/// An empty [`SolInterface`] implementation. Used by [`GenericContractError`].
96impl SolInterface for Infallible {
97    // better than "Infallible" since it shows up in error messages
98    const NAME: &'static str = "GenericContractError";
99
100    // no selectors or data are valid
101    const MIN_DATA_LENGTH: usize = usize::MAX;
102    const COUNT: usize = 0;
103
104    #[inline]
105    fn selector(&self) -> [u8; 4] {
106        unreachable!()
107    }
108
109    #[inline]
110    fn selector_at(_i: usize) -> Option<[u8; 4]> {
111        None
112    }
113
114    #[inline]
115    fn valid_selector(_selector: [u8; 4]) -> bool {
116        false
117    }
118
119    #[inline]
120    fn abi_decode_raw(selector: [u8; 4], _data: &[u8], _validate: bool) -> Result<Self> {
121        Self::type_check(selector).map(|()| unreachable!())
122    }
123
124    #[inline]
125    fn abi_encoded_size(&self) -> usize {
126        unreachable!()
127    }
128
129    #[inline]
130    fn abi_encode_raw(&self, _out: &mut Vec<u8>) {
131        unreachable!()
132    }
133}
134
135/// A generic contract error.
136///
137/// Contains a [`Revert`] or [`Panic`] error.
138pub type GenericContractError = ContractError<Infallible>;
139
140/// A generic contract error.
141///
142/// Contains a [`Revert`] or [`Panic`] error, or a custom error.
143///
144/// If you want an empty [`CustomError`](ContractError::CustomError) variant,
145/// use [`GenericContractError`].
146#[derive(Clone, Debug, PartialEq, Eq)]
147pub enum ContractError<T> {
148    /// A contract's custom error.
149    CustomError(T),
150    /// A generic revert. See [`Revert`] for more information.
151    Revert(Revert),
152    /// A panic. See [`Panic`] for more information.
153    Panic(Panic),
154}
155
156impl<T: SolInterface> From<T> for ContractError<T> {
157    #[inline]
158    fn from(value: T) -> Self {
159        Self::CustomError(value)
160    }
161}
162
163impl<T> From<Revert> for ContractError<T> {
164    #[inline]
165    fn from(value: Revert) -> Self {
166        Self::Revert(value)
167    }
168}
169
170impl<T> TryFrom<ContractError<T>> for Revert {
171    type Error = ContractError<T>;
172
173    #[inline]
174    fn try_from(value: ContractError<T>) -> Result<Self, Self::Error> {
175        match value {
176            ContractError::Revert(inner) => Ok(inner),
177            _ => Err(value),
178        }
179    }
180}
181
182impl<T> From<Panic> for ContractError<T> {
183    #[inline]
184    fn from(value: Panic) -> Self {
185        Self::Panic(value)
186    }
187}
188
189impl<T> TryFrom<ContractError<T>> for Panic {
190    type Error = ContractError<T>;
191
192    #[inline]
193    fn try_from(value: ContractError<T>) -> Result<Self, Self::Error> {
194        match value {
195            ContractError::Panic(inner) => Ok(inner),
196            _ => Err(value),
197        }
198    }
199}
200
201impl<T: fmt::Display> fmt::Display for ContractError<T> {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        match self {
204            Self::CustomError(error) => error.fmt(f),
205            Self::Panic(panic) => panic.fmt(f),
206            Self::Revert(revert) => revert.fmt(f),
207        }
208    }
209}
210
211impl<T: core::error::Error + 'static> core::error::Error for ContractError<T> {
212    #[inline]
213    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
214        match self {
215            Self::CustomError(error) => Some(error),
216            Self::Panic(panic) => Some(panic),
217            Self::Revert(revert) => Some(revert),
218        }
219    }
220}
221
222impl<T: SolInterface> SolInterface for ContractError<T> {
223    const NAME: &'static str = "ContractError";
224
225    // revert is 64, panic is 32
226    const MIN_DATA_LENGTH: usize = if T::MIN_DATA_LENGTH < 32 { T::MIN_DATA_LENGTH } else { 32 };
227
228    const COUNT: usize = T::COUNT + 2;
229
230    #[inline]
231    fn selector(&self) -> [u8; 4] {
232        match self {
233            Self::CustomError(error) => error.selector(),
234            Self::Panic(_) => Panic::SELECTOR,
235            Self::Revert(_) => Revert::SELECTOR,
236        }
237    }
238
239    #[inline]
240    fn selector_at(i: usize) -> Option<[u8; 4]> {
241        if i < T::COUNT {
242            T::selector_at(i)
243        } else {
244            match i - T::COUNT {
245                0 => Some(Revert::SELECTOR),
246                1 => Some(Panic::SELECTOR),
247                _ => None,
248            }
249        }
250    }
251
252    #[inline]
253    fn valid_selector(selector: [u8; 4]) -> bool {
254        match selector {
255            Revert::SELECTOR | Panic::SELECTOR => true,
256            s => T::valid_selector(s),
257        }
258    }
259
260    #[inline]
261    fn abi_decode_raw(selector: [u8; 4], data: &[u8], validate: bool) -> Result<Self> {
262        match selector {
263            Revert::SELECTOR => Revert::abi_decode_raw(data, validate).map(Self::Revert),
264            Panic::SELECTOR => Panic::abi_decode_raw(data, validate).map(Self::Panic),
265            s => T::abi_decode_raw(s, data, validate).map(Self::CustomError),
266        }
267    }
268
269    #[inline]
270    fn abi_encoded_size(&self) -> usize {
271        match self {
272            Self::CustomError(error) => error.abi_encoded_size(),
273            Self::Panic(panic) => panic.abi_encoded_size(),
274            Self::Revert(revert) => revert.abi_encoded_size(),
275        }
276    }
277
278    #[inline]
279    fn abi_encode_raw(&self, out: &mut Vec<u8>) {
280        match self {
281            Self::CustomError(error) => error.abi_encode_raw(out),
282            Self::Panic(panic) => panic.abi_encode_raw(out),
283            Self::Revert(revert) => revert.abi_encode_raw(out),
284        }
285    }
286}
287
288impl<T> ContractError<T> {
289    /// Returns `true` if `self` matches [`CustomError`](Self::CustomError).
290    #[inline]
291    pub const fn is_custom_error(&self) -> bool {
292        matches!(self, Self::CustomError(_))
293    }
294
295    /// Returns an immutable reference to the inner custom error if `self`
296    /// matches [`CustomError`](Self::CustomError).
297    #[inline]
298    pub const fn as_custom_error(&self) -> Option<&T> {
299        match self {
300            Self::CustomError(inner) => Some(inner),
301            _ => None,
302        }
303    }
304
305    /// Returns a mutable reference to the inner custom error if `self`
306    /// matches [`CustomError`](Self::CustomError).
307    #[inline]
308    pub fn as_custom_error_mut(&mut self) -> Option<&mut T> {
309        match self {
310            Self::CustomError(inner) => Some(inner),
311            _ => None,
312        }
313    }
314
315    /// Returns `true` if `self` matches [`Revert`](Self::Revert).
316    #[inline]
317    pub const fn is_revert(&self) -> bool {
318        matches!(self, Self::Revert(_))
319    }
320
321    /// Returns an immutable reference to the inner [`Revert`] if `self` matches
322    /// [`Revert`](Self::Revert).
323    #[inline]
324    pub const fn as_revert(&self) -> Option<&Revert> {
325        match self {
326            Self::Revert(inner) => Some(inner),
327            _ => None,
328        }
329    }
330
331    /// Returns a mutable reference to the inner [`Revert`] if `self` matches
332    /// [`Revert`](Self::Revert).
333    #[inline]
334    pub fn as_revert_mut(&mut self) -> Option<&mut Revert> {
335        match self {
336            Self::Revert(inner) => Some(inner),
337            _ => None,
338        }
339    }
340
341    /// Returns `true` if `self` matches [`Panic`](Self::Panic).
342    #[inline]
343    pub const fn is_panic(&self) -> bool {
344        matches!(self, Self::Panic(_))
345    }
346
347    /// Returns an immutable reference to the inner [`Panic`] if `self` matches
348    /// [`Panic`](Self::Panic).
349    #[inline]
350    pub const fn as_panic(&self) -> Option<&Panic> {
351        match self {
352            Self::Panic(inner) => Some(inner),
353            _ => None,
354        }
355    }
356
357    /// Returns a mutable reference to the inner [`Panic`] if `self` matches
358    /// [`Panic`](Self::Panic).
359    #[inline]
360    pub fn as_panic_mut(&mut self) -> Option<&mut Panic> {
361        match self {
362            Self::Panic(inner) => Some(inner),
363            _ => None,
364        }
365    }
366}
367
368/// Represents the reason for a revert in a generic contract error.
369pub type GenericRevertReason = RevertReason<Infallible>;
370
371/// Represents the reason for a revert in a smart contract.
372///
373/// This enum captures two possible scenarios for a revert:
374///
375/// - [`ContractError`](RevertReason::ContractError): Contains detailed error information, such as a
376///   specific [`Revert`] or [`Panic`] error.
377///
378/// - [`RawString`](RevertReason::RawString): Represents a raw string message as the reason for the
379///   revert.
380#[derive(Clone, Debug, PartialEq, Eq)]
381pub enum RevertReason<T> {
382    /// A detailed contract error, including a specific revert or panic error.
383    ContractError(ContractError<T>),
384    /// Represents a raw string message as the reason for the revert.
385    RawString(String),
386}
387
388impl<T: fmt::Display> fmt::Display for RevertReason<T> {
389    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
390        match self {
391            Self::ContractError(error) => error.fmt(f),
392            Self::RawString(raw_string) => f.write_str(raw_string),
393        }
394    }
395}
396
397/// Converts a `ContractError<T>` into a `RevertReason<T>`.
398impl<T> From<ContractError<T>> for RevertReason<T> {
399    fn from(error: ContractError<T>) -> Self {
400        Self::ContractError(error)
401    }
402}
403
404/// Converts a `Revert` into a `RevertReason<T>`.
405impl<T> From<Revert> for RevertReason<T> {
406    fn from(revert: Revert) -> Self {
407        Self::ContractError(ContractError::Revert(revert))
408    }
409}
410
411/// Converts a `String` into a `RevertReason<T>`.
412impl<T> From<String> for RevertReason<T> {
413    fn from(raw_string: String) -> Self {
414        Self::RawString(raw_string)
415    }
416}
417
418impl<T: SolInterface> RevertReason<T>
419where
420    Self: From<ContractError<Infallible>>,
421{
422    /// Decodes and retrieves the reason for a revert from the provided output data.
423    ///
424    /// This method attempts to decode the provided output data as a generic contract error
425    /// or a UTF-8 string (for Vyper reverts).
426    ///
427    /// If successful, it returns the decoded revert reason wrapped in an `Option`.
428    ///
429    /// If both attempts fail, it returns `None`.
430    pub fn decode(out: &[u8]) -> Option<Self> {
431        // Try to decode as a generic contract error.
432        if let Ok(error) = ContractError::<T>::abi_decode(out, false) {
433            return Some(error.into());
434        }
435
436        // If that fails, try to decode as a regular string.
437        if let Ok(decoded_string) = core::str::from_utf8(out) {
438            return Some(decoded_string.to_string().into());
439        }
440
441        // If both attempts fail, return None.
442        None
443    }
444}
445
446impl<T: SolInterface + fmt::Display> RevertReason<T> {
447    /// Returns the reason for a revert as a string.
448    #[allow(clippy::inherent_to_string_shadow_display)]
449    pub fn to_string(&self) -> String {
450        match self {
451            Self::ContractError(error) => error.to_string(),
452            Self::RawString(raw_string) => raw_string.clone(),
453        }
454    }
455}
456
457impl<T> RevertReason<T> {
458    /// Returns the raw string error message if this type is a [`RevertReason::RawString`]
459    pub fn as_raw_error(&self) -> Option<&str> {
460        match self {
461            Self::RawString(error) => Some(error.as_str()),
462            _ => None,
463        }
464    }
465
466    /// Returns the [`ContractError`] if this type is a [`RevertReason::ContractError`]
467    pub const fn as_contract_error(&self) -> Option<&ContractError<T>> {
468        match self {
469            Self::ContractError(error) => Some(error),
470            _ => None,
471        }
472    }
473
474    /// Returns `true` if `self` matches [`Revert`](ContractError::Revert).
475    pub const fn is_revert(&self) -> bool {
476        matches!(self, Self::ContractError(ContractError::Revert(_)))
477    }
478
479    /// Returns `true` if `self` matches [`Panic`](ContractError::Panic).
480    pub const fn is_panic(&self) -> bool {
481        matches!(self, Self::ContractError(ContractError::Panic(_)))
482    }
483
484    /// Returns `true` if `self` matches [`CustomError`](ContractError::CustomError).
485    pub const fn is_custom_error(&self) -> bool {
486        matches!(self, Self::ContractError(ContractError::CustomError(_)))
487    }
488}
489
490/// Iterator over the function or error selectors of a [`SolInterface`] type.
491///
492/// This `struct` is created by the [`selectors`] method on [`SolInterface`].
493/// See its documentation for more.
494///
495/// [`selectors`]: SolInterface::selectors
496pub struct Selectors<T> {
497    index: usize,
498    _marker: PhantomData<T>,
499}
500
501impl<T> Clone for Selectors<T> {
502    fn clone(&self) -> Self {
503        Self { index: self.index, _marker: PhantomData }
504    }
505}
506
507impl<T> fmt::Debug for Selectors<T> {
508    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
509        f.debug_struct("Selectors").field("index", &self.index).finish()
510    }
511}
512
513impl<T> Selectors<T> {
514    #[inline]
515    const fn new() -> Self {
516        Self { index: 0, _marker: PhantomData }
517    }
518}
519
520impl<T: SolInterface> Iterator for Selectors<T> {
521    type Item = [u8; 4];
522
523    #[inline]
524    fn next(&mut self) -> Option<Self::Item> {
525        let selector = T::selector_at(self.index)?;
526        self.index += 1;
527        Some(selector)
528    }
529
530    #[inline]
531    fn size_hint(&self) -> (usize, Option<usize>) {
532        let exact = self.len();
533        (exact, Some(exact))
534    }
535
536    #[inline]
537    fn count(self) -> usize {
538        self.len()
539    }
540}
541
542impl<T: SolInterface> ExactSizeIterator for Selectors<T> {
543    #[inline]
544    fn len(&self) -> usize {
545        T::COUNT - self.index
546    }
547}
548
549impl<T: SolInterface> FusedIterator for Selectors<T> {}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use alloy_primitives::{keccak256, U256};
555
556    fn sel(s: &str) -> [u8; 4] {
557        keccak256(s)[..4].try_into().unwrap()
558    }
559
560    #[test]
561    fn generic_contract_error_enum() {
562        assert_eq!(
563            GenericContractError::selectors().collect::<Vec<_>>(),
564            [sel("Error(string)"), sel("Panic(uint256)")]
565        );
566    }
567
568    #[test]
569    fn contract_error_enum_1() {
570        crate::sol! {
571            contract C {
572                error Err1();
573            }
574        }
575
576        assert_eq!(C::CErrors::COUNT, 1);
577        assert_eq!(C::CErrors::MIN_DATA_LENGTH, 0);
578        assert_eq!(ContractError::<C::CErrors>::COUNT, 1 + 2);
579        assert_eq!(ContractError::<C::CErrors>::MIN_DATA_LENGTH, 0);
580
581        assert_eq!(C::CErrors::SELECTORS, [sel("Err1()")]);
582        assert_eq!(
583            ContractError::<C::CErrors>::selectors().collect::<Vec<_>>(),
584            vec![sel("Err1()"), sel("Error(string)"), sel("Panic(uint256)")],
585        );
586
587        for selector in C::CErrors::selectors() {
588            assert!(C::CErrors::valid_selector(selector));
589        }
590
591        for selector in ContractError::<C::CErrors>::selectors() {
592            assert!(ContractError::<C::CErrors>::valid_selector(selector));
593        }
594    }
595
596    #[test]
597    fn contract_error_enum_2() {
598        crate::sol! {
599            #[derive(Debug, PartialEq, Eq)]
600            contract C {
601                error Err1();
602                error Err2(uint256);
603                error Err3(string);
604            }
605        }
606
607        assert_eq!(C::CErrors::COUNT, 3);
608        assert_eq!(C::CErrors::MIN_DATA_LENGTH, 0);
609        assert_eq!(ContractError::<C::CErrors>::COUNT, 2 + 3);
610        assert_eq!(ContractError::<C::CErrors>::MIN_DATA_LENGTH, 0);
611
612        // sorted by selector
613        assert_eq!(
614            C::CErrors::SELECTORS,
615            [sel("Err3(string)"), sel("Err2(uint256)"), sel("Err1()")]
616        );
617        assert_eq!(
618            ContractError::<C::CErrors>::selectors().collect::<Vec<_>>(),
619            [
620                sel("Err3(string)"),
621                sel("Err2(uint256)"),
622                sel("Err1()"),
623                sel("Error(string)"),
624                sel("Panic(uint256)"),
625            ],
626        );
627
628        let err1 = || C::Err1 {};
629        let errors_err1 = || C::CErrors::Err1(err1());
630        let contract_error_err1 = || ContractError::<C::CErrors>::CustomError(errors_err1());
631        let data = err1().abi_encode();
632        assert_eq!(data[..4], C::Err1::SELECTOR);
633        assert_eq!(errors_err1().abi_encode(), data);
634        assert_eq!(contract_error_err1().abi_encode(), data);
635
636        assert_eq!(C::Err1::abi_decode(&data, true), Ok(err1()));
637        assert_eq!(C::CErrors::abi_decode(&data, true), Ok(errors_err1()));
638        assert_eq!(ContractError::<C::CErrors>::abi_decode(&data, true), Ok(contract_error_err1()));
639
640        let err2 = || C::Err2 { _0: U256::from(42) };
641        let errors_err2 = || C::CErrors::Err2(err2());
642        let contract_error_err2 = || ContractError::<C::CErrors>::CustomError(errors_err2());
643        let data = err2().abi_encode();
644        assert_eq!(data[..4], C::Err2::SELECTOR);
645        assert_eq!(errors_err2().abi_encode(), data);
646        assert_eq!(contract_error_err2().abi_encode(), data);
647
648        assert_eq!(C::Err2::abi_decode(&data, true), Ok(err2()));
649        assert_eq!(C::CErrors::abi_decode(&data, true), Ok(errors_err2()));
650        assert_eq!(ContractError::<C::CErrors>::abi_decode(&data, true), Ok(contract_error_err2()));
651
652        let err3 = || C::Err3 { _0: "hello".into() };
653        let errors_err3 = || C::CErrors::Err3(err3());
654        let contract_error_err3 = || ContractError::<C::CErrors>::CustomError(errors_err3());
655        let data = err3().abi_encode();
656        assert_eq!(data[..4], C::Err3::SELECTOR);
657        assert_eq!(errors_err3().abi_encode(), data);
658        assert_eq!(contract_error_err3().abi_encode(), data);
659
660        assert_eq!(C::Err3::abi_decode(&data, true), Ok(err3()));
661        assert_eq!(C::CErrors::abi_decode(&data, true), Ok(errors_err3()));
662        assert_eq!(ContractError::<C::CErrors>::abi_decode(&data, true), Ok(contract_error_err3()));
663
664        for selector in C::CErrors::selectors() {
665            assert!(C::CErrors::valid_selector(selector));
666        }
667
668        for selector in ContractError::<C::CErrors>::selectors() {
669            assert!(ContractError::<C::CErrors>::valid_selector(selector));
670        }
671    }
672}