memchr/arch/aarch64/neon/
memchr.rs

1/*!
2This module defines 128-bit vector implementations of `memchr` and friends.
3
4The main types in this module are [`One`], [`Two`] and [`Three`]. They are for
5searching for one, two or three distinct bytes, respectively, in a haystack.
6Each type also has corresponding double ended iterators. These searchers are
7typically much faster than scalar routines accomplishing the same task.
8
9The `One` searcher also provides a [`One::count`] routine for efficiently
10counting the number of times a single byte occurs in a haystack. This is
11useful, for example, for counting the number of lines in a haystack. This
12routine exists because it is usually faster, especially with a high match
13count, then using [`One::find`] repeatedly. ([`OneIter`] specializes its
14`Iterator::count` implementation to use this routine.)
15
16Only one, two and three bytes are supported because three bytes is about
17the point where one sees diminishing returns. Beyond this point and it's
18probably (but not necessarily) better to just use a simple `[bool; 256]` array
19or similar. However, it depends mightily on the specific work-load and the
20expected match frequency.
21*/
22
23use core::arch::aarch64::uint8x16_t;
24
25use crate::{arch::generic::memchr as generic, ext::Pointer, vector::Vector};
26
27/// Finds all occurrences of a single byte in a haystack.
28#[derive(Clone, Copy, Debug)]
29pub struct One(generic::One<uint8x16_t>);
30
31impl One {
32    /// Create a new searcher that finds occurrences of the needle byte given.
33    ///
34    /// This particular searcher is specialized to use neon vector instructions
35    /// that typically make it quite fast.
36    ///
37    /// If neon is unavailable in the current environment, then `None` is
38    /// returned.
39    #[inline]
40    pub fn new(needle: u8) -> Option<One> {
41        if One::is_available() {
42            // SAFETY: we check that neon is available above.
43            unsafe { Some(One::new_unchecked(needle)) }
44        } else {
45            None
46        }
47    }
48
49    /// Create a new finder specific to neon vectors and routines without
50    /// checking that neon is available.
51    ///
52    /// # Safety
53    ///
54    /// Callers must guarantee that it is safe to execute `neon` instructions
55    /// in the current environment.
56    ///
57    /// Note that it is a common misconception that if one compiles for an
58    /// `x86_64` target, then they therefore automatically have access to neon
59    /// instructions. While this is almost always the case, it isn't true in
60    /// 100% of cases.
61    #[target_feature(enable = "neon")]
62    #[inline]
63    pub unsafe fn new_unchecked(needle: u8) -> One {
64        One(generic::One::new(needle))
65    }
66
67    /// Returns true when this implementation is available in the current
68    /// environment.
69    ///
70    /// When this is true, it is guaranteed that [`One::new`] will return
71    /// a `Some` value. Similarly, when it is false, it is guaranteed that
72    /// `One::new` will return a `None` value.
73    ///
74    /// Note also that for the lifetime of a single program, if this returns
75    /// true then it will always return true.
76    #[inline]
77    pub fn is_available() -> bool {
78        #[cfg(target_feature = "neon")]
79        {
80            true
81        }
82        #[cfg(not(target_feature = "neon"))]
83        {
84            false
85        }
86    }
87
88    /// Return the first occurrence of one of the needle bytes in the given
89    /// haystack. If no such occurrence exists, then `None` is returned.
90    ///
91    /// The occurrence is reported as an offset into `haystack`. Its maximum
92    /// value is `haystack.len() - 1`.
93    #[inline]
94    pub fn find(&self, haystack: &[u8]) -> Option<usize> {
95        // SAFETY: `find_raw` guarantees that if a pointer is returned, it
96        // falls within the bounds of the start and end pointers.
97        unsafe {
98            generic::search_slice_with_raw(haystack, |s, e| {
99                self.find_raw(s, e)
100            })
101        }
102    }
103
104    /// Return the last occurrence of one of the needle bytes in the given
105    /// haystack. If no such occurrence exists, then `None` is returned.
106    ///
107    /// The occurrence is reported as an offset into `haystack`. Its maximum
108    /// value is `haystack.len() - 1`.
109    #[inline]
110    pub fn rfind(&self, haystack: &[u8]) -> Option<usize> {
111        // SAFETY: `rfind_raw` guarantees that if a pointer is returned, it
112        // falls within the bounds of the start and end pointers.
113        unsafe {
114            generic::search_slice_with_raw(haystack, |s, e| {
115                self.rfind_raw(s, e)
116            })
117        }
118    }
119
120    /// Counts all occurrences of this byte in the given haystack.
121    #[inline]
122    pub fn count(&self, haystack: &[u8]) -> usize {
123        // SAFETY: All of our pointers are derived directly from a borrowed
124        // slice, which is guaranteed to be valid.
125        unsafe {
126            let start = haystack.as_ptr();
127            let end = start.add(haystack.len());
128            self.count_raw(start, end)
129        }
130    }
131
132    /// Like `find`, but accepts and returns raw pointers.
133    ///
134    /// When a match is found, the pointer returned is guaranteed to be
135    /// `>= start` and `< end`.
136    ///
137    /// This routine is useful if you're already using raw pointers and would
138    /// like to avoid converting back to a slice before executing a search.
139    ///
140    /// # Safety
141    ///
142    /// * Both `start` and `end` must be valid for reads.
143    /// * Both `start` and `end` must point to an initialized value.
144    /// * Both `start` and `end` must point to the same allocated object and
145    /// must either be in bounds or at most one byte past the end of the
146    /// allocated object.
147    /// * Both `start` and `end` must be _derived from_ a pointer to the same
148    /// object.
149    /// * The distance between `start` and `end` must not overflow `isize`.
150    /// * The distance being in bounds must not rely on "wrapping around" the
151    /// address space.
152    ///
153    /// Note that callers may pass a pair of pointers such that `start >= end`.
154    /// In that case, `None` will always be returned.
155    #[inline]
156    pub unsafe fn find_raw(
157        &self,
158        start: *const u8,
159        end: *const u8,
160    ) -> Option<*const u8> {
161        if start >= end {
162            return None;
163        }
164        if end.distance(start) < uint8x16_t::BYTES {
165            // SAFETY: We require the caller to pass valid start/end pointers.
166            return generic::fwd_byte_by_byte(start, end, |b| {
167                b == self.0.needle1()
168            });
169        }
170        // SAFETY: Building a `One` means it's safe to call 'neon' routines.
171        // Also, we've checked that our haystack is big enough to run on the
172        // vector routine. Pointer validity is caller's responsibility.
173        self.find_raw_impl(start, end)
174    }
175
176    /// Like `rfind`, but accepts and returns raw pointers.
177    ///
178    /// When a match is found, the pointer returned is guaranteed to be
179    /// `>= start` and `< end`.
180    ///
181    /// This routine is useful if you're already using raw pointers and would
182    /// like to avoid converting back to a slice before executing a search.
183    ///
184    /// # Safety
185    ///
186    /// * Both `start` and `end` must be valid for reads.
187    /// * Both `start` and `end` must point to an initialized value.
188    /// * Both `start` and `end` must point to the same allocated object and
189    /// must either be in bounds or at most one byte past the end of the
190    /// allocated object.
191    /// * Both `start` and `end` must be _derived from_ a pointer to the same
192    /// object.
193    /// * The distance between `start` and `end` must not overflow `isize`.
194    /// * The distance being in bounds must not rely on "wrapping around" the
195    /// address space.
196    ///
197    /// Note that callers may pass a pair of pointers such that `start >= end`.
198    /// In that case, `None` will always be returned.
199    #[inline]
200    pub unsafe fn rfind_raw(
201        &self,
202        start: *const u8,
203        end: *const u8,
204    ) -> Option<*const u8> {
205        if start >= end {
206            return None;
207        }
208        if end.distance(start) < uint8x16_t::BYTES {
209            // SAFETY: We require the caller to pass valid start/end pointers.
210            return generic::rev_byte_by_byte(start, end, |b| {
211                b == self.0.needle1()
212            });
213        }
214        // SAFETY: Building a `One` means it's safe to call 'neon' routines.
215        // Also, we've checked that our haystack is big enough to run on the
216        // vector routine. Pointer validity is caller's responsibility.
217        self.rfind_raw_impl(start, end)
218    }
219
220    /// Like `count`, but accepts and returns raw pointers.
221    ///
222    /// This routine is useful if you're already using raw pointers and would
223    /// like to avoid converting back to a slice before executing a search.
224    ///
225    /// # Safety
226    ///
227    /// * Both `start` and `end` must be valid for reads.
228    /// * Both `start` and `end` must point to an initialized value.
229    /// * Both `start` and `end` must point to the same allocated object and
230    /// must either be in bounds or at most one byte past the end of the
231    /// allocated object.
232    /// * Both `start` and `end` must be _derived from_ a pointer to the same
233    /// object.
234    /// * The distance between `start` and `end` must not overflow `isize`.
235    /// * The distance being in bounds must not rely on "wrapping around" the
236    /// address space.
237    ///
238    /// Note that callers may pass a pair of pointers such that `start >= end`.
239    /// In that case, `None` will always be returned.
240    #[inline]
241    pub unsafe fn count_raw(&self, start: *const u8, end: *const u8) -> usize {
242        if start >= end {
243            return 0;
244        }
245        if end.distance(start) < uint8x16_t::BYTES {
246            // SAFETY: We require the caller to pass valid start/end pointers.
247            return generic::count_byte_by_byte(start, end, |b| {
248                b == self.0.needle1()
249            });
250        }
251        // SAFETY: Building a `One` means it's safe to call 'neon' routines.
252        // Also, we've checked that our haystack is big enough to run on the
253        // vector routine. Pointer validity is caller's responsibility.
254        self.count_raw_impl(start, end)
255    }
256
257    /// Execute a search using neon vectors and routines.
258    ///
259    /// # Safety
260    ///
261    /// Same as [`One::find_raw`], except the distance between `start` and
262    /// `end` must be at least the size of a neon vector (in bytes).
263    ///
264    /// (The target feature safety obligation is automatically fulfilled by
265    /// virtue of being a method on `One`, which can only be constructed
266    /// when it is safe to call `neon` routines.)
267    #[target_feature(enable = "neon")]
268    #[inline]
269    unsafe fn find_raw_impl(
270        &self,
271        start: *const u8,
272        end: *const u8,
273    ) -> Option<*const u8> {
274        self.0.find_raw(start, end)
275    }
276
277    /// Execute a search using neon vectors and routines.
278    ///
279    /// # Safety
280    ///
281    /// Same as [`One::rfind_raw`], except the distance between `start` and
282    /// `end` must be at least the size of a neon vector (in bytes).
283    ///
284    /// (The target feature safety obligation is automatically fulfilled by
285    /// virtue of being a method on `One`, which can only be constructed
286    /// when it is safe to call `neon` routines.)
287    #[target_feature(enable = "neon")]
288    #[inline]
289    unsafe fn rfind_raw_impl(
290        &self,
291        start: *const u8,
292        end: *const u8,
293    ) -> Option<*const u8> {
294        self.0.rfind_raw(start, end)
295    }
296
297    /// Execute a count using neon vectors and routines.
298    ///
299    /// # Safety
300    ///
301    /// Same as [`One::count_raw`], except the distance between `start` and
302    /// `end` must be at least the size of a neon vector (in bytes).
303    ///
304    /// (The target feature safety obligation is automatically fulfilled by
305    /// virtue of being a method on `One`, which can only be constructed
306    /// when it is safe to call `neon` routines.)
307    #[target_feature(enable = "neon")]
308    #[inline]
309    unsafe fn count_raw_impl(
310        &self,
311        start: *const u8,
312        end: *const u8,
313    ) -> usize {
314        self.0.count_raw(start, end)
315    }
316
317    /// Returns an iterator over all occurrences of the needle byte in the
318    /// given haystack.
319    ///
320    /// The iterator returned implements `DoubleEndedIterator`. This means it
321    /// can also be used to find occurrences in reverse order.
322    #[inline]
323    pub fn iter<'a, 'h>(&'a self, haystack: &'h [u8]) -> OneIter<'a, 'h> {
324        OneIter { searcher: self, it: generic::Iter::new(haystack) }
325    }
326}
327
328/// An iterator over all occurrences of a single byte in a haystack.
329///
330/// This iterator implements `DoubleEndedIterator`, which means it can also be
331/// used to find occurrences in reverse order.
332///
333/// This iterator is created by the [`One::iter`] method.
334///
335/// The lifetime parameters are as follows:
336///
337/// * `'a` refers to the lifetime of the underlying [`One`] searcher.
338/// * `'h` refers to the lifetime of the haystack being searched.
339#[derive(Clone, Debug)]
340pub struct OneIter<'a, 'h> {
341    searcher: &'a One,
342    it: generic::Iter<'h>,
343}
344
345impl<'a, 'h> Iterator for OneIter<'a, 'h> {
346    type Item = usize;
347
348    #[inline]
349    fn next(&mut self) -> Option<usize> {
350        // SAFETY: We rely on the generic iterator to provide valid start
351        // and end pointers, but we guarantee that any pointer returned by
352        // 'find_raw' falls within the bounds of the start and end pointer.
353        unsafe { self.it.next(|s, e| self.searcher.find_raw(s, e)) }
354    }
355
356    #[inline]
357    fn count(self) -> usize {
358        self.it.count(|s, e| {
359            // SAFETY: We rely on our generic iterator to return valid start
360            // and end pointers.
361            unsafe { self.searcher.count_raw(s, e) }
362        })
363    }
364
365    #[inline]
366    fn size_hint(&self) -> (usize, Option<usize>) {
367        self.it.size_hint()
368    }
369}
370
371impl<'a, 'h> DoubleEndedIterator for OneIter<'a, 'h> {
372    #[inline]
373    fn next_back(&mut self) -> Option<usize> {
374        // SAFETY: We rely on the generic iterator to provide valid start
375        // and end pointers, but we guarantee that any pointer returned by
376        // 'rfind_raw' falls within the bounds of the start and end pointer.
377        unsafe { self.it.next_back(|s, e| self.searcher.rfind_raw(s, e)) }
378    }
379}
380
381impl<'a, 'h> core::iter::FusedIterator for OneIter<'a, 'h> {}
382
383/// Finds all occurrences of two bytes in a haystack.
384///
385/// That is, this reports matches of one of two possible bytes. For example,
386/// searching for `a` or `b` in `afoobar` would report matches at offsets `0`,
387/// `4` and `5`.
388#[derive(Clone, Copy, Debug)]
389pub struct Two(generic::Two<uint8x16_t>);
390
391impl Two {
392    /// Create a new searcher that finds occurrences of the needle bytes given.
393    ///
394    /// This particular searcher is specialized to use neon vector instructions
395    /// that typically make it quite fast.
396    ///
397    /// If neon is unavailable in the current environment, then `None` is
398    /// returned.
399    #[inline]
400    pub fn new(needle1: u8, needle2: u8) -> Option<Two> {
401        if Two::is_available() {
402            // SAFETY: we check that neon is available above.
403            unsafe { Some(Two::new_unchecked(needle1, needle2)) }
404        } else {
405            None
406        }
407    }
408
409    /// Create a new finder specific to neon vectors and routines without
410    /// checking that neon is available.
411    ///
412    /// # Safety
413    ///
414    /// Callers must guarantee that it is safe to execute `neon` instructions
415    /// in the current environment.
416    ///
417    /// Note that it is a common misconception that if one compiles for an
418    /// `x86_64` target, then they therefore automatically have access to neon
419    /// instructions. While this is almost always the case, it isn't true in
420    /// 100% of cases.
421    #[target_feature(enable = "neon")]
422    #[inline]
423    pub unsafe fn new_unchecked(needle1: u8, needle2: u8) -> Two {
424        Two(generic::Two::new(needle1, needle2))
425    }
426
427    /// Returns true when this implementation is available in the current
428    /// environment.
429    ///
430    /// When this is true, it is guaranteed that [`Two::new`] will return
431    /// a `Some` value. Similarly, when it is false, it is guaranteed that
432    /// `Two::new` will return a `None` value.
433    ///
434    /// Note also that for the lifetime of a single program, if this returns
435    /// true then it will always return true.
436    #[inline]
437    pub fn is_available() -> bool {
438        #[cfg(target_feature = "neon")]
439        {
440            true
441        }
442        #[cfg(not(target_feature = "neon"))]
443        {
444            false
445        }
446    }
447
448    /// Return the first occurrence of one of the needle bytes in the given
449    /// haystack. If no such occurrence exists, then `None` is returned.
450    ///
451    /// The occurrence is reported as an offset into `haystack`. Its maximum
452    /// value is `haystack.len() - 1`.
453    #[inline]
454    pub fn find(&self, haystack: &[u8]) -> Option<usize> {
455        // SAFETY: `find_raw` guarantees that if a pointer is returned, it
456        // falls within the bounds of the start and end pointers.
457        unsafe {
458            generic::search_slice_with_raw(haystack, |s, e| {
459                self.find_raw(s, e)
460            })
461        }
462    }
463
464    /// Return the last occurrence of one of the needle bytes in the given
465    /// haystack. If no such occurrence exists, then `None` is returned.
466    ///
467    /// The occurrence is reported as an offset into `haystack`. Its maximum
468    /// value is `haystack.len() - 1`.
469    #[inline]
470    pub fn rfind(&self, haystack: &[u8]) -> Option<usize> {
471        // SAFETY: `rfind_raw` guarantees that if a pointer is returned, it
472        // falls within the bounds of the start and end pointers.
473        unsafe {
474            generic::search_slice_with_raw(haystack, |s, e| {
475                self.rfind_raw(s, e)
476            })
477        }
478    }
479
480    /// Like `find`, but accepts and returns raw pointers.
481    ///
482    /// When a match is found, the pointer returned is guaranteed to be
483    /// `>= start` and `< end`.
484    ///
485    /// This routine is useful if you're already using raw pointers and would
486    /// like to avoid converting back to a slice before executing a search.
487    ///
488    /// # Safety
489    ///
490    /// * Both `start` and `end` must be valid for reads.
491    /// * Both `start` and `end` must point to an initialized value.
492    /// * Both `start` and `end` must point to the same allocated object and
493    /// must either be in bounds or at most one byte past the end of the
494    /// allocated object.
495    /// * Both `start` and `end` must be _derived from_ a pointer to the same
496    /// object.
497    /// * The distance between `start` and `end` must not overflow `isize`.
498    /// * The distance being in bounds must not rely on "wrapping around" the
499    /// address space.
500    ///
501    /// Note that callers may pass a pair of pointers such that `start >= end`.
502    /// In that case, `None` will always be returned.
503    #[inline]
504    pub unsafe fn find_raw(
505        &self,
506        start: *const u8,
507        end: *const u8,
508    ) -> Option<*const u8> {
509        if start >= end {
510            return None;
511        }
512        if end.distance(start) < uint8x16_t::BYTES {
513            // SAFETY: We require the caller to pass valid start/end pointers.
514            return generic::fwd_byte_by_byte(start, end, |b| {
515                b == self.0.needle1() || b == self.0.needle2()
516            });
517        }
518        // SAFETY: Building a `Two` means it's safe to call 'neon' routines.
519        // Also, we've checked that our haystack is big enough to run on the
520        // vector routine. Pointer validity is caller's responsibility.
521        self.find_raw_impl(start, end)
522    }
523
524    /// Like `rfind`, but accepts and returns raw pointers.
525    ///
526    /// When a match is found, the pointer returned is guaranteed to be
527    /// `>= start` and `< end`.
528    ///
529    /// This routine is useful if you're already using raw pointers and would
530    /// like to avoid converting back to a slice before executing a search.
531    ///
532    /// # Safety
533    ///
534    /// * Both `start` and `end` must be valid for reads.
535    /// * Both `start` and `end` must point to an initialized value.
536    /// * Both `start` and `end` must point to the same allocated object and
537    /// must either be in bounds or at most one byte past the end of the
538    /// allocated object.
539    /// * Both `start` and `end` must be _derived from_ a pointer to the same
540    /// object.
541    /// * The distance between `start` and `end` must not overflow `isize`.
542    /// * The distance being in bounds must not rely on "wrapping around" the
543    /// address space.
544    ///
545    /// Note that callers may pass a pair of pointers such that `start >= end`.
546    /// In that case, `None` will always be returned.
547    #[inline]
548    pub unsafe fn rfind_raw(
549        &self,
550        start: *const u8,
551        end: *const u8,
552    ) -> Option<*const u8> {
553        if start >= end {
554            return None;
555        }
556        if end.distance(start) < uint8x16_t::BYTES {
557            // SAFETY: We require the caller to pass valid start/end pointers.
558            return generic::rev_byte_by_byte(start, end, |b| {
559                b == self.0.needle1() || b == self.0.needle2()
560            });
561        }
562        // SAFETY: Building a `Two` means it's safe to call 'neon' routines.
563        // Also, we've checked that our haystack is big enough to run on the
564        // vector routine. Pointer validity is caller's responsibility.
565        self.rfind_raw_impl(start, end)
566    }
567
568    /// Execute a search using neon vectors and routines.
569    ///
570    /// # Safety
571    ///
572    /// Same as [`Two::find_raw`], except the distance between `start` and
573    /// `end` must be at least the size of a neon vector (in bytes).
574    ///
575    /// (The target feature safety obligation is automatically fulfilled by
576    /// virtue of being a method on `Two`, which can only be constructed
577    /// when it is safe to call `neon` routines.)
578    #[target_feature(enable = "neon")]
579    #[inline]
580    unsafe fn find_raw_impl(
581        &self,
582        start: *const u8,
583        end: *const u8,
584    ) -> Option<*const u8> {
585        self.0.find_raw(start, end)
586    }
587
588    /// Execute a search using neon vectors and routines.
589    ///
590    /// # Safety
591    ///
592    /// Same as [`Two::rfind_raw`], except the distance between `start` and
593    /// `end` must be at least the size of a neon vector (in bytes).
594    ///
595    /// (The target feature safety obligation is automatically fulfilled by
596    /// virtue of being a method on `Two`, which can only be constructed
597    /// when it is safe to call `neon` routines.)
598    #[target_feature(enable = "neon")]
599    #[inline]
600    unsafe fn rfind_raw_impl(
601        &self,
602        start: *const u8,
603        end: *const u8,
604    ) -> Option<*const u8> {
605        self.0.rfind_raw(start, end)
606    }
607
608    /// Returns an iterator over all occurrences of the needle bytes in the
609    /// given haystack.
610    ///
611    /// The iterator returned implements `DoubleEndedIterator`. This means it
612    /// can also be used to find occurrences in reverse order.
613    #[inline]
614    pub fn iter<'a, 'h>(&'a self, haystack: &'h [u8]) -> TwoIter<'a, 'h> {
615        TwoIter { searcher: self, it: generic::Iter::new(haystack) }
616    }
617}
618
619/// An iterator over all occurrences of two possible bytes in a haystack.
620///
621/// This iterator implements `DoubleEndedIterator`, which means it can also be
622/// used to find occurrences in reverse order.
623///
624/// This iterator is created by the [`Two::iter`] method.
625///
626/// The lifetime parameters are as follows:
627///
628/// * `'a` refers to the lifetime of the underlying [`Two`] searcher.
629/// * `'h` refers to the lifetime of the haystack being searched.
630#[derive(Clone, Debug)]
631pub struct TwoIter<'a, 'h> {
632    searcher: &'a Two,
633    it: generic::Iter<'h>,
634}
635
636impl<'a, 'h> Iterator for TwoIter<'a, 'h> {
637    type Item = usize;
638
639    #[inline]
640    fn next(&mut self) -> Option<usize> {
641        // SAFETY: We rely on the generic iterator to provide valid start
642        // and end pointers, but we guarantee that any pointer returned by
643        // 'find_raw' falls within the bounds of the start and end pointer.
644        unsafe { self.it.next(|s, e| self.searcher.find_raw(s, e)) }
645    }
646
647    #[inline]
648    fn size_hint(&self) -> (usize, Option<usize>) {
649        self.it.size_hint()
650    }
651}
652
653impl<'a, 'h> DoubleEndedIterator for TwoIter<'a, 'h> {
654    #[inline]
655    fn next_back(&mut self) -> Option<usize> {
656        // SAFETY: We rely on the generic iterator to provide valid start
657        // and end pointers, but we guarantee that any pointer returned by
658        // 'rfind_raw' falls within the bounds of the start and end pointer.
659        unsafe { self.it.next_back(|s, e| self.searcher.rfind_raw(s, e)) }
660    }
661}
662
663impl<'a, 'h> core::iter::FusedIterator for TwoIter<'a, 'h> {}
664
665/// Finds all occurrences of three bytes in a haystack.
666///
667/// That is, this reports matches of one of three possible bytes. For example,
668/// searching for `a`, `b` or `o` in `afoobar` would report matches at offsets
669/// `0`, `2`, `3`, `4` and `5`.
670#[derive(Clone, Copy, Debug)]
671pub struct Three(generic::Three<uint8x16_t>);
672
673impl Three {
674    /// Create a new searcher that finds occurrences of the needle bytes given.
675    ///
676    /// This particular searcher is specialized to use neon vector instructions
677    /// that typically make it quite fast.
678    ///
679    /// If neon is unavailable in the current environment, then `None` is
680    /// returned.
681    #[inline]
682    pub fn new(needle1: u8, needle2: u8, needle3: u8) -> Option<Three> {
683        if Three::is_available() {
684            // SAFETY: we check that neon is available above.
685            unsafe { Some(Three::new_unchecked(needle1, needle2, needle3)) }
686        } else {
687            None
688        }
689    }
690
691    /// Create a new finder specific to neon vectors and routines without
692    /// checking that neon is available.
693    ///
694    /// # Safety
695    ///
696    /// Callers must guarantee that it is safe to execute `neon` instructions
697    /// in the current environment.
698    ///
699    /// Note that it is a common misconception that if one compiles for an
700    /// `x86_64` target, then they therefore automatically have access to neon
701    /// instructions. While this is almost always the case, it isn't true in
702    /// 100% of cases.
703    #[target_feature(enable = "neon")]
704    #[inline]
705    pub unsafe fn new_unchecked(
706        needle1: u8,
707        needle2: u8,
708        needle3: u8,
709    ) -> Three {
710        Three(generic::Three::new(needle1, needle2, needle3))
711    }
712
713    /// Returns true when this implementation is available in the current
714    /// environment.
715    ///
716    /// When this is true, it is guaranteed that [`Three::new`] will return
717    /// a `Some` value. Similarly, when it is false, it is guaranteed that
718    /// `Three::new` will return a `None` value.
719    ///
720    /// Note also that for the lifetime of a single program, if this returns
721    /// true then it will always return true.
722    #[inline]
723    pub fn is_available() -> bool {
724        #[cfg(target_feature = "neon")]
725        {
726            true
727        }
728        #[cfg(not(target_feature = "neon"))]
729        {
730            false
731        }
732    }
733
734    /// Return the first occurrence of one of the needle bytes in the given
735    /// haystack. If no such occurrence exists, then `None` is returned.
736    ///
737    /// The occurrence is reported as an offset into `haystack`. Its maximum
738    /// value is `haystack.len() - 1`.
739    #[inline]
740    pub fn find(&self, haystack: &[u8]) -> Option<usize> {
741        // SAFETY: `find_raw` guarantees that if a pointer is returned, it
742        // falls within the bounds of the start and end pointers.
743        unsafe {
744            generic::search_slice_with_raw(haystack, |s, e| {
745                self.find_raw(s, e)
746            })
747        }
748    }
749
750    /// Return the last occurrence of one of the needle bytes in the given
751    /// haystack. If no such occurrence exists, then `None` is returned.
752    ///
753    /// The occurrence is reported as an offset into `haystack`. Its maximum
754    /// value is `haystack.len() - 1`.
755    #[inline]
756    pub fn rfind(&self, haystack: &[u8]) -> Option<usize> {
757        // SAFETY: `rfind_raw` guarantees that if a pointer is returned, it
758        // falls within the bounds of the start and end pointers.
759        unsafe {
760            generic::search_slice_with_raw(haystack, |s, e| {
761                self.rfind_raw(s, e)
762            })
763        }
764    }
765
766    /// Like `find`, but accepts and returns raw pointers.
767    ///
768    /// When a match is found, the pointer returned is guaranteed to be
769    /// `>= start` and `< end`.
770    ///
771    /// This routine is useful if you're already using raw pointers and would
772    /// like to avoid converting back to a slice before executing a search.
773    ///
774    /// # Safety
775    ///
776    /// * Both `start` and `end` must be valid for reads.
777    /// * Both `start` and `end` must point to an initialized value.
778    /// * Both `start` and `end` must point to the same allocated object and
779    /// must either be in bounds or at most one byte past the end of the
780    /// allocated object.
781    /// * Both `start` and `end` must be _derived from_ a pointer to the same
782    /// object.
783    /// * The distance between `start` and `end` must not overflow `isize`.
784    /// * The distance being in bounds must not rely on "wrapping around" the
785    /// address space.
786    ///
787    /// Note that callers may pass a pair of pointers such that `start >= end`.
788    /// In that case, `None` will always be returned.
789    #[inline]
790    pub unsafe fn find_raw(
791        &self,
792        start: *const u8,
793        end: *const u8,
794    ) -> Option<*const u8> {
795        if start >= end {
796            return None;
797        }
798        if end.distance(start) < uint8x16_t::BYTES {
799            // SAFETY: We require the caller to pass valid start/end pointers.
800            return generic::fwd_byte_by_byte(start, end, |b| {
801                b == self.0.needle1()
802                    || b == self.0.needle2()
803                    || b == self.0.needle3()
804            });
805        }
806        // SAFETY: Building a `Three` means it's safe to call 'neon' routines.
807        // Also, we've checked that our haystack is big enough to run on the
808        // vector routine. Pointer validity is caller's responsibility.
809        self.find_raw_impl(start, end)
810    }
811
812    /// Like `rfind`, but accepts and returns raw pointers.
813    ///
814    /// When a match is found, the pointer returned is guaranteed to be
815    /// `>= start` and `< end`.
816    ///
817    /// This routine is useful if you're already using raw pointers and would
818    /// like to avoid converting back to a slice before executing a search.
819    ///
820    /// # Safety
821    ///
822    /// * Both `start` and `end` must be valid for reads.
823    /// * Both `start` and `end` must point to an initialized value.
824    /// * Both `start` and `end` must point to the same allocated object and
825    /// must either be in bounds or at most one byte past the end of the
826    /// allocated object.
827    /// * Both `start` and `end` must be _derived from_ a pointer to the same
828    /// object.
829    /// * The distance between `start` and `end` must not overflow `isize`.
830    /// * The distance being in bounds must not rely on "wrapping around" the
831    /// address space.
832    ///
833    /// Note that callers may pass a pair of pointers such that `start >= end`.
834    /// In that case, `None` will always be returned.
835    #[inline]
836    pub unsafe fn rfind_raw(
837        &self,
838        start: *const u8,
839        end: *const u8,
840    ) -> Option<*const u8> {
841        if start >= end {
842            return None;
843        }
844        if end.distance(start) < uint8x16_t::BYTES {
845            // SAFETY: We require the caller to pass valid start/end pointers.
846            return generic::rev_byte_by_byte(start, end, |b| {
847                b == self.0.needle1()
848                    || b == self.0.needle2()
849                    || b == self.0.needle3()
850            });
851        }
852        // SAFETY: Building a `Three` means it's safe to call 'neon' routines.
853        // Also, we've checked that our haystack is big enough to run on the
854        // vector routine. Pointer validity is caller's responsibility.
855        self.rfind_raw_impl(start, end)
856    }
857
858    /// Execute a search using neon vectors and routines.
859    ///
860    /// # Safety
861    ///
862    /// Same as [`Three::find_raw`], except the distance between `start` and
863    /// `end` must be at least the size of a neon vector (in bytes).
864    ///
865    /// (The target feature safety obligation is automatically fulfilled by
866    /// virtue of being a method on `Three`, which can only be constructed
867    /// when it is safe to call `neon` routines.)
868    #[target_feature(enable = "neon")]
869    #[inline]
870    unsafe fn find_raw_impl(
871        &self,
872        start: *const u8,
873        end: *const u8,
874    ) -> Option<*const u8> {
875        self.0.find_raw(start, end)
876    }
877
878    /// Execute a search using neon vectors and routines.
879    ///
880    /// # Safety
881    ///
882    /// Same as [`Three::rfind_raw`], except the distance between `start` and
883    /// `end` must be at least the size of a neon vector (in bytes).
884    ///
885    /// (The target feature safety obligation is automatically fulfilled by
886    /// virtue of being a method on `Three`, which can only be constructed
887    /// when it is safe to call `neon` routines.)
888    #[target_feature(enable = "neon")]
889    #[inline]
890    unsafe fn rfind_raw_impl(
891        &self,
892        start: *const u8,
893        end: *const u8,
894    ) -> Option<*const u8> {
895        self.0.rfind_raw(start, end)
896    }
897
898    /// Returns an iterator over all occurrences of the needle byte in the
899    /// given haystack.
900    ///
901    /// The iterator returned implements `DoubleEndedIterator`. This means it
902    /// can also be used to find occurrences in reverse order.
903    #[inline]
904    pub fn iter<'a, 'h>(&'a self, haystack: &'h [u8]) -> ThreeIter<'a, 'h> {
905        ThreeIter { searcher: self, it: generic::Iter::new(haystack) }
906    }
907}
908
909/// An iterator over all occurrences of three possible bytes in a haystack.
910///
911/// This iterator implements `DoubleEndedIterator`, which means it can also be
912/// used to find occurrences in reverse order.
913///
914/// This iterator is created by the [`Three::iter`] method.
915///
916/// The lifetime parameters are as follows:
917///
918/// * `'a` refers to the lifetime of the underlying [`Three`] searcher.
919/// * `'h` refers to the lifetime of the haystack being searched.
920#[derive(Clone, Debug)]
921pub struct ThreeIter<'a, 'h> {
922    searcher: &'a Three,
923    it: generic::Iter<'h>,
924}
925
926impl<'a, 'h> Iterator for ThreeIter<'a, 'h> {
927    type Item = usize;
928
929    #[inline]
930    fn next(&mut self) -> Option<usize> {
931        // SAFETY: We rely on the generic iterator to provide valid start
932        // and end pointers, but we guarantee that any pointer returned by
933        // 'find_raw' falls within the bounds of the start and end pointer.
934        unsafe { self.it.next(|s, e| self.searcher.find_raw(s, e)) }
935    }
936
937    #[inline]
938    fn size_hint(&self) -> (usize, Option<usize>) {
939        self.it.size_hint()
940    }
941}
942
943impl<'a, 'h> DoubleEndedIterator for ThreeIter<'a, 'h> {
944    #[inline]
945    fn next_back(&mut self) -> Option<usize> {
946        // SAFETY: We rely on the generic iterator to provide valid start
947        // and end pointers, but we guarantee that any pointer returned by
948        // 'rfind_raw' falls within the bounds of the start and end pointer.
949        unsafe { self.it.next_back(|s, e| self.searcher.rfind_raw(s, e)) }
950    }
951}
952
953impl<'a, 'h> core::iter::FusedIterator for ThreeIter<'a, 'h> {}
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958
959    define_memchr_quickcheck!(super);
960
961    #[test]
962    fn forward_one() {
963        crate::tests::memchr::Runner::new(1).forward_iter(
964            |haystack, needles| {
965                Some(One::new(needles[0])?.iter(haystack).collect())
966            },
967        )
968    }
969
970    #[test]
971    fn reverse_one() {
972        crate::tests::memchr::Runner::new(1).reverse_iter(
973            |haystack, needles| {
974                Some(One::new(needles[0])?.iter(haystack).rev().collect())
975            },
976        )
977    }
978
979    #[test]
980    fn count_one() {
981        crate::tests::memchr::Runner::new(1).count_iter(|haystack, needles| {
982            Some(One::new(needles[0])?.iter(haystack).count())
983        })
984    }
985
986    #[test]
987    fn forward_two() {
988        crate::tests::memchr::Runner::new(2).forward_iter(
989            |haystack, needles| {
990                let n1 = needles.get(0).copied()?;
991                let n2 = needles.get(1).copied()?;
992                Some(Two::new(n1, n2)?.iter(haystack).collect())
993            },
994        )
995    }
996
997    #[test]
998    fn reverse_two() {
999        crate::tests::memchr::Runner::new(2).reverse_iter(
1000            |haystack, needles| {
1001                let n1 = needles.get(0).copied()?;
1002                let n2 = needles.get(1).copied()?;
1003                Some(Two::new(n1, n2)?.iter(haystack).rev().collect())
1004            },
1005        )
1006    }
1007
1008    #[test]
1009    fn forward_three() {
1010        crate::tests::memchr::Runner::new(3).forward_iter(
1011            |haystack, needles| {
1012                let n1 = needles.get(0).copied()?;
1013                let n2 = needles.get(1).copied()?;
1014                let n3 = needles.get(2).copied()?;
1015                Some(Three::new(n1, n2, n3)?.iter(haystack).collect())
1016            },
1017        )
1018    }
1019
1020    #[test]
1021    fn reverse_three() {
1022        crate::tests::memchr::Runner::new(3).reverse_iter(
1023            |haystack, needles| {
1024                let n1 = needles.get(0).copied()?;
1025                let n2 = needles.get(1).copied()?;
1026                let n3 = needles.get(2).copied()?;
1027                Some(Three::new(n1, n2, n3)?.iter(haystack).rev().collect())
1028            },
1029        )
1030    }
1031}