bitcode/derive/
option.rs

1use crate::coder::{Buffer, Decoder, Encoder, Result, View, MAX_VECTORED_CHUNK};
2use crate::derive::variant::{VariantDecoder, VariantEncoder};
3use crate::derive::{Decode, Encode};
4use crate::fast::{FastArrayVec, PushUnchecked};
5use alloc::vec::Vec;
6use core::mem::MaybeUninit;
7use core::num::NonZeroUsize;
8
9pub struct OptionEncoder<T: Encode> {
10    variants: VariantEncoder<2>,
11    some: T::Encoder,
12}
13
14// Can't derive since it would bound T: Default.
15impl<T: Encode> Default for OptionEncoder<T> {
16    fn default() -> Self {
17        Self {
18            variants: Default::default(),
19            some: Default::default(),
20        }
21    }
22}
23
24impl<T: Encode> Encoder<Option<T>> for OptionEncoder<T> {
25    #[inline(always)]
26    fn encode(&mut self, t: &Option<T>) {
27        self.variants.encode(&(t.is_some() as u8));
28        if let Some(t) = t {
29            self.some.reserve(NonZeroUsize::new(1).unwrap());
30            self.some.encode(t);
31        }
32    }
33
34    fn encode_vectored<'a>(&mut self, i: impl Iterator<Item = &'a Option<T>> + Clone)
35    where
36        Option<T>: 'a,
37    {
38        // Types with many vectorized encoders benefit from a &[&T] since encode_vectorized is still
39        // faster even with the extra indirection. TODO vectored encoder count >= 8 instead of size_of.
40        if core::mem::size_of::<T>() >= 64 {
41            let mut uninit = MaybeUninit::uninit();
42            let mut refs = FastArrayVec::<_, MAX_VECTORED_CHUNK>::new(&mut uninit);
43
44            for t in i {
45                self.variants.encode(&(t.is_some() as u8));
46                if let Some(t) = t {
47                    // Safety: encode_vectored guarantees less than `MAX_VECTORED_CHUNK` items.
48                    unsafe { refs.push_unchecked(t) };
49                }
50            }
51
52            let refs = refs.as_slice();
53            let Some(some_count) = NonZeroUsize::new(refs.len()) else {
54                return;
55            };
56            self.some.reserve(some_count);
57            self.some.encode_vectored(refs.iter().copied());
58        } else {
59            // Safety: encode_vectored guarantees `i.size_hint().1.unwrap() != 0`.
60            let size_hint =
61                unsafe { NonZeroUsize::new(i.size_hint().1.unwrap()).unwrap_unchecked() };
62            // size_of::<T>() is small, so we can just assume all elements are Some.
63            // This will waste a maximum of `MAX_VECTORED_CHUNK * size_of::<T>()` bytes.
64            self.some.reserve(size_hint);
65
66            for option in i {
67                self.variants.encode(&(option.is_some() as u8));
68                if let Some(t) = option {
69                    self.some.encode(t);
70                }
71            }
72        }
73    }
74}
75
76impl<T: Encode> Buffer for OptionEncoder<T> {
77    fn collect_into(&mut self, out: &mut Vec<u8>) {
78        self.variants.collect_into(out);
79        self.some.collect_into(out);
80    }
81
82    fn reserve(&mut self, additional: NonZeroUsize) {
83        self.variants.reserve(additional);
84        // We don't know how many are Some, so we can't reserve more.
85    }
86}
87
88pub struct OptionDecoder<'a, T: Decode<'a>> {
89    variants: VariantDecoder<'a, 2, false>,
90    some: T::Decoder,
91}
92
93// Can't derive since it would bound T: Default.
94impl<'a, T: Decode<'a>> Default for OptionDecoder<'a, T> {
95    fn default() -> Self {
96        Self {
97            variants: Default::default(),
98            some: Default::default(),
99        }
100    }
101}
102
103impl<'a, T: Decode<'a>> View<'a> for OptionDecoder<'a, T> {
104    fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
105        self.variants.populate(input, length)?;
106        self.some.populate(input, self.variants.length(1))
107    }
108}
109
110impl<'a, T: Decode<'a>> Decoder<'a, Option<T>> for OptionDecoder<'a, T> {
111    #[inline(always)]
112    fn decode_in_place(&mut self, out: &mut MaybeUninit<Option<T>>) {
113        if self.variants.decode() != 0 {
114            out.write(Some(self.some.decode()));
115        } else {
116            out.write(None);
117        }
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use alloc::vec::Vec;
124
125    #[rustfmt::skip]
126    fn bench_data() -> Vec<Option<(u64, u32, u8, i32, u64, u32, u8, i32, u64, (u32, u8, i32, u64, u32, u8, i32))>> {
127        crate::random_data(1000)
128    }
129    crate::bench_encode_decode!(option_vec: Vec<_>);
130}
131
132#[cfg(test)]
133mod tests2 {
134    use alloc::vec::Vec;
135
136    #[rustfmt::skip]
137    fn bench_data() -> Vec<Option<u16>> {
138        crate::random_data(1000)
139    }
140    crate::bench_encode_decode!(option_u16_vec: Vec<_>);
141}