bon_macros/normalization/
impl_traits.rs

1use super::GenericsNamespace;
2use crate::util::prelude::*;
3use syn::visit_mut::VisitMut;
4
5pub(crate) struct NormalizeImplTraits<'a> {
6    namespace: &'a GenericsNamespace,
7}
8
9impl<'a> NormalizeImplTraits<'a> {
10    pub(crate) fn new(namespace: &'a GenericsNamespace) -> Self {
11        Self { namespace }
12    }
13}
14
15impl VisitMut for NormalizeImplTraits<'_> {
16    fn visit_impl_item_fn_mut(&mut self, fn_item: &mut syn::ImplItemFn) {
17        // We are interested only in signatures of functions. Don't recurse
18        // into the function's block.
19        self.visit_signature_mut(&mut fn_item.sig);
20    }
21
22    fn visit_signature_mut(&mut self, signature: &mut syn::Signature) {
23        let mut visitor = AssignTypeParams::new(self, &mut signature.generics);
24
25        for arg in &mut signature.inputs {
26            visitor.visit_fn_arg_mut(arg);
27        }
28    }
29}
30
31struct AssignTypeParams<'a> {
32    base: &'a NormalizeImplTraits<'a>,
33    generics: &'a mut syn::Generics,
34    next_type_param_index: usize,
35}
36
37impl<'a> AssignTypeParams<'a> {
38    fn new(base: &'a NormalizeImplTraits<'a>, generics: &'a mut syn::Generics) -> Self {
39        Self {
40            base,
41            generics,
42            next_type_param_index: 1,
43        }
44    }
45}
46
47impl VisitMut for AssignTypeParams<'_> {
48    fn visit_item_mut(&mut self, _item: &mut syn::Item) {
49        // Don't recurse into nested items because `impl Trait` isn't available there.
50    }
51
52    fn visit_signature_mut(&mut self, signature: &mut syn::Signature) {
53        for arg in &mut signature.inputs {
54            self.visit_type_mut(arg.ty_mut());
55        }
56    }
57
58    fn visit_type_mut(&mut self, ty: &mut syn::Type) {
59        syn::visit_mut::visit_type_mut(self, ty);
60
61        if !matches!(ty, syn::Type::ImplTrait(_)) {
62            return;
63        };
64
65        let index = self.next_type_param_index;
66        self.next_type_param_index += 1;
67
68        let type_param = self.base.namespace.unique_ident(format!("I{index}"));
69
70        let impl_trait = std::mem::replace(ty, syn::Type::Path(syn::parse_quote!(#type_param)));
71
72        let impl_trait = match impl_trait {
73            syn::Type::ImplTrait(impl_trait) => impl_trait,
74            _ => {
75                unreachable!("BUG: code higher validated that this is impl trait: {impl_trait:?}");
76            }
77        };
78
79        self.generics
80            .params
81            .push(syn::GenericParam::Type(syn::parse_quote!(#type_param)));
82
83        let bounds = impl_trait.bounds;
84
85        self.generics
86            .make_where_clause()
87            .predicates
88            .push(syn::parse_quote!(#type_param: #bounds));
89    }
90}