1use crate::coder::{Buffer, Decoder, Encoder, Result, View};
2use crate::consume::consume_byte_arrays;
3use crate::fast::{FastSlice, NextUnchecked, PushUnchecked, VecImpl};
4use alloc::vec::Vec;
5use core::mem::MaybeUninit;
6use core::num::NonZeroUsize;
7
8#[derive(Default)]
9pub struct F32Encoder(VecImpl<f32>);
10
11impl Encoder<f32> for F32Encoder {
12 #[inline(always)]
13 fn as_primitive(&mut self) -> Option<&mut VecImpl<f32>> {
14 Some(&mut self.0)
15 }
16
17 #[inline(always)]
18 fn encode(&mut self, t: &f32) {
19 unsafe { self.0.push_unchecked(*t) };
20 }
21}
22
23fn chunks_uninit<A, B>(m: &mut [MaybeUninit<A>]) -> &mut [MaybeUninit<B>] {
26 use core::mem::{align_of, size_of};
27 assert_eq!(align_of::<B>(), align_of::<A>());
28 assert_eq!(0, size_of::<B>() % size_of::<A>());
29 let divisor = size_of::<B>() / size_of::<A>();
30 unsafe {
32 core::slice::from_raw_parts_mut(m.as_mut_ptr() as *mut MaybeUninit<B>, m.len() / divisor)
33 }
34}
35
36impl Buffer for F32Encoder {
37 fn collect_into(&mut self, out: &mut Vec<u8>) {
38 let floats = self.0.as_slice();
39 let byte_len = core::mem::size_of_val(floats);
40 out.reserve(byte_len);
41 let uninit = &mut out.spare_capacity_mut()[..byte_len];
42
43 let (mantissa, sign_exp) = uninit.split_at_mut(floats.len() * 3);
44 let mantissa: &mut [MaybeUninit<[u8; 3]>] = chunks_uninit(mantissa);
45
46 const CHUNK_SIZE: usize = 4;
48 let chunks_len = floats.len() / CHUNK_SIZE;
49 let chunks_floats = chunks_len * CHUNK_SIZE;
50 let chunks: &[[u32; CHUNK_SIZE]] = bytemuck::cast_slice(&floats[..chunks_floats]);
51 let mantissa_chunks: &mut [MaybeUninit<[[u8; 4]; 3]>] = chunks_uninit(mantissa);
52 let sign_exp_chunks: &mut [MaybeUninit<[u8; 4]>] = chunks_uninit(sign_exp);
53
54 for ci in 0..chunks_len {
55 let [a, b, c, d] = chunks[ci];
56
57 let m0 = a & 0xFF_FF_FF | (b << 24);
58 let m1 = ((b >> 8) & 0xFF_FF) | (c << 16);
59 let m2 = (c >> 16) & 0xFF | (d << 8);
60 let mantissa_chunk = &mut mantissa_chunks[ci];
61 mantissa_chunk.write([m0.to_le_bytes(), m1.to_le_bytes(), m2.to_le_bytes()]);
62
63 let se = (a >> 24) | ((b >> 24) << 8) | ((c >> 24) << 16) | ((d >> 24) << 24);
64 let sign_exp_chunk = &mut sign_exp_chunks[ci];
65 sign_exp_chunk.write(se.to_le_bytes());
66 }
67
68 for i in chunks_floats..floats.len() {
69 let [m @ .., se] = floats[i].to_le_bytes();
70 mantissa[i].write(m);
71 sign_exp[i].write(se);
72 }
73
74 unsafe { out.set_len(out.len() + byte_len) };
76 self.0.clear();
77 }
78
79 fn reserve(&mut self, additional: NonZeroUsize) {
80 self.0.reserve(additional.get());
81 }
82}
83
84#[derive(Default)]
85pub struct F32Decoder<'a> {
86 mantissa: FastSlice<'a, [u8; 3]>,
88 sign_exp: FastSlice<'a, u8>,
89}
90
91impl<'a> View<'a> for F32Decoder<'a> {
92 fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
93 let total: &[u8] = bytemuck::must_cast_slice(consume_byte_arrays::<4>(input, length)?);
94 let (mantissa, sign_exp) = total.split_at(length * 3);
95 let mantissa: &[[u8; 3]] = bytemuck::cast_slice(mantissa);
96 self.mantissa =
98 unsafe { FastSlice::from_raw_parts(total.as_ptr() as *const [u8; 3], mantissa.len()) };
99 self.sign_exp = sign_exp.into();
100 Ok(())
101 }
102}
103
104impl<'a> Decoder<'a, f32> for F32Decoder<'a> {
105 #[inline(always)]
106 fn decode(&mut self) -> f32 {
107 let mantissa_ptr = unsafe { self.mantissa.next_unchecked_as_ptr() };
108
109 let mantissa_extended = unsafe { *(mantissa_ptr as *const [u8; 4]) };
112 let mantissa = u32::from_le_bytes(mantissa_extended) & 0xFF_FF_FF;
113
114 let sign_exp = unsafe { self.sign_exp.next_unchecked() };
115 f32::from_bits(mantissa | ((sign_exp as u32) << 24))
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use rand::prelude::*;
123 use rand_chacha::ChaCha20Rng;
124
125 #[test]
126 fn test() {
127 for i in 1..16 {
128 let mut rng = ChaCha20Rng::from_seed(Default::default());
129 let floats: Vec<_> = (0..i).map(|_| f32::from_bits(rng.gen())).collect();
130
131 let mut encoder = F32Encoder::default();
132 encoder.reserve(NonZeroUsize::new(floats.len()).unwrap());
133 for &f in &floats {
134 encoder.encode(&f);
135 }
136 let bytes = encoder.collect();
137
138 let mut decoder = F32Decoder::default();
139 let mut slice = bytes.as_slice();
140 decoder.populate(&mut slice, floats.len()).unwrap();
141 assert!(slice.is_empty());
142 for &f in &floats {
143 assert_eq!(f.to_bits(), decoder.decode().to_bits());
144 }
145 }
146 }
147
148 fn bench_data() -> Vec<f32> {
149 crate::random_data::<f32>(1500001)
150 }
151 crate::bench_encode_decode!(f32_vec: Vec<f32>);
152}
153
154#[cfg(test)]
155mod tests2 {
156 use alloc::vec::Vec;
157
158 fn bench_data() -> Vec<Vec<f32>> {
159 crate::random_data::<u8>(125)
160 .into_iter()
161 .map(|n| (0..n / 16).map(|_| 0.0).collect())
162 .collect()
163 }
164 crate::bench_encode_decode!(f32_vecs: Vec<Vec<f32>>);
165}