1use crate::util::iterator::IntoIteratorExt;
2use crate::util::prelude::*;
3use syn::spanned::Spanned;
4
5pub(crate) fn match_return_types(
6 scrutinee: &syn::ReturnType,
7 pattern: &syn::ReturnType,
8) -> Result<bool> {
9 match (scrutinee, pattern) {
10 (syn::ReturnType::Default, syn::ReturnType::Default) => Ok(true),
11 (syn::ReturnType::Default, syn::ReturnType::Type(_, pattern)) => {
12 match_types(&syn::parse_quote!(()), pattern)
13 }
14 (syn::ReturnType::Type(_, scrutinee), syn::ReturnType::Default) => {
15 Ok(**scrutinee == syn::parse_quote!(()))
16 }
17 (syn::ReturnType::Type(_, scrutinee), syn::ReturnType::Type(_, pattern)) => {
18 match_types(scrutinee, pattern)
19 }
20 }
21}
22
23fn match_paths(scrutinee: &syn::Path, pattern: &syn::Path) -> Result<bool> {
24 let verdict = scrutinee.leading_colon == pattern.leading_colon
25 && scrutinee
26 .segments
27 .iter()
28 .try_equals_with(&pattern.segments, |scrutinee, pattern| {
29 let verdict = scrutinee.ident == pattern.ident
30 && match_path_args(&scrutinee.arguments, &pattern.arguments)?;
31
32 Ok(verdict)
33 })?;
34
35 Ok(verdict)
36}
37
38fn match_path_args(scrutinee: &syn::PathArguments, pattern: &syn::PathArguments) -> Result<bool> {
39 use syn::PathArguments::*;
40
41 let verdict = match (scrutinee, pattern) {
42 (None, None) => true,
43 (AngleBracketed(scrutinee), AngleBracketed(pattern)) => {
44 match_angle_bracketed_generic_args(scrutinee, pattern)?
45 }
46 (Parenthesized(scrutinee), Parenthesized(pattern)) => {
47 scrutinee
48 .inputs
49 .iter()
50 .try_equals_with(&pattern.inputs, match_types)?
51 && match_return_types(&scrutinee.output, &pattern.output)?
52 }
53 _ => false,
54 };
55
56 Ok(verdict)
57}
58
59fn match_angle_bracketed_generic_args(
60 scrutinee: &syn::AngleBracketedGenericArguments,
61 pattern: &syn::AngleBracketedGenericArguments,
62) -> Result<bool> {
63 scrutinee
64 .args
65 .iter()
66 .try_equals_with(&pattern.args, match_generic_args)
67}
68
69fn match_option<T>(
70 scrutinee: Option<&T>,
71 pattern: Option<&T>,
72 compare: impl Fn(&T, &T) -> Result<bool>,
73) -> Result<bool> {
74 match (scrutinee, &pattern) {
75 (None, None) => Ok(true),
76 (Some(scrutinee), Some(pattern)) => compare(scrutinee, pattern),
77 _ => Ok(false),
78 }
79}
80
81fn match_generic_args(
82 scrutinee: &syn::GenericArgument,
83 pattern: &syn::GenericArgument,
84) -> Result<bool> {
85 use syn::GenericArgument::*;
86
87 let verdict = match pattern {
88 Lifetime(pattern) => {
89 if pattern.ident != "_" {
90 return Err(unsupported_syntax_error(
91 pattern,
92 "Lifetimes are ignored during type pattern matching. \
93 Use an anonymous lifetime (`'_`) in this position instead. \
94 Named lifetime syntax in generic parameters",
95 ));
96 }
97
98 matches!(scrutinee, Lifetime(_))
99 }
100 Type(pattern) => {
101 let scrutinee = match scrutinee {
102 Type(scrutinee) => scrutinee,
103 _ => return Ok(false),
104 };
105 match_types(scrutinee, pattern)?
106 }
107 Const(pattern) => {
108 let scrutinee = match scrutinee {
109 Const(scrutinee) => scrutinee,
110 _ => return Ok(false),
111 };
112 match_exprs(scrutinee, pattern)
113 }
114 AssocType(pattern) => {
115 let scrutinee = match scrutinee {
116 AssocType(scrutinee) => scrutinee,
117 _ => return Ok(false),
118 };
119 scrutinee.ident == pattern.ident
120 && match_types(&scrutinee.ty, &pattern.ty)?
121 && match_option(
122 scrutinee.generics.as_ref(),
123 pattern.generics.as_ref(),
124 match_angle_bracketed_generic_args,
125 )?
126 }
127 AssocConst(pattern) => {
128 let scrutinee = match scrutinee {
129 AssocConst(scrutinee) => scrutinee,
130 _ => return Ok(false),
131 };
132
133 scrutinee.ident == pattern.ident
134 && match_option(
135 scrutinee.generics.as_ref(),
136 pattern.generics.as_ref(),
137 match_angle_bracketed_generic_args,
138 )?
139 && match_exprs(&scrutinee.value, &pattern.value)
140 }
141
142 _ => return Err(unsupported_syntax_error(&pattern, "this syntax")),
143 };
144
145 Ok(verdict)
146}
147
148fn match_exprs(scrutinee: &syn::Expr, pattern: &syn::Expr) -> bool {
149 matches!(pattern, syn::Expr::Infer(_)) || scrutinee == pattern
150}
151
152pub(crate) fn match_types(scrutinee: &syn::Type, pattern: &syn::Type) -> Result<bool> {
153 use syn::Type::*;
154
155 let pattern = pattern.peel();
156
157 if let Infer(_) = pattern {
158 return Ok(true);
159 }
160
161 let scrutinee = scrutinee.peel();
162
163 let verdict = match pattern {
164 Array(pattern) => {
165 let scrutinee = match scrutinee {
166 Array(scrutinee) => scrutinee,
167 _ => return Ok(false),
168 };
169
170 match_types(&scrutinee.elem, &pattern.elem)?
171 && match_exprs(&scrutinee.len, &pattern.len)
172 }
173 Path(pattern) => {
174 if let Some(qself) = &pattern.qself {
175 return Err(unsupported_syntax_error(qself, "<T as Trait> syntax"));
176 }
177
178 let scrutinee = match scrutinee {
179 Path(scrutinee) => scrutinee,
180 _ => return Ok(false),
181 };
182
183 scrutinee.qself.is_none() && match_paths(&scrutinee.path, &pattern.path)?
184 }
185 Ptr(pattern) => {
186 let scrutinee = match scrutinee {
187 Ptr(scrutinee) => scrutinee,
188 _ => return Ok(false),
189 };
190 scrutinee.const_token == pattern.const_token
191 && scrutinee.mutability == pattern.mutability
192 && match_types(&scrutinee.elem, &pattern.elem)?
193 }
194 Reference(pattern) => {
195 if let Some(lifetime) = &pattern.lifetime {
196 return Err(unsupported_syntax_error(
197 lifetime,
198 "Lifetimes are ignored during type pattern matching. \
199 Explicit lifetime syntax",
200 ));
201 }
202
203 let scrutinee = match scrutinee {
204 Reference(scrutinee) => scrutinee,
205 _ => return Ok(false),
206 };
207
208 scrutinee.mutability == pattern.mutability
209 && match_types(&scrutinee.elem, &pattern.elem)?
210 }
211 Slice(pattern) => {
212 let scrutinee = match scrutinee {
213 Slice(scrutinee) => scrutinee,
214 _ => return Ok(false),
215 };
216 match_types(&scrutinee.elem, &pattern.elem)?
217 }
218 Tuple(pattern) => {
219 let scrutinee = match scrutinee {
220 Tuple(scrutinee) => scrutinee,
221 _ => return Ok(false),
222 };
223 scrutinee
224 .elems
225 .iter()
226 .try_equals_with(&pattern.elems, match_types)?
227 }
228
229 Never(_) => matches!(scrutinee, Never(_)),
230
231 _ => return Err(unsupported_syntax_error(&pattern, "this syntax")),
232 };
233
234 Ok(verdict)
235}
236
237fn unsupported_syntax_error(spanned: &impl Spanned, syntax: &str) -> Error {
238 err!(
239 spanned,
240 "{syntax} is not supported in type patterns yet. If you have \
241 a use case for this, please open an issue at \
242 https://github.com/elastio/bon/issues.",
243 )
244}
245
246#[cfg(test)]
247mod tests {
248 #![allow(clippy::needless_pass_by_value)]
250 use super::*;
251 use syn::parse_quote as pq;
252
253 #[track_caller]
254 fn assert_match_types(scrutinee: syn::Type, pattern: syn::Type) {
255 assert!(scrutinee.matches(&pq!(_)).unwrap());
257
258 assert!(scrutinee.matches(&pattern).unwrap());
259 }
260
261 #[track_caller]
262 fn assert_not_match_types(scrutinee: syn::Type, pattern: syn::Type) {
263 assert!(!scrutinee.matches(&pattern).unwrap());
264 }
265
266 #[track_caller]
267 fn assert_unsupported(pattern: syn::Type) {
268 let scrutinee: syn::Type = syn::parse_quote!(());
269 let result = scrutinee.matches(&pattern);
270 let err = result.unwrap_err().to_string();
271 assert!(
272 err.contains("is not supported in type patterns yet"),
273 "Error: {err}"
274 );
275 }
276
277 #[test]
278 fn array() {
279 assert_match_types(pq!([u8; 4]), pq!([u8; 4]));
280 assert_match_types(pq!([u8; 4]), pq!([_; 4]));
281 assert_match_types(pq!([u8; 4]), pq!([_; _]));
282 assert_match_types(pq!([u8; 4]), pq!([u8; _]));
283
284 assert_not_match_types(pq!([u8; 4]), pq!([u8; 5]));
285 assert_not_match_types(pq!([u8; 4]), pq!([_; 5]));
286
287 assert_not_match_types(pq!([u8; 4]), pq!([u16; 4]));
288 assert_not_match_types(pq!([u8; 4]), pq!([u16; _]));
289
290 assert_not_match_types(pq!([u8; 4]), pq!([_]));
291 assert_not_match_types(pq!([u8; 4]), pq!([u8]));
292 }
293
294 #[test]
295 fn path() {
296 assert_match_types(pq!(bool), pq!(bool));
297 assert_match_types(pq!(foo::Bar), pq!(foo::Bar));
298 assert_match_types(pq!(crate::foo::Bar), pq!(crate::foo::Bar));
299 assert_match_types(pq!(super::foo::Bar), pq!(super::foo::Bar));
300
301 assert_not_match_types(pq!(::Bar), pq!(Bar));
302 assert_not_match_types(pq!(Bar), pq!(::Bar));
303 assert_not_match_types(pq!(super::foo::Bar), pq!(crate::foo::Bar));
304
305 assert_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<_>));
306 assert_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<u32>));
307 assert_match_types(pq!(foo::Bar<u32, bool>), pq!(foo::Bar<u32, _>));
308 assert_match_types(pq!(foo::Bar<u32, bool>), pq!(foo::Bar<_, bool>));
309 assert_match_types(pq!(foo::Bar<u32, bool>), pq!(foo::Bar<u32, bool>));
310 assert_match_types(pq!(foo::Bar<u32, bool>), pq!(foo::Bar<_, _>));
311
312 assert_not_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<bool>));
313 assert_not_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar));
314 assert_not_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<u32, _>));
315 assert_not_match_types(pq!(foo::Bar<u32>), pq!(foo::Bar<_, _>));
316 assert_not_match_types(pq!(foo::Foo<u32>), pq!(foo::Bar<u32>));
317 }
318
319 #[test]
320 fn ptr() {
321 assert_match_types(pq!(*const u8), pq!(*const u8));
322 assert_match_types(pq!(*const u8), pq!(*const _));
323 assert_match_types(pq!(*mut u8), pq!(*mut u8));
324 assert_match_types(pq!(*mut u8), pq!(*mut _));
325
326 assert_not_match_types(pq!(*const u8), pq!(*mut u8));
327 assert_not_match_types(pq!(*const u8), pq!(*mut _));
328 assert_not_match_types(pq!(*mut u8), pq!(*const u8));
329 assert_not_match_types(pq!(*mut u8), pq!(*const _));
330 }
331
332 #[test]
333 fn reference() {
334 assert_match_types(pq!(&u8), pq!(&u8));
335 assert_match_types(pq!(&u8), pq!(&_));
336 assert_match_types(pq!(&mut u8), pq!(&mut u8));
337 assert_match_types(pq!(&mut u8), pq!(&mut _));
338
339 assert_match_types(pq!(&'a u8), pq!(&_));
340 assert_match_types(pq!(&'_ u8), pq!(&_));
341 assert_match_types(pq!(&'static u8), pq!(&_));
342 assert_match_types(pq!(&'a mut u8), pq!(&mut _));
343 assert_match_types(pq!(&'_ mut u8), pq!(&mut _));
344 assert_match_types(pq!(&'static mut u8), pq!(&mut _));
345
346 assert_match_types(pq!(&'a u8), pq!(&u8));
347 assert_match_types(pq!(&'_ u8), pq!(&u8));
348 assert_match_types(pq!(&'static u8), pq!(&u8));
349 assert_match_types(pq!(&'a mut u8), pq!(&mut u8));
350 assert_match_types(pq!(&'_ mut u8), pq!(&mut u8));
351 assert_match_types(pq!(&'static mut u8), pq!(&mut u8));
352
353 assert_not_match_types(pq!(&u8), pq!(&mut u8));
354 assert_not_match_types(pq!(&u8), pq!(&mut _));
355 assert_not_match_types(pq!(&mut u8), pq!(&u8));
356 assert_not_match_types(pq!(&mut u8), pq!(&_));
357 }
358
359 #[test]
360 fn slice() {
361 assert_match_types(pq!([u8]), pq!([u8]));
362 assert_match_types(pq!([u8]), pq!([_]));
363 assert_match_types(pq!(&[u8]), pq!(&[u8]));
364 assert_match_types(pq!(&[u8]), pq!(&[_]));
365 assert_match_types(pq!(&[u8]), pq!(&_));
366
367 assert_not_match_types(pq!([u8]), pq!([u16]));
368 assert_not_match_types(pq!([u8]), pq!([u8; 4]));
369 }
370
371 #[test]
372 fn tuple() {
373 assert_match_types(pq!((u8, bool)), pq!((u8, bool)));
374 assert_match_types(pq!((u8, bool)), pq!((_, _)));
375 assert_match_types(pq!((u8, bool)), pq!((u8, _)));
376 assert_match_types(pq!((u8, bool)), pq!((_, bool)));
377
378 assert_match_types(pq!(()), pq!(()));
379 assert_match_types(pq!((u8,)), pq!((u8,)));
380 assert_match_types(pq!((u8,)), pq!((_,)));
381
382 assert_not_match_types(pq!((u8, bool)), pq!((bool, u8)));
383 assert_not_match_types(pq!((u8, bool)), pq!((u8, bool, u8)));
384
385 assert_not_match_types(pq!((u8,)), pq!(()));
386 assert_not_match_types(pq!(()), pq!((u8,)));
387 }
388
389 #[test]
390 fn never() {
391 assert_match_types(pq!(!), pq!(!));
392 assert_not_match_types(pq!(!), pq!(bool));
393 }
394
395 #[test]
396 fn unsupported() {
397 assert_unsupported(pq!(dyn Trait));
398 assert_unsupported(pq!(dyn FnOnce()));
399
400 assert_unsupported(pq!(impl Trait));
401 assert_unsupported(pq!(impl FnOnce()));
402
403 assert_unsupported(pq!(fn()));
404
405 assert_unsupported(pq!(&'a _));
406 assert_unsupported(pq!(&'_ _));
407 assert_unsupported(pq!(&'static _));
408
409 assert_unsupported(pq!(for<'a> Trait<'a>));
410 assert_unsupported(pq!(for<'a> fn()));
411
412 assert_unsupported(pq!(<T as Trait>::Foo));
413 }
414}