yoke_derive/
lib.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5//! Custom derives for `Yokeable` from the `yoke` crate.
6
7use proc_macro::TokenStream;
8use proc_macro2::{Span, TokenStream as TokenStream2};
9use quote::quote;
10use syn::spanned::Spanned;
11use syn::{parse_macro_input, parse_quote, DeriveInput, Ident, Lifetime, Type, WherePredicate};
12use synstructure::Structure;
13
14mod visitor;
15
16/// Custom derive for `yoke::Yokeable`,
17///
18/// If your struct contains `zerovec::ZeroMap`, then the compiler will not
19/// be able to guarantee the lifetime covariance due to the generic types on
20/// the `ZeroMap` itself. You must add the following attribute in order for
21/// the custom derive to work with `ZeroMap`.
22///
23/// ```rust,ignore
24/// #[derive(Yokeable)]
25/// #[yoke(prove_covariance_manually)]
26/// ```
27///
28/// Beyond this case, if the derive fails to compile due to lifetime issues, it
29/// means that the lifetime is not covariant and `Yokeable` is not safe to implement.
30#[proc_macro_derive(Yokeable, attributes(yoke))]
31pub fn yokeable_derive(input: TokenStream) -> TokenStream {
32    let input = parse_macro_input!(input as DeriveInput);
33    TokenStream::from(yokeable_derive_impl(&input))
34}
35
36fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 {
37    let tybounds = input
38        .generics
39        .type_params()
40        .map(|ty| {
41            // Strip out param defaults, we don't need them in the impl
42            let mut ty = ty.clone();
43            ty.eq_token = None;
44            ty.default = None;
45            ty
46        })
47        .collect::<Vec<_>>();
48    let typarams = tybounds
49        .iter()
50        .map(|ty| ty.ident.clone())
51        .collect::<Vec<_>>();
52    // We require all type parameters be 'static, otherwise
53    // the Yokeable impl becomes really unweildy to generate safely
54    let static_bounds: Vec<WherePredicate> = typarams
55        .iter()
56        .map(|ty| parse_quote!(#ty: 'static))
57        .collect();
58    let lts = input.generics.lifetimes().count();
59    if lts == 0 {
60        let name = &input.ident;
61        quote! {
62            // This is safe because there are no lifetime parameters.
63            unsafe impl<'a, #(#tybounds),*> yoke::Yokeable<'a> for #name<#(#typarams),*> where #(#static_bounds,)* Self: Sized {
64                type Output = Self;
65                #[inline]
66                fn transform(&self) -> &Self::Output {
67                    self
68                }
69                #[inline]
70                fn transform_owned(self) -> Self::Output {
71                    self
72                }
73                #[inline]
74                unsafe fn make(this: Self::Output) -> Self {
75                    this
76                }
77                #[inline]
78                fn transform_mut<F>(&'a mut self, f: F)
79                where
80                    F: 'static + for<'b> FnOnce(&'b mut Self::Output) {
81                    f(self)
82                }
83            }
84        }
85    } else {
86        if lts != 1 {
87            return syn::Error::new(
88                input.generics.span(),
89                "derive(Yokeable) cannot have multiple lifetime parameters",
90            )
91            .to_compile_error();
92        }
93        let name = &input.ident;
94        let manual_covariance = input.attrs.iter().any(|a| {
95            if let Ok(i) = a.parse_args::<Ident>() {
96                if i == "prove_covariance_manually" {
97                    return true;
98                }
99            }
100            false
101        });
102        if manual_covariance {
103            let mut structure = Structure::new(input);
104            let generics_env = typarams.iter().cloned().collect();
105            let static_bounds: Vec<WherePredicate> = typarams
106                .iter()
107                .map(|ty| parse_quote!(#ty: 'static))
108                .collect();
109            let mut yoke_bounds: Vec<WherePredicate> = vec![];
110            structure.bind_with(|_| synstructure::BindStyle::Move);
111            let owned_body = structure.each_variant(|vi| {
112                vi.construct(|f, i| {
113                    let binding = format!("__binding_{i}");
114                    let field = Ident::new(&binding, Span::call_site());
115                    let fty_static = replace_lifetime(&f.ty, static_lt());
116
117                    let (has_ty, has_lt) = visitor::check_type_for_parameters(&f.ty, &generics_env);
118                    if has_ty {
119                        // For types without type parameters, the compiler can figure out that the field implements
120                        // Yokeable on its own. However, if there are type parameters, there may be complex preconditions
121                        // to `FieldTy: Yokeable` that need to be satisfied. We get them to be satisfied by requiring
122                        // `FieldTy<'static>: Yokeable<FieldTy<'a>>`
123                        if has_lt {
124                            let fty_a = replace_lifetime(&f.ty, custom_lt("'a"));
125                            yoke_bounds.push(
126                                parse_quote!(#fty_static: yoke::Yokeable<'a, Output = #fty_a>),
127                            );
128                        } else {
129                            yoke_bounds.push(
130                                parse_quote!(#fty_static: yoke::Yokeable<'a, Output = #fty_static>),
131                            );
132                        }
133                    }
134                    if has_ty || has_lt {
135                        // By calling transform_owned on all fields, we manually prove
136                        // that the lifetimes are covariant, since this requirement
137                        // must already be true for the type that implements transform_owned().
138                        quote! {
139                            <#fty_static as yoke::Yokeable<'a>>::transform_owned(#field)
140                        }
141                    } else {
142                        // No nested lifetimes, so nothing to be done
143                        quote! { #field }
144                    }
145                })
146            });
147            let borrowed_body = structure.each(|binding| {
148                let f = binding.ast();
149                let field = &binding.binding;
150
151                let (has_ty, has_lt) = visitor::check_type_for_parameters(&f.ty, &generics_env);
152
153                if has_ty || has_lt {
154                    let fty_static = replace_lifetime(&f.ty, static_lt());
155                    let fty_a = replace_lifetime(&f.ty, custom_lt("'a"));
156                    // We also must assert that each individual field can `transform()` correctly
157                    //
158                    // Even though transform_owned() does such an assertion already, CoerceUnsized
159                    // can cause type transformations that allow it to succeed where this would fail.
160                    // We need to check both.
161                    //
162                    // https://github.com/unicode-org/icu4x/issues/2928
163                    quote! {
164                        let _: &#fty_a = &<#fty_static as yoke::Yokeable<'a>>::transform(#field);
165                    }
166                } else {
167                    // No nested lifetimes, so nothing to be done
168                    quote! {}
169                }
170            });
171            return quote! {
172                unsafe impl<'a, #(#tybounds),*> yoke::Yokeable<'a> for #name<'static, #(#typarams),*>
173                    where #(#static_bounds,)*
174                    #(#yoke_bounds,)* {
175                    type Output = #name<'a, #(#typarams),*>;
176                    #[inline]
177                    fn transform(&'a self) -> &'a Self::Output {
178                        // These are just type asserts, we don't need them for anything
179                        if false {
180                            match self {
181                                #borrowed_body
182                            }
183                        }
184                        unsafe {
185                            // safety: we have asserted covariance in
186                            // transform_owned
187                            ::core::mem::transmute(self)
188                        }
189                    }
190                    #[inline]
191                    fn transform_owned(self) -> Self::Output {
192                        match self { #owned_body }
193                    }
194                    #[inline]
195                    unsafe fn make(this: Self::Output) -> Self {
196                        use core::{mem, ptr};
197                        // unfortunately Rust doesn't think `mem::transmute` is possible since it's not sure the sizes
198                        // are the same
199                        debug_assert!(mem::size_of::<Self::Output>() == mem::size_of::<Self>());
200                        let ptr: *const Self = (&this as *const Self::Output).cast();
201                        #[allow(forgetting_copy_types, clippy::forget_copy, clippy::forget_non_drop)] // This is a noop if the struct is copy, which Clippy doesn't like
202                        mem::forget(this);
203                        ptr::read(ptr)
204                    }
205                    #[inline]
206                    fn transform_mut<F>(&'a mut self, f: F)
207                    where
208                        F: 'static + for<'b> FnOnce(&'b mut Self::Output) {
209                        unsafe { f(core::mem::transmute::<&'a mut Self, &'a mut Self::Output>(self)) }
210                    }
211                }
212            };
213        }
214        quote! {
215            // This is safe because as long as `transform()` compiles,
216            // we can be sure that `'a` is a covariant lifetime on `Self`
217            //
218            // This will not work for structs involving ZeroMap since
219            // the compiler does not know that ZeroMap is covariant.
220            //
221            // This custom derive can be improved to handle this case when
222            // necessary
223            unsafe impl<'a, #(#tybounds),*> yoke::Yokeable<'a> for #name<'static, #(#typarams),*> where #(#static_bounds,)* {
224                type Output = #name<'a, #(#typarams),*>;
225                #[inline]
226                fn transform(&'a self) -> &'a Self::Output {
227                    self
228                }
229                #[inline]
230                fn transform_owned(self) -> Self::Output {
231                    self
232                }
233                #[inline]
234                unsafe fn make(this: Self::Output) -> Self {
235                    use core::{mem, ptr};
236                    // unfortunately Rust doesn't think `mem::transmute` is possible since it's not sure the sizes
237                    // are the same
238                    debug_assert!(mem::size_of::<Self::Output>() == mem::size_of::<Self>());
239                    let ptr: *const Self = (&this as *const Self::Output).cast();
240                    #[allow(forgetting_copy_types, clippy::forget_copy, clippy::forget_non_drop)] // This is a noop if the struct is copy, which Clippy doesn't like
241                    mem::forget(this);
242                    ptr::read(ptr)
243                }
244                #[inline]
245                fn transform_mut<F>(&'a mut self, f: F)
246                where
247                    F: 'static + for<'b> FnOnce(&'b mut Self::Output) {
248                    unsafe { f(core::mem::transmute::<&'a mut Self, &'a mut Self::Output>(self)) }
249                }
250            }
251        }
252    }
253}
254
255fn static_lt() -> Lifetime {
256    Lifetime::new("'static", Span::call_site())
257}
258
259fn custom_lt(s: &str) -> Lifetime {
260    Lifetime::new(s, Span::call_site())
261}
262
263fn replace_lifetime(x: &Type, lt: Lifetime) -> Type {
264    use syn::fold::Fold;
265    struct ReplaceLifetime(Lifetime);
266
267    impl Fold for ReplaceLifetime {
268        fn fold_lifetime(&mut self, _: Lifetime) -> Lifetime {
269            self.0.clone()
270        }
271    }
272    ReplaceLifetime(lt).fold_type(x.clone())
273}