bitcode/derive/
array.rs

1use crate::coder::{Buffer, Decoder, Encoder, Result, View};
2use crate::consume::mul_length;
3use crate::derive::{Decode, Encode};
4use crate::fast::{FastSlice, FastVec, Unaligned};
5use alloc::vec::Vec;
6use core::mem::MaybeUninit;
7use core::num::NonZeroUsize;
8
9pub struct ArrayEncoder<T: Encode, const N: usize>(T::Encoder);
10
11// Can't derive since it would bound T: Default.
12impl<T: Encode, const N: usize> Default for ArrayEncoder<T, N> {
13    fn default() -> Self {
14        Self(Default::default())
15    }
16}
17
18impl<T: Encode, const N: usize> Encoder<[T; N]> for ArrayEncoder<T, N> {
19    fn as_primitive(&mut self) -> Option<&mut FastVec<[T; N]>> {
20        // FastVec doesn't work on ZST.
21        if N == 0 {
22            return None;
23        }
24        self.0.as_primitive().map(|v| {
25            debug_assert!(v.len() % N == 0);
26            // Safety: FastVec uses pointers for len/cap unlike Vec, so casting to FastVec<[T; N]>
27            // is safe as long as `v.len() % N == 0`. This will always be the case since we only
28            // encode in chunks of N.
29            // NOTE: If panics occurs during ArrayEncoder::encode and Buffer is reused, this
30            // invariant can be violated. Luckily primitive encoders never panic.
31            // TODO std::mem::take Buffer while encoding to avoid corrupted buffers.
32            unsafe { core::mem::transmute(v) }
33        })
34    }
35
36    #[inline(always)]
37    fn encode(&mut self, array: &[T; N]) {
38        // TODO use encode_vectored if N is large enough.
39        for v in array {
40            self.0.encode(v);
41        }
42    }
43}
44
45impl<T: Encode, const N: usize> Buffer for ArrayEncoder<T, N> {
46    fn collect_into(&mut self, out: &mut Vec<u8>) {
47        self.0.collect_into(out);
48    }
49
50    fn reserve(&mut self, additional: NonZeroUsize) {
51        if N == 0 {
52            return; // self.0.reserve takes NonZeroUsize and `additional * N == 0`.
53        }
54        self.0.reserve(
55            additional
56                .checked_mul(NonZeroUsize::new(N).unwrap())
57                .unwrap(),
58        );
59    }
60}
61
62pub struct ArrayDecoder<'a, T: Decode<'a>, const N: usize>(T::Decoder);
63
64// Can't derive since it would bound T: Default.
65impl<'a, T: Decode<'a>, const N: usize> Default for ArrayDecoder<'a, T, N> {
66    fn default() -> Self {
67        Self(Default::default())
68    }
69}
70
71impl<'a, T: Decode<'a>, const N: usize> View<'a> for ArrayDecoder<'a, T, N> {
72    fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
73        let length = mul_length(length, N)?;
74        self.0.populate(input, length)
75    }
76}
77
78impl<'a, T: Decode<'a>, const N: usize> Decoder<'a, [T; N]> for ArrayDecoder<'a, T, N> {
79    fn as_primitive(&mut self) -> Option<&mut FastSlice<Unaligned<[T; N]>>> {
80        self.0.as_primitive().map(|s| {
81            // Safety: FastSlice doesn't have a length unlike slice, so casting to FastSlice<[T; N]>
82            // is safe. N == 0 case is also safe for the same reason.
83            unsafe { core::mem::transmute(s) }
84        })
85    }
86
87    #[inline(always)]
88    fn decode_in_place(&mut self, out: &mut MaybeUninit<[T; N]>) {
89        // Safety: Equivalent to nightly MaybeUninit::transpose.
90        let out = unsafe { &mut *(out.as_mut_ptr() as *mut [MaybeUninit<T>; N]) };
91        for out in out {
92            self.0.decode_in_place(out);
93        }
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use crate::coder::{Buffer, Encoder};
100    use crate::error::err;
101    use crate::length::LengthEncoder;
102    use crate::{decode, encode};
103    use alloc::vec::Vec;
104    use core::num::NonZeroUsize;
105
106    #[test]
107    fn test_empty_array() {
108        type T = [u8; 0];
109        let empty_array = T::default();
110        decode::<T>(&encode(&empty_array)).unwrap();
111        decode::<Vec<T>>(&encode(&vec![empty_array; 100])).unwrap();
112    }
113
114    #[test]
115    fn test_length_overflow() {
116        const N: usize = 16384;
117        let mut encoder = LengthEncoder::default();
118        encoder.reserve(NonZeroUsize::MIN);
119        encoder.encode(&(usize::MAX / N + 1));
120        let bytes = encoder.collect();
121        assert_eq!(decode::<Vec<[u8; N]>>(&bytes), err("length overflow"));
122    }
123
124    fn bench_data() -> Vec<Vec<[u8; 3]>> {
125        crate::random_data::<u8>(125)
126            .into_iter()
127            .map(|n| (0..n / 16).map(|_| [0, 0, 255]).collect())
128            .collect()
129    }
130    crate::bench_encode_decode!(u8_array_vecs: Vec<Vec<[u8; 3]>>);
131}