use crate::coder::{Buffer, Decoder, Encoder, Result, View};
use crate::consume::mul_length;
use crate::derive::{Decode, Encode};
use crate::fast::{FastSlice, FastVec, Unaligned};
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::num::NonZeroUsize;
pub struct ArrayEncoder<T: Encode, const N: usize>(T::Encoder);
impl<T: Encode, const N: usize> Default for ArrayEncoder<T, N> {
fn default() -> Self {
Self(Default::default())
}
}
impl<T: Encode, const N: usize> Encoder<[T; N]> for ArrayEncoder<T, N> {
fn as_primitive(&mut self) -> Option<&mut FastVec<[T; N]>> {
if N == 0 {
return None;
}
self.0.as_primitive().map(|v| {
debug_assert!(v.len() % N == 0);
unsafe { core::mem::transmute(v) }
})
}
#[inline(always)]
fn encode(&mut self, array: &[T; N]) {
for v in array {
self.0.encode(v);
}
}
}
impl<T: Encode, const N: usize> Buffer for ArrayEncoder<T, N> {
fn collect_into(&mut self, out: &mut Vec<u8>) {
self.0.collect_into(out);
}
fn reserve(&mut self, additional: NonZeroUsize) {
if N == 0 {
return; }
self.0.reserve(
additional
.checked_mul(NonZeroUsize::new(N).unwrap())
.unwrap(),
);
}
}
pub struct ArrayDecoder<'a, T: Decode<'a>, const N: usize>(T::Decoder);
impl<'a, T: Decode<'a>, const N: usize> Default for ArrayDecoder<'a, T, N> {
fn default() -> Self {
Self(Default::default())
}
}
impl<'a, T: Decode<'a>, const N: usize> View<'a> for ArrayDecoder<'a, T, N> {
fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
let length = mul_length(length, N)?;
self.0.populate(input, length)
}
}
impl<'a, T: Decode<'a>, const N: usize> Decoder<'a, [T; N]> for ArrayDecoder<'a, T, N> {
fn as_primitive(&mut self) -> Option<&mut FastSlice<Unaligned<[T; N]>>> {
self.0.as_primitive().map(|s| {
unsafe { core::mem::transmute(s) }
})
}
#[inline(always)]
fn decode_in_place(&mut self, out: &mut MaybeUninit<[T; N]>) {
let out = unsafe { &mut *(out.as_mut_ptr() as *mut [MaybeUninit<T>; N]) };
for out in out {
self.0.decode_in_place(out);
}
}
}
#[cfg(test)]
mod tests {
use crate::coder::{Buffer, Encoder};
use crate::error::err;
use crate::length::LengthEncoder;
use crate::{decode, encode};
use alloc::vec::Vec;
use core::num::NonZeroUsize;
#[test]
fn test_empty_array() {
type T = [u8; 0];
let empty_array = T::default();
decode::<T>(&encode(&empty_array)).unwrap();
decode::<Vec<T>>(&encode(&vec![empty_array; 100])).unwrap();
}
#[test]
fn test_length_overflow() {
const N: usize = 16384;
let mut encoder = LengthEncoder::default();
encoder.reserve(NonZeroUsize::MIN);
encoder.encode(&(usize::MAX / N + 1));
let bytes = encoder.collect();
assert_eq!(decode::<Vec<[u8; N]>>(&bytes), err("length overflow"));
}
fn bench_data() -> Vec<Vec<[u8; 3]>> {
crate::random_data::<u8>(125)
.into_iter()
.map(|n| (0..n / 16).map(|_| [0, 0, 255]).collect())
.collect()
}
crate::bench_encode_decode!(u8_array_vecs: Vec<Vec<[u8; 3]>>);
}