strum_macros/macros/strings/
from_string.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields};
4
5use crate::helpers::{
6    non_enum_error, occurrence_error, HasInnerVariantProperties, HasStrumVariantProperties,
7    HasTypeProperties,
8};
9
10pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
11    let name = &ast.ident;
12    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
13    let variants = match &ast.data {
14        Data::Enum(v) => &v.variants,
15        _ => return Err(non_enum_error()),
16    };
17
18    let type_properties = ast.get_type_properties()?;
19    let strum_module_path = type_properties.crate_module_path();
20
21    let mut default_kw = None;
22    let mut default =
23        quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) };
24
25    let mut phf_exact_match_arms = Vec::new();
26    let mut standard_match_arms = Vec::new();
27    for variant in variants {
28        let ident = &variant.ident;
29        let variant_properties = variant.get_variant_properties()?;
30
31        if variant_properties.disabled.is_some() {
32            continue;
33        }
34
35        if let Some(kw) = variant_properties.default {
36            if let Some(fst_kw) = default_kw {
37                return Err(occurrence_error(fst_kw, kw, "default"));
38            }
39
40            match &variant.fields {
41                Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {}
42                _ => {
43                    return Err(syn::Error::new_spanned(
44                        variant,
45                        "Default only works on newtype structs with a single String field",
46                    ))
47                }
48            }
49            default_kw = Some(kw);
50            default = quote! {
51                ::core::result::Result::Ok(#name::#ident(s.into()))
52            };
53            continue;
54        }
55
56        let params = match &variant.fields {
57            Fields::Unit => quote! {},
58            Fields::Unnamed(fields) => {
59                if let Some(ref value) = variant_properties.default_with {
60                    let func = proc_macro2::Ident::new(&value.value(), value.span());
61                    let defaults = vec![quote! { #func() }];
62                    quote! { (#(#defaults),*) }
63                } else {
64                    let defaults =
65                        ::core::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
66                    quote! { (#(#defaults),*) }
67                }
68            }
69            Fields::Named(fields) => {
70                let mut defaults = vec![];
71                for field in &fields.named {
72                    let meta = field.get_variant_inner_properties()?;
73                    let field = field.ident.as_ref().unwrap();
74
75                    if let Some(default_with) = meta.default_with {
76                        let func =
77                            proc_macro2::Ident::new(&default_with.value(), default_with.span());
78                        defaults.push(quote! {
79                            #field: #func()
80                        });
81                    } else {
82                        defaults.push(quote! { #field: Default::default() });
83                    }
84                }
85
86                quote! { {#(#defaults),*} }
87            }
88        };
89
90        let is_ascii_case_insensitive = variant_properties
91            .ascii_case_insensitive
92            .unwrap_or(type_properties.ascii_case_insensitive);
93
94        // If we don't have any custom variants, add the default serialized name.
95        for serialization in variant_properties.get_serializations(type_properties.case_style) {
96            if type_properties.use_phf {
97                phf_exact_match_arms.push(quote! { #serialization => #name::#ident #params, });
98
99                if is_ascii_case_insensitive {
100                    // Store the lowercase and UPPERCASE variants in the phf map to capture
101                    let ser_string = serialization.value();
102
103                    let lower =
104                        syn::LitStr::new(&ser_string.to_ascii_lowercase(), serialization.span());
105                    let upper =
106                        syn::LitStr::new(&ser_string.to_ascii_uppercase(), serialization.span());
107                    phf_exact_match_arms.push(quote! { #lower => #name::#ident #params, });
108                    phf_exact_match_arms.push(quote! { #upper => #name::#ident #params, });
109                    standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
110                }
111            } else {
112                standard_match_arms.push(if !is_ascii_case_insensitive {
113                    quote! { #serialization => #name::#ident #params, }
114                } else {
115                    quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, }
116                });
117            }
118        }
119    }
120
121    let phf_body = if phf_exact_match_arms.is_empty() {
122        quote!()
123    } else {
124        quote! {
125            use #strum_module_path::_private_phf_reexport_for_macro_if_phf_feature as phf;
126            static PHF: phf::Map<&'static str, #name> = phf::phf_map! {
127                #(#phf_exact_match_arms)*
128            };
129            if let Some(value) = PHF.get(s).cloned() {
130                return ::core::result::Result::Ok(value);
131            }
132        }
133    };
134
135    let standard_match_body = if standard_match_arms.is_empty() {
136        default
137    } else {
138        quote! {
139            ::core::result::Result::Ok(match s {
140                #(#standard_match_arms)*
141                _ => return #default,
142            })
143        }
144    };
145
146    let from_str = quote! {
147        #[allow(clippy::use_self)]
148        impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
149            type Err = #strum_module_path::ParseError;
150            fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
151                #phf_body
152                #standard_match_body
153            }
154        }
155    };
156    let try_from_str = try_from_str(
157        name,
158        &impl_generics,
159        &ty_generics,
160        where_clause,
161        &strum_module_path,
162    );
163
164    Ok(quote! {
165        #from_str
166        #try_from_str
167    })
168}
169
170#[rustversion::before(1.34)]
171fn try_from_str(
172    _name: &proc_macro2::Ident,
173    _impl_generics: &syn::ImplGenerics,
174    _ty_generics: &syn::TypeGenerics,
175    _where_clause: Option<&syn::WhereClause>,
176    _strum_module_path: &syn::Path,
177) -> TokenStream {
178    Default::default()
179}
180
181#[rustversion::since(1.34)]
182fn try_from_str(
183    name: &proc_macro2::Ident,
184    impl_generics: &syn::ImplGenerics,
185    ty_generics: &syn::TypeGenerics,
186    where_clause: Option<&syn::WhereClause>,
187    strum_module_path: &syn::Path,
188) -> TokenStream {
189    quote! {
190        #[allow(clippy::use_self)]
191        impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
192            type Error = #strum_module_path::ParseError;
193            fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
194                ::core::str::FromStr::from_str(s)
195            }
196        }
197    }
198}