strum_macros/macros/
enum_iter.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Ident};
4
5use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6
7pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8    let name = &ast.ident;
9    let gen = &ast.generics;
10    let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11    let vis = &ast.vis;
12    let type_properties = ast.get_type_properties()?;
13    let strum_module_path = type_properties.crate_module_path();
14    let doc_comment = format!("An iterator over the variants of [{}]", name);
15
16    if gen.lifetimes().count() > 0 {
17        return Err(syn::Error::new(
18            Span::call_site(),
19            "This macro doesn't support enums with lifetimes. \
20             The resulting enums would be unbounded.",
21        ));
22    }
23
24    let phantom_data = if gen.type_params().count() > 0 {
25        let g = gen.type_params().map(|param| &param.ident);
26        quote! { < ( #(#g),* ) > }
27    } else {
28        quote! { < () > }
29    };
30
31    let variants = match &ast.data {
32        Data::Enum(v) => &v.variants,
33        _ => return Err(non_enum_error()),
34    };
35
36    let mut arms = Vec::new();
37    let mut idx = 0usize;
38    for variant in variants {
39        if variant.get_variant_properties()?.disabled.is_some() {
40            continue;
41        }
42
43        let ident = &variant.ident;
44        let params = match &variant.fields {
45            Fields::Unit => quote! {},
46            Fields::Unnamed(fields) => {
47                let defaults = ::core::iter::repeat(quote!(::core::default::Default::default()))
48                    .take(fields.unnamed.len());
49                quote! { (#(#defaults),*) }
50            }
51            Fields::Named(fields) => {
52                let fields = fields
53                    .named
54                    .iter()
55                    .map(|field| field.ident.as_ref().unwrap());
56                quote! { {#(#fields: ::core::default::Default::default()),*} }
57            }
58        };
59
60        arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
61        idx += 1;
62    }
63
64    let variant_count = arms.len();
65    arms.push(quote! { _ => ::core::option::Option::None });
66    let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
67
68    // Create a string literal "MyEnumIter" to use in the debug impl.
69    let iter_name_debug_struct =
70        syn::parse_str::<syn::LitStr>(&format!("\"{}\"", iter_name)).unwrap();
71
72    Ok(quote! {
73        #[doc = #doc_comment]
74        #[allow(
75            missing_copy_implementations,
76        )]
77        #vis struct #iter_name #impl_generics {
78            idx: usize,
79            back_idx: usize,
80            marker: ::core::marker::PhantomData #phantom_data,
81        }
82
83        impl #impl_generics ::core::fmt::Debug for #iter_name #ty_generics #where_clause {
84            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
85                // We don't know if the variants implement debug themselves so the only thing we
86                // can really show is how many elements are left.
87                f.debug_struct(#iter_name_debug_struct)
88                    .field("len", &self.len())
89                    .finish()
90            }
91        }
92
93        impl #impl_generics #iter_name #ty_generics #where_clause {
94            fn get(&self, idx: usize) -> ::core::option::Option<#name #ty_generics> {
95                match idx {
96                    #(#arms),*
97                }
98            }
99        }
100
101        impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
102            type Iterator = #iter_name #ty_generics;
103            fn iter() -> #iter_name #ty_generics {
104                #iter_name {
105                    idx: 0,
106                    back_idx: 0,
107                    marker: ::core::marker::PhantomData,
108                }
109            }
110        }
111
112        impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
113            type Item = #name #ty_generics;
114
115            fn next(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
116                self.nth(0)
117            }
118
119            fn size_hint(&self) -> (usize, ::core::option::Option<usize>) {
120                let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
121                (t, Some(t))
122            }
123
124            fn nth(&mut self, n: usize) -> ::core::option::Option<<Self as Iterator>::Item> {
125                let idx = self.idx + n + 1;
126                if idx + self.back_idx > #variant_count {
127                    // We went past the end of the iterator. Freeze idx at #variant_count
128                    // so that it doesn't overflow if the user calls this repeatedly.
129                    // See PR #76 for context.
130                    self.idx = #variant_count;
131                    ::core::option::Option::None
132                } else {
133                    self.idx = idx;
134                    #iter_name::get(self, idx - 1)
135                }
136            }
137        }
138
139        impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
140            fn len(&self) -> usize {
141                self.size_hint().0
142            }
143        }
144
145        impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
146            fn next_back(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
147                let back_idx = self.back_idx + 1;
148
149                if self.idx + back_idx > #variant_count {
150                    // We went past the end of the iterator. Freeze back_idx at #variant_count
151                    // so that it doesn't overflow if the user calls this repeatedly.
152                    // See PR #76 for context.
153                    self.back_idx = #variant_count;
154                    ::core::option::Option::None
155                } else {
156                    self.back_idx = back_idx;
157                    #iter_name::get(self, #variant_count - self.back_idx)
158                }
159            }
160        }
161
162        impl #impl_generics ::core::iter::FusedIterator for #iter_name #ty_generics #where_clause { }
163
164        impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
165            fn clone(&self) -> #iter_name #ty_generics {
166                #iter_name {
167                    idx: self.idx,
168                    back_idx: self.back_idx,
169                    marker: self.marker.clone(),
170                }
171            }
172        }
173    })
174}