bitcode/
length.rs

1use crate::coder::{Buffer, Decoder, Encoder, Result, View};
2use crate::error::{err, error};
3use crate::fast::{CowSlice, NextUnchecked, VecImpl};
4use crate::int::{IntDecoder, IntEncoder};
5use crate::pack::{pack_bytes, unpack_bytes};
6use alloc::vec::Vec;
7use core::num::NonZeroUsize;
8
9#[derive(Default)]
10pub struct LengthEncoder {
11    small: VecImpl<u8>,
12    large: IntEncoder<usize>,
13}
14
15impl Encoder<usize> for LengthEncoder {
16    #[inline(always)]
17    fn encode(&mut self, &v: &usize) {
18        unsafe {
19            let end_ptr = self.small.end_ptr();
20            if v < 255 {
21                *end_ptr = v as u8;
22            } else {
23                #[cold]
24                #[inline(never)]
25                unsafe fn encode_slow(end_ptr: *mut u8, large: &mut IntEncoder<usize>, v: usize) {
26                    *end_ptr = 255;
27                    large.reserve(NonZeroUsize::new(1).unwrap());
28                    large.encode(&v);
29                }
30                encode_slow(end_ptr, &mut self.large, v);
31            }
32            self.small.increment_len();
33        }
34    }
35}
36
37pub trait Len {
38    fn len(&self) -> usize;
39}
40
41impl<T> Len for &[T] {
42    #[inline(always)]
43    fn len(&self) -> usize {
44        <[T]>::len(self)
45    }
46}
47
48impl LengthEncoder {
49    /// Encodes a length known to be < `255`.
50    #[cfg(feature = "arrayvec")]
51    #[inline(always)]
52    pub fn encode_less_than_255(&mut self, n: usize) {
53        use crate::fast::PushUnchecked;
54        debug_assert!(n < 255);
55        unsafe { self.small.push_unchecked(n as u8) };
56    }
57
58    /// Encodes lengths less than `N`. Have to reserve `N * i.size_hint().1 elements`.
59    /// Skips calling encode for T::len() == 0. Returns `true` if it failed due to a length over `N`.
60    #[inline(always)]
61    pub fn encode_vectored_max_len<T: Len, const N: usize>(
62        &mut self,
63        i: impl Iterator<Item = T>,
64        mut encode: impl FnMut(T),
65    ) -> bool {
66        debug_assert!(N <= 64);
67        let mut ptr = self.small.end_ptr();
68        for t in i {
69            let n = t.len();
70            unsafe {
71                *ptr = n as u8;
72                ptr = ptr.add(1);
73            }
74            if n == 0 {
75                continue;
76            }
77            if n > N {
78                // Don't set end ptr (elements won't be saved).
79                return true;
80            }
81            encode(t);
82        }
83        self.small.set_end_ptr(ptr);
84        false
85    }
86
87    #[inline(always)]
88    pub fn encode_vectored_fallback<T: Len>(
89        &mut self,
90        i: impl Iterator<Item = T>,
91        mut reserve_and_encode_large: impl FnMut(T),
92    ) {
93        for v in i {
94            let n = v.len();
95            self.encode(&n);
96            reserve_and_encode_large(v);
97        }
98    }
99}
100
101impl Buffer for LengthEncoder {
102    fn collect_into(&mut self, out: &mut Vec<u8>) {
103        pack_bytes(self.small.as_mut_slice(), out);
104        self.small.clear();
105        self.large.collect_into(out);
106    }
107
108    fn reserve(&mut self, additional: NonZeroUsize) {
109        self.small.reserve(additional.get()); // All lengths inhabit small, only large ones inhabit large.
110    }
111}
112
113#[derive(Default)]
114pub struct LengthDecoder<'a> {
115    small: CowSlice<'a, u8>,
116    large: IntDecoder<'a, usize>,
117    sum: usize,
118}
119
120impl<'a> LengthDecoder<'a> {
121    pub fn length(&self) -> usize {
122        self.sum
123    }
124
125    // For decoding lengths multiple times (e.g. ArrayVec, utf8 validation).
126    pub fn borrowed_clone<'me: 'a>(&'me self) -> LengthDecoder<'me> {
127        let mut small = CowSlice::default();
128        small.set_borrowed_slice_impl(self.small.ref_slice().clone());
129        Self {
130            small,
131            large: self.large.borrowed_clone(),
132            sum: self.sum,
133        }
134    }
135
136    /// Returns if any of the decoded lengths are > `N`.
137    /// Safety: `length` must be the `length` passed to populate.
138    #[cfg_attr(not(feature = "arrayvec"), allow(unused))]
139    pub unsafe fn any_greater_than<const N: usize>(&self, length: usize) -> bool {
140        if N < 255 {
141            // Fast path: don't need to scan large lengths since there shouldn't be any.
142            // A large length will have a 255 in small which will be greater than N.
143            self.small
144                .as_slice(length)
145                .iter()
146                .copied()
147                .max()
148                .unwrap_or(0) as usize
149                > N
150        } else {
151            let mut decoder = self.borrowed_clone();
152            (0..length).any(|_| decoder.decode() > N)
153        }
154    }
155}
156
157impl<'a> View<'a> for LengthDecoder<'a> {
158    fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
159        unpack_bytes(input, length, &mut self.small)?;
160        let small = unsafe { self.small.as_slice(length) };
161
162        // Summing &[u8] can't overflow since that would require > 2^56 bytes of memory.
163        let mut sum: u64 = small.iter().map(|&v| v as u64).sum();
164
165        // Fast path for small lengths: If sum(small) < 255 every small < 255 so large_length is 0.
166        if sum < 255 {
167            self.sum = sum as usize;
168            return Ok(());
169        }
170
171        // Every 255 byte indicates a large is present.
172        let large_length = small.iter().filter(|&&v| v == 255).count();
173        self.large.populate(input, large_length)?;
174
175        // Can't overflow since sum includes large_length many 255s.
176        sum -= large_length as u64 * 255;
177
178        // Summing &[u64] can overflow, so we check it.
179        let mut decoder = self.large.borrowed_clone();
180        for _ in 0..large_length {
181            let v: usize = decoder.decode();
182            sum = sum
183                .checked_add(v as u64)
184                .ok_or_else(|| error("length overflow"))?;
185        }
186        if sum >= HUGE_LEN {
187            return err("huge length"); // Lets us optimize decode with unreachable_unchecked.
188        }
189        self.sum = sum.try_into().map_err(|_| error("length > usize::MAX"))?;
190        Ok(())
191    }
192}
193
194// isize::MAX / (largest type we want to allocate without possibility of overflow)
195const HUGE_LEN: u64 = 0x7FFFFFFF_FFFFFFFF / 4096;
196
197impl<'a> Decoder<'a, usize> for LengthDecoder<'a> {
198    #[inline(always)]
199    fn decode(&mut self) -> usize {
200        let length = unsafe {
201            let v = self.small.mut_slice().next_unchecked();
202
203            if v < 255 {
204                v as usize
205            } else {
206                #[cold]
207                unsafe fn cold(large: &mut IntDecoder<'_, usize>) -> usize {
208                    large.decode()
209                }
210                cold(&mut self.large)
211            }
212        };
213
214        // Allows some checks in Vec::with_capacity to be removed if lto = true.
215        // Safety: sum < HUGE_LEN is checked in populate so all elements have to be < HUGE_LEN.
216        if length as u64 >= HUGE_LEN {
217            unsafe { core::hint::unreachable_unchecked() }
218        }
219        length
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::{LengthDecoder, LengthEncoder};
226    use crate::coder::{Buffer, Decoder, Encoder, View};
227    use core::num::NonZeroUsize;
228
229    #[test]
230    fn test() {
231        let mut encoder = LengthEncoder::default();
232        encoder.reserve(NonZeroUsize::new(3).unwrap());
233        encoder.encode(&1);
234        encoder.encode(&255);
235        encoder.encode(&2);
236        let bytes = encoder.collect();
237
238        let mut decoder = LengthDecoder::default();
239        decoder.populate(&mut bytes.as_slice(), 3).unwrap();
240        assert_eq!(decoder.decode(), 1);
241        assert_eq!(decoder.decode(), 255);
242        assert_eq!(decoder.decode(), 2);
243    }
244
245    #[cfg(target_pointer_width = "64")] // HUGE_LEN > u32::MAX
246    #[test]
247    fn huge_len() {
248        for (x, is_ok) in [(super::HUGE_LEN - 1, true), (super::HUGE_LEN, false)] {
249            let mut encoder = LengthEncoder::default();
250            encoder.reserve(NonZeroUsize::new(1).unwrap());
251            encoder.encode(&(x as usize));
252            let bytes = encoder.collect();
253
254            let mut decoder = LengthDecoder::default();
255            assert_eq!(decoder.populate(&mut bytes.as_slice(), 1).is_ok(), is_ok);
256        }
257    }
258}