bitcode/derive/
vec.rs

1use crate::coder::{Buffer, Decoder, Encoder, Result, View, MAX_VECTORED_CHUNK};
2use crate::derive::{Decode, Encode};
3use crate::fast::Unaligned;
4use crate::length::{LengthDecoder, LengthEncoder};
5use alloc::collections::{BTreeSet, BinaryHeap, LinkedList, VecDeque};
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8use core::num::NonZeroUsize;
9
10#[cfg(feature = "std")]
11use core::hash::{BuildHasher, Hash};
12#[cfg(feature = "std")]
13use std::collections::HashSet;
14
15pub struct VecEncoder<T: Encode> {
16    // pub(crate) for arrayvec.rs
17    pub(crate) lengths: LengthEncoder,
18    pub(crate) elements: T::Encoder,
19    vectored_impl: Option<fn()>,
20}
21
22// Can't derive since it would bound T: Default.
23impl<T: Encode> Default for VecEncoder<T> {
24    fn default() -> Self {
25        Self {
26            lengths: Default::default(),
27            elements: Default::default(),
28            vectored_impl: Default::default(),
29        }
30    }
31}
32
33impl<T: Encode> Buffer for VecEncoder<T> {
34    fn collect_into(&mut self, out: &mut Vec<u8>) {
35        self.lengths.collect_into(out);
36        self.elements.collect_into(out);
37    }
38
39    fn reserve(&mut self, additional: NonZeroUsize) {
40        self.lengths.reserve(additional);
41        // We don't know the lengths of the vectors, so we can't reserve more.
42    }
43}
44
45/// Copies `N` or `n` bytes from `src` to `dst` depending on if `src` lies within a memory page.
46/// https://stackoverflow.com/questions/37800739/is-it-safe-to-read-past-the-end-of-a-buffer-within-the-same-page-on-x86-and-x64
47/// # Safety
48/// Same as [`std::ptr::copy_nonoverlapping`] but with the additional requirements that
49/// `n != 0 && n <= N` and `dst` has room for a `[T; N]`.
50/// Is a macro instead of an `#[inline(always)] fn` because it optimizes better.
51macro_rules! unsafe_wild_copy {
52    // pub unsafe fn wild_copy<T, const N: usize>(src: *const T, dst: *mut T, n: usize) {
53    ([$T:ident; $N:ident], $src:ident, $dst:ident, $n:ident) => {
54        debug_assert!($n != 0 && $n <= $N);
55
56        let page_size = 4096;
57        let read_size = core::mem::size_of::<[$T; $N]>();
58        let within_page = $src as usize & (page_size - 1) < (page_size - read_size) && cfg!(all(
59            // Miri doesn't like this.
60            not(miri),
61            // cargo fuzz's memory sanitizer complains about buffer overrun.
62            // Without nightly we can't detect memory sanitizers, so we check debug_assertions.
63            not(debug_assertions),
64            // x86/x86_64/aarch64 all have min page size of 4096, so reading past the end of a non-empty
65            // buffer won't page fault.
66            any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")
67        ));
68
69        if within_page {
70            *($dst as *mut core::mem::MaybeUninit<[$T; $N]>) = core::ptr::read($src as *const core::mem::MaybeUninit<[$T; $N]>);
71        } else {
72            #[cold]
73            unsafe fn cold<T>(src: *const T, dst: *mut T, n: usize) {
74                src.copy_to_nonoverlapping(dst, n);
75            }
76            cold($src, $dst, $n);
77        }
78    }
79}
80pub(crate) use unsafe_wild_copy;
81
82impl<T: Encode> VecEncoder<T> {
83    /// Copy fixed size slices. Much faster than memcpy.
84    #[inline(never)]
85    fn encode_vectored_max_len<'a, I: Iterator<Item = &'a [T]> + Clone, const N: usize>(
86        &mut self,
87        i: I,
88    ) where
89        T: 'a,
90    {
91        unsafe {
92            let primitives = self.elements.as_primitive().unwrap();
93            primitives.reserve(i.size_hint().1.unwrap() * N);
94
95            let mut dst = primitives.end_ptr();
96            if self.lengths.encode_vectored_max_len::<_, N>(
97                i.clone(),
98                #[inline(always)]
99                |s| {
100                    let src = s.as_ptr();
101                    let n = s.len();
102                    // Safety: encode_vectored_max_len skips len == 0 and ensures len <= N.
103                    // `dst` has enough space for `[T; N]` because we've reserved size_hint * N.
104                    unsafe_wild_copy!([T; N], src, dst, n);
105                    dst = dst.add(n);
106                },
107            ) {
108                // Use fallback for impls that copy more than 64 bytes.
109                let size = core::mem::size_of::<T>();
110                self.vectored_impl = core::mem::transmute(match N {
111                    1 if size <= 32 => Self::encode_vectored_max_len::<I, 2>,
112                    2 if size <= 16 => Self::encode_vectored_max_len::<I, 4>,
113                    4 if size <= 8 => Self::encode_vectored_max_len::<I, 8>,
114                    8 if size <= 4 => Self::encode_vectored_max_len::<I, 16>,
115                    16 if size <= 2 => Self::encode_vectored_max_len::<I, 32>,
116                    32 if size <= 1 => Self::encode_vectored_max_len::<I, 64>,
117                    _ => Self::encode_vectored_fallback::<I>,
118                } as fn(&mut Self, I));
119                let f: fn(&mut Self, I) = core::mem::transmute(self.vectored_impl);
120                f(self, i);
121                return;
122            }
123            primitives.set_end_ptr(dst);
124        }
125    }
126
127    /// Fallback for when length > [`Self::encode_vectored_max_len`]'s max_len.
128    #[inline(never)]
129    fn encode_vectored_fallback<'a, I: Iterator<Item = &'a [T]>>(&mut self, i: I)
130    where
131        T: 'a,
132    {
133        let primitives = self.elements.as_primitive().unwrap();
134        self.lengths.encode_vectored_fallback(i, |s| unsafe {
135            let n = s.len();
136            primitives.reserve(n);
137            let ptr = primitives.end_ptr();
138            s.as_ptr().copy_to_nonoverlapping(ptr, n);
139            primitives.set_end_ptr(ptr.add(n));
140        });
141    }
142}
143
144impl<T: Encode> Encoder<[T]> for VecEncoder<T> {
145    #[inline(always)]
146    fn encode(&mut self, v: &[T]) {
147        let n = v.len();
148        self.lengths.encode(&n);
149
150        if let Some(primitive) = self.elements.as_primitive() {
151            primitive.reserve(n);
152            unsafe {
153                let ptr = primitive.end_ptr();
154                v.as_ptr().copy_to_nonoverlapping(ptr, n);
155                primitive.set_end_ptr(ptr.add(n));
156            }
157        } else if let Some(n) = NonZeroUsize::new(n) {
158            self.elements.reserve(n);
159            // Uses chunks to keep everything in the CPU cache. TODO pick optimal chunk size.
160            for chunk in v.chunks(MAX_VECTORED_CHUNK) {
161                self.elements.encode_vectored(chunk.iter());
162            }
163        }
164    }
165
166    #[inline(always)]
167    fn encode_vectored<'a>(&mut self, i: impl Iterator<Item = &'a [T]> + Clone)
168    where
169        [T]: 'a,
170    {
171        if self.elements.as_primitive().is_some() {
172            /// Convert impl trait to named generic type.
173            #[inline(always)]
174            fn inner<'a, T: Encode + 'a, I: Iterator<Item = &'a [T]> + Clone>(
175                me: &mut VecEncoder<T>,
176                i: I,
177            ) {
178                unsafe {
179                    // We can't set this in the Default constructor because we don't have the type I.
180                    if me.vectored_impl.is_none() {
181                        // Use match to avoid "use of generic parameter from outer function".
182                        // Start at the pointer size (assumed to be 8 bytes) to not be wasteful.
183                        me.vectored_impl =
184                            core::mem::transmute(match (8 / core::mem::size_of::<T>()).max(1) {
185                                1 => VecEncoder::encode_vectored_max_len::<I, 1>,
186                                2 => VecEncoder::encode_vectored_max_len::<I, 2>,
187                                4 => VecEncoder::encode_vectored_max_len::<I, 4>,
188                                8 => VecEncoder::encode_vectored_max_len::<I, 8>,
189                                _ => unreachable!(),
190                            }
191                                as fn(&mut VecEncoder<T>, I));
192                    }
193                    let f: fn(&mut VecEncoder<T>, I) = core::mem::transmute(me.vectored_impl);
194                    f(me, i);
195                }
196            }
197            inner(self, i);
198        } else {
199            for v in i {
200                self.encode(v);
201            }
202        }
203    }
204}
205
206pub struct VecDecoder<'a, T: Decode<'a>> {
207    // pub(crate) for arrayvec::ArrayVec.
208    pub(crate) lengths: LengthDecoder<'a>,
209    pub(crate) elements: T::Decoder,
210}
211
212// Can't derive since it would bound T: Default.
213impl<'a, T: Decode<'a>> Default for VecDecoder<'a, T> {
214    fn default() -> Self {
215        Self {
216            lengths: Default::default(),
217            elements: Default::default(),
218        }
219    }
220}
221
222impl<'a, T: Decode<'a>> View<'a> for VecDecoder<'a, T> {
223    fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
224        self.lengths.populate(input, length)?;
225        self.elements.populate(input, self.lengths.length())
226    }
227}
228
229macro_rules! encode_body {
230    ($t:ty) => {
231        #[inline(always)]
232        fn encode(&mut self, v: &$t) {
233            let n = v.len();
234            self.lengths.encode(&n);
235            if let Some(n) = NonZeroUsize::new(n) {
236                self.elements.reserve(n);
237                for v in v {
238                    self.elements.encode(v);
239                }
240            }
241        }
242    };
243}
244// Faster on some collections.
245macro_rules! encode_body_internal_iteration {
246    ($t:ty) => {
247        #[inline(always)]
248        fn encode(&mut self, v: &$t) {
249            let n = v.len();
250            self.lengths.encode(&n);
251            if let Some(n) = NonZeroUsize::new(n) {
252                self.elements.reserve(n);
253                v.iter().for_each(|v| self.elements.encode(v));
254            }
255        }
256    };
257}
258macro_rules! decode_body {
259    ($t:ty) => {
260        #[inline(always)]
261        fn decode(&mut self) -> $t {
262            // - BTreeSet::from_iter is faster than BTreeSet::insert (see comment in map.rs).
263            // - HashSet is about the same either way.
264            // - Vec::from_iter is slower (so it doesn't use this).
265            (0..self.lengths.decode())
266                .map(|_| self.elements.decode())
267                .collect()
268        }
269    };
270}
271
272impl<T: Encode> Encoder<Vec<T>> for VecEncoder<T> {
273    #[inline(always)]
274    fn encode(&mut self, v: &Vec<T>) {
275        self.encode(v.as_slice());
276    }
277
278    #[inline(always)]
279    fn encode_vectored<'a>(&mut self, i: impl Iterator<Item = &'a Vec<T>> + Clone)
280    where
281        Vec<T>: 'a,
282    {
283        self.encode_vectored(i.map(Vec::as_slice));
284    }
285}
286impl<'a, T: Decode<'a>> Decoder<'a, Vec<T>> for VecDecoder<'a, T> {
287    #[inline(always)]
288    fn decode_in_place(&mut self, out: &mut MaybeUninit<Vec<T>>) {
289        let length = self.lengths.decode();
290        // Fast path, avoid memcpy and mutating len.
291        if length == 0 {
292            out.write(Vec::new());
293            return;
294        }
295
296        let v = out.write(Vec::with_capacity(length));
297        if let Some(primitive) = self.elements.as_primitive() {
298            unsafe {
299                primitive
300                    .as_ptr()
301                    .copy_to_nonoverlapping(v.as_mut_ptr() as *mut Unaligned<T>, length);
302                primitive.advance(length);
303            }
304        } else {
305            let spare = v.spare_capacity_mut();
306            for i in 0..length {
307                let out = unsafe { spare.get_unchecked_mut(i) };
308                self.elements.decode_in_place(out);
309            }
310        }
311        unsafe { v.set_len(length) };
312    }
313}
314
315impl<T: Encode> Encoder<BinaryHeap<T>> for VecEncoder<T> {
316    encode_body!(BinaryHeap<T>); // When BinaryHeap::as_slice is stable use [T] impl.
317}
318impl<'a, T: Decode<'a> + Ord> Decoder<'a, BinaryHeap<T>> for VecDecoder<'a, T> {
319    #[inline(always)]
320    fn decode(&mut self) -> BinaryHeap<T> {
321        let v: Vec<T> = self.decode();
322        v.into()
323    }
324}
325
326impl<T: Encode> Encoder<BTreeSet<T>> for VecEncoder<T> {
327    encode_body!(BTreeSet<T>);
328}
329impl<'a, T: Decode<'a> + Ord> Decoder<'a, BTreeSet<T>> for VecDecoder<'a, T> {
330    decode_body!(BTreeSet<T>);
331}
332
333#[cfg(feature = "std")]
334impl<T: Encode, S> Encoder<HashSet<T, S>> for VecEncoder<T> {
335    // Internal iteration is 1.6x faster. Interestingly this does not apply to HashMap<T, ()> which
336    // I assume is due to HashSet::iter being implemented with HashMap::keys.
337    encode_body_internal_iteration!(HashSet<T, S>);
338}
339#[cfg(feature = "std")]
340impl<'a, T: Decode<'a> + Eq + Hash, S: BuildHasher + Default> Decoder<'a, HashSet<T, S>>
341    for VecDecoder<'a, T>
342{
343    decode_body!(HashSet<T, S>);
344}
345
346impl<T: Encode> Encoder<LinkedList<T>> for VecEncoder<T> {
347    encode_body!(LinkedList<T>);
348}
349impl<'a, T: Decode<'a>> Decoder<'a, LinkedList<T>> for VecDecoder<'a, T> {
350    decode_body!(LinkedList<T>);
351}
352
353impl<T: Encode> Encoder<VecDeque<T>> for VecEncoder<T> {
354    encode_body_internal_iteration!(VecDeque<T>); // Internal iteration is 10x faster.
355}
356impl<'a, T: Decode<'a>> Decoder<'a, VecDeque<T>> for VecDecoder<'a, T> {
357    #[inline(always)]
358    fn decode(&mut self) -> VecDeque<T> {
359        let v: Vec<T> = self.decode();
360        v.into()
361    }
362}
363
364#[cfg(test)]
365mod test {
366    use alloc::collections::*;
367    use alloc::vec::Vec;
368
369    fn bench_data<T: FromIterator<u8>>() -> T {
370        (0..=255).collect()
371    }
372
373    crate::bench_encode_decode!(
374        btree_set: BTreeSet<_>,
375        linked_list: LinkedList<_>,
376        vec: Vec<_>,
377        vec_deque: VecDeque<_>
378    );
379    #[cfg(feature = "std")]
380    crate::bench_encode_decode!(hash_set: std::collections::HashSet<_>);
381
382    // BinaryHeap can't use bench_encode_decode because it doesn't implement PartialEq.
383    #[bench]
384    fn bench_binary_heap_decode(b: &mut test::Bencher) {
385        type T = BinaryHeap<u8>;
386        let data: T = bench_data();
387        let encoded = crate::encode(&data);
388        b.iter(|| {
389            let decoded: T = crate::decode::<T>(&encoded).unwrap();
390            debug_assert!(data.iter().eq(decoded.iter()));
391            decoded
392        })
393    }
394}