bon_macros/util/ty/
match_types.rs

1use crate::util::iterator::IntoIteratorExt;
2use crate::util::prelude::*;
3use syn::spanned::Spanned;
4
5pub(crate) fn match_return_types(
6    scrutinee: &syn::ReturnType,
7    pattern: &syn::ReturnType,
8) -> Result<bool> {
9    match (scrutinee, pattern) {
10        (syn::ReturnType::Default, syn::ReturnType::Default) => Ok(true),
11        (syn::ReturnType::Default, syn::ReturnType::Type(_, pattern)) => {
12            match_types(&syn::parse_quote!(()), pattern)
13        }
14        (syn::ReturnType::Type(_, scrutinee), syn::ReturnType::Default) => {
15            Ok(**scrutinee == syn::parse_quote!(()))
16        }
17        (syn::ReturnType::Type(_, scrutinee), syn::ReturnType::Type(_, pattern)) => {
18            match_types(scrutinee, pattern)
19        }
20    }
21}
22
23fn match_paths(scrutinee: &syn::Path, pattern: &syn::Path) -> Result<bool> {
24    let verdict = scrutinee.leading_colon == pattern.leading_colon
25        && scrutinee
26            .segments
27            .iter()
28            .try_equals_with(&pattern.segments, |scrutinee, pattern| {
29                let verdict = scrutinee.ident == pattern.ident
30                    && match_path_args(&scrutinee.arguments, &pattern.arguments)?;
31
32                Ok(verdict)
33            })?;
34
35    Ok(verdict)
36}
37
38fn match_path_args(scrutinee: &syn::PathArguments, pattern: &syn::PathArguments) -> Result<bool> {
39    use syn::PathArguments::*;
40
41    let verdict = match (scrutinee, pattern) {
42        (None, None) => true,
43        (AngleBracketed(scrutinee), AngleBracketed(pattern)) => {
44            match_angle_bracketed_generic_args(scrutinee, pattern)?
45        }
46        (Parenthesized(scrutinee), Parenthesized(pattern)) => {
47            scrutinee
48                .inputs
49                .iter()
50                .try_equals_with(&pattern.inputs, match_types)?
51                && match_return_types(&scrutinee.output, &pattern.output)?
52        }
53        _ => false,
54    };
55
56    Ok(verdict)
57}
58
59fn match_angle_bracketed_generic_args(
60    scrutinee: &syn::AngleBracketedGenericArguments,
61    pattern: &syn::AngleBracketedGenericArguments,
62) -> Result<bool> {
63    scrutinee
64        .args
65        .iter()
66        .try_equals_with(&pattern.args, match_generic_args)
67}
68
69fn match_option<T>(
70    scrutinee: Option<&T>,
71    pattern: Option<&T>,
72    compare: impl Fn(&T, &T) -> Result<bool>,
73) -> Result<bool> {
74    match (scrutinee, &pattern) {
75        (None, None) => Ok(true),
76        (Some(scrutinee), Some(pattern)) => compare(scrutinee, pattern),
77        _ => Ok(false),
78    }
79}
80
81fn match_generic_args(
82    scrutinee: &syn::GenericArgument,
83    pattern: &syn::GenericArgument,
84) -> Result<bool> {
85    use syn::GenericArgument::*;
86
87    let verdict = match pattern {
88        Lifetime(pattern) => {
89            if pattern.ident != "_" {
90                return Err(unsupported_syntax_error(
91                    pattern,
92                    "Lifetimes are ignored during type pattern matching. \
93                    Use an anonymous lifetime (`'_`) in this position instead. \
94                    Named lifetime syntax in generic parameters",
95                ));
96            }
97
98            matches!(scrutinee, Lifetime(_))
99        }
100        Type(pattern) => {
101            let scrutinee = match scrutinee {
102                Type(scrutinee) => scrutinee,
103                _ => return Ok(false),
104            };
105            match_types(scrutinee, pattern)?
106        }
107        Const(pattern) => {
108            let scrutinee = match scrutinee {
109                Const(scrutinee) => scrutinee,
110                _ => return Ok(false),
111            };
112            match_exprs(scrutinee, pattern)
113        }
114        AssocType(pattern) => {
115            let scrutinee = match scrutinee {
116                AssocType(scrutinee) => scrutinee,
117                _ => return Ok(false),
118            };
119            scrutinee.ident == pattern.ident
120                && match_types(&scrutinee.ty, &pattern.ty)?
121                && match_option(
122                    scrutinee.generics.as_ref(),
123                    pattern.generics.as_ref(),
124                    match_angle_bracketed_generic_args,
125                )?
126        }
127        AssocConst(pattern) => {
128            let scrutinee = match scrutinee {
129                AssocConst(scrutinee) => scrutinee,
130                _ => return Ok(false),
131            };
132
133            scrutinee.ident == pattern.ident
134                && match_option(
135                    scrutinee.generics.as_ref(),
136                    pattern.generics.as_ref(),
137                    match_angle_bracketed_generic_args,
138                )?
139                && match_exprs(&scrutinee.value, &pattern.value)
140        }
141
142        _ => return Err(unsupported_syntax_error(&pattern, "this syntax")),
143    };
144
145    Ok(verdict)
146}
147
148fn match_exprs(scrutinee: &syn::Expr, pattern: &syn::Expr) -> bool {
149    matches!(pattern, syn::Expr::Infer(_)) || scrutinee == pattern
150}
151
152pub(crate) fn match_types(scrutinee: &syn::Type, pattern: &syn::Type) -> Result<bool> {
153    use syn::Type::*;
154
155    let pattern = pattern.peel();
156
157    if let Infer(_) = pattern {
158        return Ok(true);
159    }
160
161    let scrutinee = scrutinee.peel();
162
163    let verdict = match pattern {
164        Array(pattern) => {
165            let scrutinee = match scrutinee {
166                Array(scrutinee) => scrutinee,
167                _ => return Ok(false),
168            };
169
170            match_types(&scrutinee.elem, &pattern.elem)?
171                && match_exprs(&scrutinee.len, &pattern.len)
172        }
173        Path(pattern) => {
174            if let Some(qself) = &pattern.qself {
175                return Err(unsupported_syntax_error(qself, "<T as Trait> syntax"));
176            }
177
178            let scrutinee = match scrutinee {
179                Path(scrutinee) => scrutinee,
180                _ => return Ok(false),
181            };
182
183            scrutinee.qself.is_none() && match_paths(&scrutinee.path, &pattern.path)?
184        }
185        Ptr(pattern) => {
186            let scrutinee = match scrutinee {
187                Ptr(scrutinee) => scrutinee,
188                _ => return Ok(false),
189            };
190            scrutinee.const_token == pattern.const_token
191                && scrutinee.mutability == pattern.mutability
192                && match_types(&scrutinee.elem, &pattern.elem)?
193        }
194        Reference(pattern) => {
195            if let Some(lifetime) = &pattern.lifetime {
196                return Err(unsupported_syntax_error(
197                    lifetime,
198                    "Lifetimes are ignored during type pattern matching. \
199                    Explicit lifetime syntax",
200                ));
201            }
202
203            let scrutinee = match scrutinee {
204                Reference(scrutinee) => scrutinee,
205                _ => return Ok(false),
206            };
207
208            scrutinee.mutability == pattern.mutability
209                && match_types(&scrutinee.elem, &pattern.elem)?
210        }
211        Slice(pattern) => {
212            let scrutinee = match scrutinee {
213                Slice(scrutinee) => scrutinee,
214                _ => return Ok(false),
215            };
216            match_types(&scrutinee.elem, &pattern.elem)?
217        }
218        Tuple(pattern) => {
219            let scrutinee = match scrutinee {
220                Tuple(scrutinee) => scrutinee,
221                _ => return Ok(false),
222            };
223            scrutinee
224                .elems
225                .iter()
226                .try_equals_with(&pattern.elems, match_types)?
227        }
228
229        Never(_) => matches!(scrutinee, Never(_)),
230
231        _ => return Err(unsupported_syntax_error(&pattern, "this syntax")),
232    };
233
234    Ok(verdict)
235}
236
237fn unsupported_syntax_error(spanned: &impl Spanned, syntax: &str) -> Error {
238    err!(
239        spanned,
240        "{syntax} is not supported in type patterns yet. If you have \
241        a use case for this, please open an issue at \
242        https://github.com/elastio/bon/issues.",
243    )
244}
245
246#[cfg(test)]
247mod tests {
248    // One less `&` character to type in assertions
249    #![allow(clippy::needless_pass_by_value)]
250    use super::*;
251    use syn::parse_quote as pq;
252
253    #[track_caller]
254    fn assert_match_types(scrutinee: syn::Type, pattern: syn::Type) {
255        // Make sure pure wildcard matches everything
256        assert!(scrutinee.matches(&pq!(_)).unwrap());
257
258        assert!(scrutinee.matches(&pattern).unwrap());
259    }
260
261    #[track_caller]
262    fn assert_not_match_types(scrutinee: syn::Type, pattern: syn::Type) {
263        assert!(!scrutinee.matches(&pattern).unwrap());
264    }
265
266    #[track_caller]
267    fn assert_unsupported(pattern: syn::Type) {
268        let scrutinee: syn::Type = syn::parse_quote!(());
269        let result = scrutinee.matches(&pattern);
270        let err = result.unwrap_err().to_string();
271        assert!(
272            err.contains("is not supported in type patterns yet"),
273            "Error: {err}"
274        );
275    }
276
277    #[test]
278    fn array() {
279        assert_match_types(pq!([u8; 4]), pq!([u8; 4]));
280        assert_match_types(pq!([u8; 4]), pq!([_; 4]));
281        assert_match_types(pq!([u8; 4]), pq!([_; _]));
282        assert_match_types(pq!([u8; 4]), pq!([u8; _]));
283
284        assert_not_match_types(pq!([u8; 4]), pq!([u8; 5]));
285        assert_not_match_types(pq!([u8; 4]), pq!([_; 5]));
286
287        assert_not_match_types(pq!([u8; 4]), pq!([u16; 4]));
288        assert_not_match_types(pq!([u8; 4]), pq!([u16; _]));
289
290        assert_not_match_types(pq!([u8; 4]), pq!([_]));
291        assert_not_match_types(pq!([u8; 4]), pq!([u8]));
292    }
293
294    #[test]
295    fn path() {
296        assert_match_types(pq!(bool), pq!(bool));
297        assert_match_types(pq!(foo::Bar), pq!(foo::Bar));
298        assert_match_types(pq!(crate::foo::Bar), pq!(crate::foo::Bar));
299        assert_match_types(pq!(super::foo::Bar), pq!(super::foo::Bar));
300
301        assert_not_match_types(pq!(::Bar), pq!(Bar));
302        assert_not_match_types(pq!(Bar), pq!(::Bar));
303        assert_not_match_types(pq!(super::foo::Bar), pq!(crate::foo::Bar));
304
305        assert_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<_>));
306        assert_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<u32>));
307        assert_match_types(pq!(foo::Bar<u32, bool>), pq!(foo::Bar<u32, _>));
308        assert_match_types(pq!(foo::Bar<u32, bool>), pq!(foo::Bar<_, bool>));
309        assert_match_types(pq!(foo::Bar<u32, bool>), pq!(foo::Bar<u32, bool>));
310        assert_match_types(pq!(foo::Bar<u32, bool>), pq!(foo::Bar<_, _>));
311
312        assert_not_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<bool>));
313        assert_not_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar));
314        assert_not_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<u32, _>));
315        assert_not_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<_, _>));
316        assert_not_match_types(pq!(foo::Foo<u32>), pq!(foo::Bar<u32>));
317    }
318
319    #[test]
320    fn ptr() {
321        assert_match_types(pq!(*const u8), pq!(*const u8));
322        assert_match_types(pq!(*const u8), pq!(*const _));
323        assert_match_types(pq!(*mut u8), pq!(*mut u8));
324        assert_match_types(pq!(*mut u8), pq!(*mut _));
325
326        assert_not_match_types(pq!(*const u8), pq!(*mut u8));
327        assert_not_match_types(pq!(*const u8), pq!(*mut _));
328        assert_not_match_types(pq!(*mut u8), pq!(*const u8));
329        assert_not_match_types(pq!(*mut u8), pq!(*const _));
330    }
331
332    #[test]
333    fn reference() {
334        assert_match_types(pq!(&u8), pq!(&u8));
335        assert_match_types(pq!(&u8), pq!(&_));
336        assert_match_types(pq!(&mut u8), pq!(&mut u8));
337        assert_match_types(pq!(&mut u8), pq!(&mut _));
338
339        assert_match_types(pq!(&'a u8), pq!(&_));
340        assert_match_types(pq!(&'_ u8), pq!(&_));
341        assert_match_types(pq!(&'static u8), pq!(&_));
342        assert_match_types(pq!(&'a mut u8), pq!(&mut _));
343        assert_match_types(pq!(&'_ mut u8), pq!(&mut _));
344        assert_match_types(pq!(&'static mut u8), pq!(&mut _));
345
346        assert_match_types(pq!(&'a u8), pq!(&u8));
347        assert_match_types(pq!(&'_ u8), pq!(&u8));
348        assert_match_types(pq!(&'static u8), pq!(&u8));
349        assert_match_types(pq!(&'a mut u8), pq!(&mut u8));
350        assert_match_types(pq!(&'_ mut u8), pq!(&mut u8));
351        assert_match_types(pq!(&'static mut u8), pq!(&mut u8));
352
353        assert_not_match_types(pq!(&u8), pq!(&mut u8));
354        assert_not_match_types(pq!(&u8), pq!(&mut _));
355        assert_not_match_types(pq!(&mut u8), pq!(&u8));
356        assert_not_match_types(pq!(&mut u8), pq!(&_));
357    }
358
359    #[test]
360    fn slice() {
361        assert_match_types(pq!([u8]), pq!([u8]));
362        assert_match_types(pq!([u8]), pq!([_]));
363        assert_match_types(pq!(&[u8]), pq!(&[u8]));
364        assert_match_types(pq!(&[u8]), pq!(&[_]));
365        assert_match_types(pq!(&[u8]), pq!(&_));
366
367        assert_not_match_types(pq!([u8]), pq!([u16]));
368        assert_not_match_types(pq!([u8]), pq!([u8; 4]));
369    }
370
371    #[test]
372    fn tuple() {
373        assert_match_types(pq!((u8, bool)), pq!((u8, bool)));
374        assert_match_types(pq!((u8, bool)), pq!((_, _)));
375        assert_match_types(pq!((u8, bool)), pq!((u8, _)));
376        assert_match_types(pq!((u8, bool)), pq!((_, bool)));
377
378        assert_match_types(pq!(()), pq!(()));
379        assert_match_types(pq!((u8,)), pq!((u8,)));
380        assert_match_types(pq!((u8,)), pq!((_,)));
381
382        assert_not_match_types(pq!((u8, bool)), pq!((bool, u8)));
383        assert_not_match_types(pq!((u8, bool)), pq!((u8, bool, u8)));
384
385        assert_not_match_types(pq!((u8,)), pq!(()));
386        assert_not_match_types(pq!(()), pq!((u8,)));
387    }
388
389    #[test]
390    fn never() {
391        assert_match_types(pq!(!), pq!(!));
392        assert_not_match_types(pq!(!), pq!(bool));
393    }
394
395    #[test]
396    fn unsupported() {
397        assert_unsupported(pq!(dyn Trait));
398        assert_unsupported(pq!(dyn FnOnce()));
399
400        assert_unsupported(pq!(impl Trait));
401        assert_unsupported(pq!(impl FnOnce()));
402
403        assert_unsupported(pq!(fn()));
404
405        assert_unsupported(pq!(&'a _));
406        assert_unsupported(pq!(&'_ _));
407        assert_unsupported(pq!(&'static _));
408
409        assert_unsupported(pq!(for<'a> Trait<'a>));
410        assert_unsupported(pq!(for<'a> fn()));
411
412        assert_unsupported(pq!(<T as Trait>::Foo));
413    }
414}