bitcode/derive/
variant.rs

1use crate::coder::{Buffer, Decoder, Encoder, Result, View};
2use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl};
3use crate::pack::{pack_bytes_less_than, unpack_bytes_less_than};
4use alloc::vec::Vec;
5use core::num::NonZeroUsize;
6
7#[derive(Default)]
8pub struct VariantEncoder<const N: usize>(VecImpl<u8>);
9
10impl<const N: usize> Encoder<u8> for VariantEncoder<N> {
11    #[inline(always)]
12    fn encode(&mut self, v: &u8) {
13        unsafe { self.0.push_unchecked(*v) };
14    }
15}
16
17impl<const N: usize> Buffer for VariantEncoder<N> {
18    fn collect_into(&mut self, out: &mut Vec<u8>) {
19        assert!(N >= 2);
20        pack_bytes_less_than::<N>(self.0.as_slice(), out);
21        self.0.clear();
22    }
23
24    fn reserve(&mut self, additional: NonZeroUsize) {
25        self.0.reserve(additional.get());
26    }
27}
28
29pub struct VariantDecoder<'a, const N: usize, const C_STYLE: bool> {
30    variants: CowSlice<'a, u8>,
31    histogram: [usize; N], // Not required if C_STYLE. TODO don't reserve space for it.
32}
33
34// [(); N] doesn't implement Default.
35impl<const N: usize, const C_STYLE: bool> Default for VariantDecoder<'_, N, C_STYLE> {
36    fn default() -> Self {
37        Self {
38            variants: Default::default(),
39            histogram: core::array::from_fn(|_| 0),
40        }
41    }
42}
43
44// C style enums don't require length, so we can skip making a histogram for them.
45impl<'a, const N: usize> VariantDecoder<'a, N, false> {
46    pub fn length(&self, variant_index: u8) -> usize {
47        self.histogram[variant_index as usize]
48    }
49}
50
51impl<'a, const N: usize, const C_STYLE: bool> View<'a> for VariantDecoder<'a, N, C_STYLE> {
52    fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
53        assert!(N >= 2);
54        if C_STYLE {
55            unpack_bytes_less_than::<N, 0>(input, length, &mut self.variants)?;
56        } else {
57            self.histogram = unpack_bytes_less_than::<N, N>(input, length, &mut self.variants)?;
58        }
59        Ok(())
60    }
61}
62
63impl<'a, const N: usize, const C_STYLE: bool> Decoder<'a, u8> for VariantDecoder<'a, N, C_STYLE> {
64    // Guaranteed to output numbers less than N.
65    #[inline(always)]
66    fn decode(&mut self) -> u8 {
67        unsafe { self.variants.mut_slice().next_unchecked() }
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use crate::{decode, encode, Decode, Encode};
74    use alloc::vec::Vec;
75
76    #[allow(unused)]
77    #[test]
78    fn test_c_style_enum() {
79        #[derive(Encode, Decode)]
80        enum Enum1 {
81            A,
82            B,
83            C,
84            D,
85            E,
86            F,
87        }
88        #[derive(Decode)]
89        enum Enum2 {
90            A,
91            B,
92            C,
93            D,
94            E,
95        }
96        // 5 and 6 element enums serialize the same, so we can use them to test variant bounds checking.
97        assert!(matches!(decode(&encode(&Enum1::A)), Ok(Enum2::A)));
98        assert!(decode::<Enum2>(&encode(&Enum1::F)).is_err());
99        assert!(matches!(decode(&encode(&Enum1::F)), Ok(Enum1::F)));
100    }
101
102    #[allow(unused)]
103    #[test]
104    fn test_rust_style_enum() {
105        #[derive(Encode, Decode)]
106        enum Enum1 {
107            A(u8),
108            B,
109            C,
110            D,
111            E,
112            F,
113        }
114        #[derive(Decode)]
115        enum Enum2 {
116            A(u8),
117            B,
118            C,
119            D,
120            E,
121        }
122        // 5 and 6 element enums serialize the same, so we can use them to test variant bounds checking.
123        assert!(matches!(decode(&encode(&Enum1::A(1))), Ok(Enum2::A(1))));
124        assert!(decode::<Enum2>(&encode(&Enum1::F)).is_err());
125        assert!(matches!(decode(&encode(&Enum1::F)), Ok(Enum1::F)));
126    }
127
128    #[derive(Debug, PartialEq, Encode, Decode)]
129    enum BoolEnum {
130        True,
131        False,
132    }
133    fn bench_data() -> Vec<BoolEnum> {
134        crate::random_data(1000)
135            .into_iter()
136            .map(|v| if v { BoolEnum::True } else { BoolEnum::False })
137            .collect()
138    }
139    crate::bench_encode_decode!(bool_enum_vec: Vec<_>);
140}