bitcode/derive/
array.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use crate::coder::{Buffer, Decoder, Encoder, Result, View};
use crate::consume::mul_length;
use crate::derive::{Decode, Encode};
use crate::fast::{FastSlice, FastVec, Unaligned};
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::num::NonZeroUsize;

pub struct ArrayEncoder<T: Encode, const N: usize>(T::Encoder);

// Can't derive since it would bound T: Default.
impl<T: Encode, const N: usize> Default for ArrayEncoder<T, N> {
    fn default() -> Self {
        Self(Default::default())
    }
}

impl<T: Encode, const N: usize> Encoder<[T; N]> for ArrayEncoder<T, N> {
    fn as_primitive(&mut self) -> Option<&mut FastVec<[T; N]>> {
        // FastVec doesn't work on ZST.
        if N == 0 {
            return None;
        }
        self.0.as_primitive().map(|v| {
            debug_assert!(v.len() % N == 0);
            // Safety: FastVec uses pointers for len/cap unlike Vec, so casting to FastVec<[T; N]>
            // is safe as long as `v.len() % N == 0`. This will always be the case since we only
            // encode in chunks of N.
            // NOTE: If panics occurs during ArrayEncoder::encode and Buffer is reused, this
            // invariant can be violated. Luckily primitive encoders never panic.
            // TODO std::mem::take Buffer while encoding to avoid corrupted buffers.
            unsafe { core::mem::transmute(v) }
        })
    }

    #[inline(always)]
    fn encode(&mut self, array: &[T; N]) {
        // TODO use encode_vectored if N is large enough.
        for v in array {
            self.0.encode(v);
        }
    }
}

impl<T: Encode, const N: usize> Buffer for ArrayEncoder<T, N> {
    fn collect_into(&mut self, out: &mut Vec<u8>) {
        self.0.collect_into(out);
    }

    fn reserve(&mut self, additional: NonZeroUsize) {
        if N == 0 {
            return; // self.0.reserve takes NonZeroUsize and `additional * N == 0`.
        }
        self.0.reserve(
            additional
                .checked_mul(NonZeroUsize::new(N).unwrap())
                .unwrap(),
        );
    }
}

pub struct ArrayDecoder<'a, T: Decode<'a>, const N: usize>(T::Decoder);

// Can't derive since it would bound T: Default.
impl<'a, T: Decode<'a>, const N: usize> Default for ArrayDecoder<'a, T, N> {
    fn default() -> Self {
        Self(Default::default())
    }
}

impl<'a, T: Decode<'a>, const N: usize> View<'a> for ArrayDecoder<'a, T, N> {
    fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
        let length = mul_length(length, N)?;
        self.0.populate(input, length)
    }
}

impl<'a, T: Decode<'a>, const N: usize> Decoder<'a, [T; N]> for ArrayDecoder<'a, T, N> {
    fn as_primitive(&mut self) -> Option<&mut FastSlice<Unaligned<[T; N]>>> {
        self.0.as_primitive().map(|s| {
            // Safety: FastSlice doesn't have a length unlike slice, so casting to FastSlice<[T; N]>
            // is safe. N == 0 case is also safe for the same reason.
            unsafe { core::mem::transmute(s) }
        })
    }

    #[inline(always)]
    fn decode_in_place(&mut self, out: &mut MaybeUninit<[T; N]>) {
        // Safety: Equivalent to nightly MaybeUninit::transpose.
        let out = unsafe { &mut *(out.as_mut_ptr() as *mut [MaybeUninit<T>; N]) };
        for out in out {
            self.0.decode_in_place(out);
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::coder::{Buffer, Encoder};
    use crate::error::err;
    use crate::length::LengthEncoder;
    use crate::{decode, encode};
    use alloc::vec::Vec;
    use core::num::NonZeroUsize;

    #[test]
    fn test_empty_array() {
        type T = [u8; 0];
        let empty_array = T::default();
        decode::<T>(&encode(&empty_array)).unwrap();
        decode::<Vec<T>>(&encode(&vec![empty_array; 100])).unwrap();
    }

    #[test]
    fn test_length_overflow() {
        const N: usize = 16384;
        let mut encoder = LengthEncoder::default();
        encoder.reserve(NonZeroUsize::MIN);
        encoder.encode(&(usize::MAX / N + 1));
        let bytes = encoder.collect();
        assert_eq!(decode::<Vec<[u8; N]>>(&bytes), err("length overflow"));
    }

    fn bench_data() -> Vec<Vec<[u8; 3]>> {
        crate::random_data::<u8>(125)
            .into_iter()
            .map(|n| (0..n / 16).map(|_| [0, 0, 255]).collect())
            .collect()
    }
    crate::bench_encode_decode!(u8_array_vecs: Vec<Vec<[u8; 3]>>);
}