strum_macros/macros/
enum_table.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::{spanned::Spanned, Data, DeriveInput, Fields};
4
5use crate::helpers::{non_enum_error, snakify, HasStrumVariantProperties};
6
7pub fn enum_table_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8    let name = &ast.ident;
9    let gen = &ast.generics;
10    let vis = &ast.vis;
11    let mut doc_comment = format!("A map over the variants of `{}`", name);
12
13    if gen.lifetimes().count() > 0 {
14        return Err(syn::Error::new(
15            Span::call_site(),
16            "`EnumTable` doesn't support enums with lifetimes.",
17        ));
18    }
19
20    let variants = match &ast.data {
21        Data::Enum(v) => &v.variants,
22        _ => return Err(non_enum_error()),
23    };
24
25    let table_name = format_ident!("{}Table", name);
26
27    // the identifiers of each variant, in PascalCase
28    let mut pascal_idents = Vec::new();
29    // the identifiers of each struct field, in snake_case
30    let mut snake_idents = Vec::new();
31    // match arms in the form `MyEnumTable::Variant => &self.variant,`
32    let mut get_matches = Vec::new();
33    // match arms in the form `MyEnumTable::Variant => &mut self.variant,`
34    let mut get_matches_mut = Vec::new();
35    // match arms in the form `MyEnumTable::Variant => self.variant = new_value`
36    let mut set_matches = Vec::new();
37    // struct fields of the form `variant: func(MyEnum::Variant),*
38    let mut closure_fields = Vec::new();
39    // struct fields of the form `variant: func(MyEnum::Variant, self.variant),`
40    let mut transform_fields = Vec::new();
41
42    // identifiers for disabled variants
43    let mut disabled_variants = Vec::new();
44    // match arms for disabled variants
45    let mut disabled_matches = Vec::new();
46
47    for variant in variants {
48        // skip disabled variants
49        if variant.get_variant_properties()?.disabled.is_some() {
50            let disabled_ident = &variant.ident;
51            let panic_message = format!(
52                "Can't use `{}` with `{}` - variant is disabled for Strum features",
53                disabled_ident, table_name
54            );
55            disabled_variants.push(disabled_ident);
56            disabled_matches.push(quote!(#name::#disabled_ident => panic!(#panic_message),));
57            continue;
58        }
59
60        // Error on variants with data
61        if variant.fields != Fields::Unit {
62            return Err(syn::Error::new(
63                variant.fields.span(),
64                "`EnumTable` doesn't support enums with non-unit variants",
65            ));
66        };
67
68        let pascal_case = &variant.ident;
69        let snake_case = format_ident!("_{}", snakify(&pascal_case.to_string()));
70
71        get_matches.push(quote! {#name::#pascal_case => &self.#snake_case,});
72        get_matches_mut.push(quote! {#name::#pascal_case => &mut self.#snake_case,});
73        set_matches.push(quote! {#name::#pascal_case => self.#snake_case = new_value,});
74        closure_fields.push(quote! {#snake_case: func(#name::#pascal_case),});
75        transform_fields.push(quote! {#snake_case: func(#name::#pascal_case, &self.#snake_case),});
76        pascal_idents.push(pascal_case);
77        snake_idents.push(snake_case);
78    }
79
80    // Error on empty enums
81    if pascal_idents.is_empty() {
82        return Err(syn::Error::new(
83            variants.span(),
84            "`EnumTable` requires at least one non-disabled variant",
85        ));
86    }
87
88    // if the index operation can panic, add that to the documentation
89    if !disabled_variants.is_empty() {
90        doc_comment.push_str(&format!(
91            "\n# Panics\nIndexing `{}` with any of the following variants will cause a panic:",
92            table_name
93        ));
94        for variant in disabled_variants {
95            doc_comment.push_str(&format!("\n\n- `{}::{}`", name, variant));
96        }
97    }
98
99    let doc_new = format!(
100        "Create a new {} with a value for each variant of {}",
101        table_name, name
102    );
103    let doc_closure = format!(
104        "Create a new {} by running a function on each variant of `{}`",
105        table_name, name
106    );
107    let doc_transform = format!("Create a new `{}` by running a function on each variant of `{}` and the corresponding value in the current `{0}`", table_name, name);
108    let doc_filled = format!(
109        "Create a new `{}` with the same value in each field.",
110        table_name
111    );
112    let doc_option_all = format!("Converts `{}<Option<T>>` into `Option<{0}<T>>`. Returns `Some` if all fields are `Some`, otherwise returns `None`.", table_name);
113    let doc_result_all_ok = format!("Converts `{}<Result<T, E>>` into `Result<{0}<T>, E>`. Returns `Ok` if all fields are `Ok`, otherwise returns `Err`.", table_name);
114
115    Ok(quote! {
116        #[doc = #doc_comment]
117        #[allow(
118            missing_copy_implementations,
119        )]
120        #[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
121        #vis struct #table_name<T> {
122            #(#snake_idents: T,)*
123        }
124
125        impl<T: Clone> #table_name<T> {
126            #[doc = #doc_filled]
127            #vis fn filled(value: T) -> #table_name<T> {
128                #table_name {
129                    #(#snake_idents: value.clone(),)*
130                }
131            }
132        }
133
134        impl<T> #table_name<T> {
135            #[doc = #doc_new]
136            #vis fn new(
137                #(#snake_idents: T,)*
138            ) -> #table_name<T> {
139                #table_name {
140                    #(#snake_idents,)*
141                }
142            }
143
144            #[doc = #doc_closure]
145            #vis fn from_closure<F: Fn(#name)->T>(func: F) -> #table_name<T> {
146              #table_name {
147                #(#closure_fields)*
148              }
149            }
150
151            #[doc = #doc_transform]
152            #vis fn transform<U, F: Fn(#name, &T)->U>(&self, func: F) -> #table_name<U> {
153              #table_name {
154                #(#transform_fields)*
155              }
156            }
157
158        }
159
160        impl<T> ::core::ops::Index<#name> for #table_name<T> {
161            type Output = T;
162
163            fn index(&self, idx: #name) -> &T {
164                match idx {
165                    #(#get_matches)*
166                    #(#disabled_matches)*
167                }
168            }
169        }
170
171        impl<T> ::core::ops::IndexMut<#name> for #table_name<T> {
172            fn index_mut(&mut self, idx: #name) -> &mut T {
173                match idx {
174                    #(#get_matches_mut)*
175                    #(#disabled_matches)*
176                }
177            }
178        }
179
180        impl<T> #table_name<::core::option::Option<T>> {
181            #[doc = #doc_option_all]
182            #vis fn all(self) -> ::core::option::Option<#table_name<T>> {
183                if let #table_name {
184                    #(#snake_idents: ::core::option::Option::Some(#snake_idents),)*
185                } = self {
186                    ::core::option::Option::Some(#table_name {
187                        #(#snake_idents,)*
188                    })
189                } else {
190                    ::core::option::Option::None
191                }
192            }
193        }
194
195        impl<T, E> #table_name<::core::result::Result<T, E>> {
196            #[doc = #doc_result_all_ok]
197            #vis fn all_ok(self) -> ::core::result::Result<#table_name<T>, E> {
198                ::core::result::Result::Ok(#table_name {
199                    #(#snake_idents: self.#snake_idents?,)*
200                })
201            }
202        }
203    })
204}