bitcode/
f32.rs

1use crate::coder::{Buffer, Decoder, Encoder, Result, View};
2use crate::consume::consume_byte_arrays;
3use crate::fast::{FastSlice, NextUnchecked, PushUnchecked, VecImpl};
4use alloc::vec::Vec;
5use core::mem::MaybeUninit;
6use core::num::NonZeroUsize;
7
8#[derive(Default)]
9pub struct F32Encoder(VecImpl<f32>);
10
11impl Encoder<f32> for F32Encoder {
12    #[inline(always)]
13    fn as_primitive(&mut self) -> Option<&mut VecImpl<f32>> {
14        Some(&mut self.0)
15    }
16
17    #[inline(always)]
18    fn encode(&mut self, t: &f32) {
19        unsafe { self.0.push_unchecked(*t) };
20    }
21}
22
23/// [`bytemuck`] doesn't implement [`MaybeUninit`] casts. Slightly different from
24/// [`bytemuck::cast_slice_mut`] in that it will truncate partial elements instead of panicking.
25fn chunks_uninit<A, B>(m: &mut [MaybeUninit<A>]) -> &mut [MaybeUninit<B>] {
26    use core::mem::{align_of, size_of};
27    assert_eq!(align_of::<B>(), align_of::<A>());
28    assert_eq!(0, size_of::<B>() % size_of::<A>());
29    let divisor = size_of::<B>() / size_of::<A>();
30    // Safety: `align_of<B> == align_of<A>` and `size_of<B>()` is a multiple of `size_of<A>()`
31    unsafe {
32        core::slice::from_raw_parts_mut(m.as_mut_ptr() as *mut MaybeUninit<B>, m.len() / divisor)
33    }
34}
35
36impl Buffer for F32Encoder {
37    fn collect_into(&mut self, out: &mut Vec<u8>) {
38        let floats = self.0.as_slice();
39        let byte_len = core::mem::size_of_val(floats);
40        out.reserve(byte_len);
41        let uninit = &mut out.spare_capacity_mut()[..byte_len];
42
43        let (mantissa, sign_exp) = uninit.split_at_mut(floats.len() * 3);
44        let mantissa: &mut [MaybeUninit<[u8; 3]>] = chunks_uninit(mantissa);
45
46        // TODO SIMD version with PSHUFB.
47        const CHUNK_SIZE: usize = 4;
48        let chunks_len = floats.len() / CHUNK_SIZE;
49        let chunks_floats = chunks_len * CHUNK_SIZE;
50        let chunks: &[[u32; CHUNK_SIZE]] = bytemuck::cast_slice(&floats[..chunks_floats]);
51        let mantissa_chunks: &mut [MaybeUninit<[[u8; 4]; 3]>] = chunks_uninit(mantissa);
52        let sign_exp_chunks: &mut [MaybeUninit<[u8; 4]>] = chunks_uninit(sign_exp);
53
54        for ci in 0..chunks_len {
55            let [a, b, c, d] = chunks[ci];
56
57            let m0 = a & 0xFF_FF_FF | (b << 24);
58            let m1 = ((b >> 8) & 0xFF_FF) | (c << 16);
59            let m2 = (c >> 16) & 0xFF | (d << 8);
60            let mantissa_chunk = &mut mantissa_chunks[ci];
61            mantissa_chunk.write([m0.to_le_bytes(), m1.to_le_bytes(), m2.to_le_bytes()]);
62
63            let se = (a >> 24) | ((b >> 24) << 8) | ((c >> 24) << 16) | ((d >> 24) << 24);
64            let sign_exp_chunk = &mut sign_exp_chunks[ci];
65            sign_exp_chunk.write(se.to_le_bytes());
66        }
67
68        for i in chunks_floats..floats.len() {
69            let [m @ .., se] = floats[i].to_le_bytes();
70            mantissa[i].write(m);
71            sign_exp[i].write(se);
72        }
73
74        // Safety: We just initialized these elements in the loops above.
75        unsafe { out.set_len(out.len() + byte_len) };
76        self.0.clear();
77    }
78
79    fn reserve(&mut self, additional: NonZeroUsize) {
80        self.0.reserve(additional.get());
81    }
82}
83
84#[derive(Default)]
85pub struct F32Decoder<'a> {
86    // While it is true that this contains 1 bit of the exp we still call it mantissa.
87    mantissa: FastSlice<'a, [u8; 3]>,
88    sign_exp: FastSlice<'a, u8>,
89}
90
91impl<'a> View<'a> for F32Decoder<'a> {
92    fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
93        let total: &[u8] = bytemuck::must_cast_slice(consume_byte_arrays::<4>(input, length)?);
94        let (mantissa, sign_exp) = total.split_at(length * 3);
95        let mantissa: &[[u8; 3]] = bytemuck::cast_slice(mantissa);
96        // Equivalent to `mantissa.into()` but satisfies miri when we read extra in decode.
97        self.mantissa =
98            unsafe { FastSlice::from_raw_parts(total.as_ptr() as *const [u8; 3], mantissa.len()) };
99        self.sign_exp = sign_exp.into();
100        Ok(())
101    }
102}
103
104impl<'a> Decoder<'a, f32> for F32Decoder<'a> {
105    #[inline(always)]
106    fn decode(&mut self) -> f32 {
107        let mantissa_ptr = unsafe { self.mantissa.next_unchecked_as_ptr() };
108
109        // Loading 4 bytes instead of 3 is 30% faster, so we read 1 extra byte after mantissa_ptr.
110        // Safety: The extra byte is within bounds because sign_exp comes after mantissa.
111        let mantissa_extended = unsafe { *(mantissa_ptr as *const [u8; 4]) };
112        let mantissa = u32::from_le_bytes(mantissa_extended) & 0xFF_FF_FF;
113
114        let sign_exp = unsafe { self.sign_exp.next_unchecked() };
115        f32::from_bits(mantissa | ((sign_exp as u32) << 24))
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use rand::prelude::*;
123    use rand_chacha::ChaCha20Rng;
124
125    #[test]
126    fn test() {
127        for i in 1..16 {
128            let mut rng = ChaCha20Rng::from_seed(Default::default());
129            let floats: Vec<_> = (0..i).map(|_| f32::from_bits(rng.gen())).collect();
130
131            let mut encoder = F32Encoder::default();
132            encoder.reserve(NonZeroUsize::new(floats.len()).unwrap());
133            for &f in &floats {
134                encoder.encode(&f);
135            }
136            let bytes = encoder.collect();
137
138            let mut decoder = F32Decoder::default();
139            let mut slice = bytes.as_slice();
140            decoder.populate(&mut slice, floats.len()).unwrap();
141            assert!(slice.is_empty());
142            for &f in &floats {
143                assert_eq!(f.to_bits(), decoder.decode().to_bits());
144            }
145        }
146    }
147
148    fn bench_data() -> Vec<f32> {
149        crate::random_data::<f32>(1500001)
150    }
151    crate::bench_encode_decode!(f32_vec: Vec<f32>);
152}
153
154#[cfg(test)]
155mod tests2 {
156    use alloc::vec::Vec;
157
158    fn bench_data() -> Vec<Vec<f32>> {
159        crate::random_data::<u8>(125)
160            .into_iter()
161            .map(|n| (0..n / 16).map(|_| 0.0).collect())
162            .collect()
163    }
164    crate::bench_encode_decode!(f32_vecs: Vec<Vec<f32>>);
165}