1use crate::coder::{Buffer, Decoder, Encoder, Result, View, MAX_VECTORED_CHUNK};
2use crate::derive::{Decode, Encode};
3use crate::fast::Unaligned;
4use crate::length::{LengthDecoder, LengthEncoder};
5use alloc::collections::{BTreeSet, BinaryHeap, LinkedList, VecDeque};
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8use core::num::NonZeroUsize;
9
10#[cfg(feature = "std")]
11use core::hash::{BuildHasher, Hash};
12#[cfg(feature = "std")]
13use std::collections::HashSet;
14
15pub struct VecEncoder<T: Encode> {
16 pub(crate) lengths: LengthEncoder,
18 pub(crate) elements: T::Encoder,
19 vectored_impl: Option<fn()>,
20}
21
22impl<T: Encode> Default for VecEncoder<T> {
24 fn default() -> Self {
25 Self {
26 lengths: Default::default(),
27 elements: Default::default(),
28 vectored_impl: Default::default(),
29 }
30 }
31}
32
33impl<T: Encode> Buffer for VecEncoder<T> {
34 fn collect_into(&mut self, out: &mut Vec<u8>) {
35 self.lengths.collect_into(out);
36 self.elements.collect_into(out);
37 }
38
39 fn reserve(&mut self, additional: NonZeroUsize) {
40 self.lengths.reserve(additional);
41 }
43}
44
45macro_rules! unsafe_wild_copy {
52 ([$T:ident; $N:ident], $src:ident, $dst:ident, $n:ident) => {
54 debug_assert!($n != 0 && $n <= $N);
55
56 let page_size = 4096;
57 let read_size = core::mem::size_of::<[$T; $N]>();
58 let within_page = $src as usize & (page_size - 1) < (page_size - read_size) && cfg!(all(
59 not(miri),
61 not(debug_assertions),
64 any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")
67 ));
68
69 if within_page {
70 *($dst as *mut core::mem::MaybeUninit<[$T; $N]>) = core::ptr::read($src as *const core::mem::MaybeUninit<[$T; $N]>);
71 } else {
72 #[cold]
73 unsafe fn cold<T>(src: *const T, dst: *mut T, n: usize) {
74 src.copy_to_nonoverlapping(dst, n);
75 }
76 cold($src, $dst, $n);
77 }
78 }
79}
80pub(crate) use unsafe_wild_copy;
81
82impl<T: Encode> VecEncoder<T> {
83 #[inline(never)]
85 fn encode_vectored_max_len<'a, I: Iterator<Item = &'a [T]> + Clone, const N: usize>(
86 &mut self,
87 i: I,
88 ) where
89 T: 'a,
90 {
91 unsafe {
92 let primitives = self.elements.as_primitive().unwrap();
93 primitives.reserve(i.size_hint().1.unwrap() * N);
94
95 let mut dst = primitives.end_ptr();
96 if self.lengths.encode_vectored_max_len::<_, N>(
97 i.clone(),
98 #[inline(always)]
99 |s| {
100 let src = s.as_ptr();
101 let n = s.len();
102 unsafe_wild_copy!([T; N], src, dst, n);
105 dst = dst.add(n);
106 },
107 ) {
108 let size = core::mem::size_of::<T>();
110 self.vectored_impl = core::mem::transmute(match N {
111 1 if size <= 32 => Self::encode_vectored_max_len::<I, 2>,
112 2 if size <= 16 => Self::encode_vectored_max_len::<I, 4>,
113 4 if size <= 8 => Self::encode_vectored_max_len::<I, 8>,
114 8 if size <= 4 => Self::encode_vectored_max_len::<I, 16>,
115 16 if size <= 2 => Self::encode_vectored_max_len::<I, 32>,
116 32 if size <= 1 => Self::encode_vectored_max_len::<I, 64>,
117 _ => Self::encode_vectored_fallback::<I>,
118 } as fn(&mut Self, I));
119 let f: fn(&mut Self, I) = core::mem::transmute(self.vectored_impl);
120 f(self, i);
121 return;
122 }
123 primitives.set_end_ptr(dst);
124 }
125 }
126
127 #[inline(never)]
129 fn encode_vectored_fallback<'a, I: Iterator<Item = &'a [T]>>(&mut self, i: I)
130 where
131 T: 'a,
132 {
133 let primitives = self.elements.as_primitive().unwrap();
134 self.lengths.encode_vectored_fallback(i, |s| unsafe {
135 let n = s.len();
136 primitives.reserve(n);
137 let ptr = primitives.end_ptr();
138 s.as_ptr().copy_to_nonoverlapping(ptr, n);
139 primitives.set_end_ptr(ptr.add(n));
140 });
141 }
142}
143
144impl<T: Encode> Encoder<[T]> for VecEncoder<T> {
145 #[inline(always)]
146 fn encode(&mut self, v: &[T]) {
147 let n = v.len();
148 self.lengths.encode(&n);
149
150 if let Some(primitive) = self.elements.as_primitive() {
151 primitive.reserve(n);
152 unsafe {
153 let ptr = primitive.end_ptr();
154 v.as_ptr().copy_to_nonoverlapping(ptr, n);
155 primitive.set_end_ptr(ptr.add(n));
156 }
157 } else if let Some(n) = NonZeroUsize::new(n) {
158 self.elements.reserve(n);
159 for chunk in v.chunks(MAX_VECTORED_CHUNK) {
161 self.elements.encode_vectored(chunk.iter());
162 }
163 }
164 }
165
166 #[inline(always)]
167 fn encode_vectored<'a>(&mut self, i: impl Iterator<Item = &'a [T]> + Clone)
168 where
169 [T]: 'a,
170 {
171 if self.elements.as_primitive().is_some() {
172 #[inline(always)]
174 fn inner<'a, T: Encode + 'a, I: Iterator<Item = &'a [T]> + Clone>(
175 me: &mut VecEncoder<T>,
176 i: I,
177 ) {
178 unsafe {
179 if me.vectored_impl.is_none() {
181 me.vectored_impl =
184 core::mem::transmute(match (8 / core::mem::size_of::<T>()).max(1) {
185 1 => VecEncoder::encode_vectored_max_len::<I, 1>,
186 2 => VecEncoder::encode_vectored_max_len::<I, 2>,
187 4 => VecEncoder::encode_vectored_max_len::<I, 4>,
188 8 => VecEncoder::encode_vectored_max_len::<I, 8>,
189 _ => unreachable!(),
190 }
191 as fn(&mut VecEncoder<T>, I));
192 }
193 let f: fn(&mut VecEncoder<T>, I) = core::mem::transmute(me.vectored_impl);
194 f(me, i);
195 }
196 }
197 inner(self, i);
198 } else {
199 for v in i {
200 self.encode(v);
201 }
202 }
203 }
204}
205
206pub struct VecDecoder<'a, T: Decode<'a>> {
207 pub(crate) lengths: LengthDecoder<'a>,
209 pub(crate) elements: T::Decoder,
210}
211
212impl<'a, T: Decode<'a>> Default for VecDecoder<'a, T> {
214 fn default() -> Self {
215 Self {
216 lengths: Default::default(),
217 elements: Default::default(),
218 }
219 }
220}
221
222impl<'a, T: Decode<'a>> View<'a> for VecDecoder<'a, T> {
223 fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
224 self.lengths.populate(input, length)?;
225 self.elements.populate(input, self.lengths.length())
226 }
227}
228
229macro_rules! encode_body {
230 ($t:ty) => {
231 #[inline(always)]
232 fn encode(&mut self, v: &$t) {
233 let n = v.len();
234 self.lengths.encode(&n);
235 if let Some(n) = NonZeroUsize::new(n) {
236 self.elements.reserve(n);
237 for v in v {
238 self.elements.encode(v);
239 }
240 }
241 }
242 };
243}
244macro_rules! encode_body_internal_iteration {
246 ($t:ty) => {
247 #[inline(always)]
248 fn encode(&mut self, v: &$t) {
249 let n = v.len();
250 self.lengths.encode(&n);
251 if let Some(n) = NonZeroUsize::new(n) {
252 self.elements.reserve(n);
253 v.iter().for_each(|v| self.elements.encode(v));
254 }
255 }
256 };
257}
258macro_rules! decode_body {
259 ($t:ty) => {
260 #[inline(always)]
261 fn decode(&mut self) -> $t {
262 (0..self.lengths.decode())
266 .map(|_| self.elements.decode())
267 .collect()
268 }
269 };
270}
271
272impl<T: Encode> Encoder<Vec<T>> for VecEncoder<T> {
273 #[inline(always)]
274 fn encode(&mut self, v: &Vec<T>) {
275 self.encode(v.as_slice());
276 }
277
278 #[inline(always)]
279 fn encode_vectored<'a>(&mut self, i: impl Iterator<Item = &'a Vec<T>> + Clone)
280 where
281 Vec<T>: 'a,
282 {
283 self.encode_vectored(i.map(Vec::as_slice));
284 }
285}
286impl<'a, T: Decode<'a>> Decoder<'a, Vec<T>> for VecDecoder<'a, T> {
287 #[inline(always)]
288 fn decode_in_place(&mut self, out: &mut MaybeUninit<Vec<T>>) {
289 let length = self.lengths.decode();
290 if length == 0 {
292 out.write(Vec::new());
293 return;
294 }
295
296 let v = out.write(Vec::with_capacity(length));
297 if let Some(primitive) = self.elements.as_primitive() {
298 unsafe {
299 primitive
300 .as_ptr()
301 .copy_to_nonoverlapping(v.as_mut_ptr() as *mut Unaligned<T>, length);
302 primitive.advance(length);
303 }
304 } else {
305 let spare = v.spare_capacity_mut();
306 for i in 0..length {
307 let out = unsafe { spare.get_unchecked_mut(i) };
308 self.elements.decode_in_place(out);
309 }
310 }
311 unsafe { v.set_len(length) };
312 }
313}
314
315impl<T: Encode> Encoder<BinaryHeap<T>> for VecEncoder<T> {
316 encode_body!(BinaryHeap<T>); }
318impl<'a, T: Decode<'a> + Ord> Decoder<'a, BinaryHeap<T>> for VecDecoder<'a, T> {
319 #[inline(always)]
320 fn decode(&mut self) -> BinaryHeap<T> {
321 let v: Vec<T> = self.decode();
322 v.into()
323 }
324}
325
326impl<T: Encode> Encoder<BTreeSet<T>> for VecEncoder<T> {
327 encode_body!(BTreeSet<T>);
328}
329impl<'a, T: Decode<'a> + Ord> Decoder<'a, BTreeSet<T>> for VecDecoder<'a, T> {
330 decode_body!(BTreeSet<T>);
331}
332
333#[cfg(feature = "std")]
334impl<T: Encode, S> Encoder<HashSet<T, S>> for VecEncoder<T> {
335 encode_body_internal_iteration!(HashSet<T, S>);
338}
339#[cfg(feature = "std")]
340impl<'a, T: Decode<'a> + Eq + Hash, S: BuildHasher + Default> Decoder<'a, HashSet<T, S>>
341 for VecDecoder<'a, T>
342{
343 decode_body!(HashSet<T, S>);
344}
345
346impl<T: Encode> Encoder<LinkedList<T>> for VecEncoder<T> {
347 encode_body!(LinkedList<T>);
348}
349impl<'a, T: Decode<'a>> Decoder<'a, LinkedList<T>> for VecDecoder<'a, T> {
350 decode_body!(LinkedList<T>);
351}
352
353impl<T: Encode> Encoder<VecDeque<T>> for VecEncoder<T> {
354 encode_body_internal_iteration!(VecDeque<T>); }
356impl<'a, T: Decode<'a>> Decoder<'a, VecDeque<T>> for VecDecoder<'a, T> {
357 #[inline(always)]
358 fn decode(&mut self) -> VecDeque<T> {
359 let v: Vec<T> = self.decode();
360 v.into()
361 }
362}
363
364#[cfg(test)]
365mod test {
366 use alloc::collections::*;
367 use alloc::vec::Vec;
368
369 fn bench_data<T: FromIterator<u8>>() -> T {
370 (0..=255).collect()
371 }
372
373 crate::bench_encode_decode!(
374 btree_set: BTreeSet<_>,
375 linked_list: LinkedList<_>,
376 vec: Vec<_>,
377 vec_deque: VecDeque<_>
378 );
379 #[cfg(feature = "std")]
380 crate::bench_encode_decode!(hash_set: std::collections::HashSet<_>);
381
382 #[bench]
384 fn bench_binary_heap_decode(b: &mut test::Bencher) {
385 type T = BinaryHeap<u8>;
386 let data: T = bench_data();
387 let encoded = crate::encode(&data);
388 b.iter(|| {
389 let decoded: T = crate::decode::<T>(&encoded).unwrap();
390 debug_assert!(data.iter().eq(decoded.iter()));
391 decoded
392 })
393 }
394}