bitcode/
length.rs
1use crate::coder::{Buffer, Decoder, Encoder, Result, View};
2use crate::error::{err, error};
3use crate::fast::{CowSlice, NextUnchecked, VecImpl};
4use crate::int::{IntDecoder, IntEncoder};
5use crate::pack::{pack_bytes, unpack_bytes};
6use alloc::vec::Vec;
7use core::num::NonZeroUsize;
8
9#[derive(Default)]
10pub struct LengthEncoder {
11 small: VecImpl<u8>,
12 large: IntEncoder<usize>,
13}
14
15impl Encoder<usize> for LengthEncoder {
16 #[inline(always)]
17 fn encode(&mut self, &v: &usize) {
18 unsafe {
19 let end_ptr = self.small.end_ptr();
20 if v < 255 {
21 *end_ptr = v as u8;
22 } else {
23 #[cold]
24 #[inline(never)]
25 unsafe fn encode_slow(end_ptr: *mut u8, large: &mut IntEncoder<usize>, v: usize) {
26 *end_ptr = 255;
27 large.reserve(NonZeroUsize::new(1).unwrap());
28 large.encode(&v);
29 }
30 encode_slow(end_ptr, &mut self.large, v);
31 }
32 self.small.increment_len();
33 }
34 }
35}
36
37pub trait Len {
38 fn len(&self) -> usize;
39}
40
41impl<T> Len for &[T] {
42 #[inline(always)]
43 fn len(&self) -> usize {
44 <[T]>::len(self)
45 }
46}
47
48impl LengthEncoder {
49 #[cfg(feature = "arrayvec")]
51 #[inline(always)]
52 pub fn encode_less_than_255(&mut self, n: usize) {
53 use crate::fast::PushUnchecked;
54 debug_assert!(n < 255);
55 unsafe { self.small.push_unchecked(n as u8) };
56 }
57
58 #[inline(always)]
61 pub fn encode_vectored_max_len<T: Len, const N: usize>(
62 &mut self,
63 i: impl Iterator<Item = T>,
64 mut encode: impl FnMut(T),
65 ) -> bool {
66 debug_assert!(N <= 64);
67 let mut ptr = self.small.end_ptr();
68 for t in i {
69 let n = t.len();
70 unsafe {
71 *ptr = n as u8;
72 ptr = ptr.add(1);
73 }
74 if n == 0 {
75 continue;
76 }
77 if n > N {
78 return true;
80 }
81 encode(t);
82 }
83 self.small.set_end_ptr(ptr);
84 false
85 }
86
87 #[inline(always)]
88 pub fn encode_vectored_fallback<T: Len>(
89 &mut self,
90 i: impl Iterator<Item = T>,
91 mut reserve_and_encode_large: impl FnMut(T),
92 ) {
93 for v in i {
94 let n = v.len();
95 self.encode(&n);
96 reserve_and_encode_large(v);
97 }
98 }
99}
100
101impl Buffer for LengthEncoder {
102 fn collect_into(&mut self, out: &mut Vec<u8>) {
103 pack_bytes(self.small.as_mut_slice(), out);
104 self.small.clear();
105 self.large.collect_into(out);
106 }
107
108 fn reserve(&mut self, additional: NonZeroUsize) {
109 self.small.reserve(additional.get()); }
111}
112
113#[derive(Default)]
114pub struct LengthDecoder<'a> {
115 small: CowSlice<'a, u8>,
116 large: IntDecoder<'a, usize>,
117 sum: usize,
118}
119
120impl<'a> LengthDecoder<'a> {
121 pub fn length(&self) -> usize {
122 self.sum
123 }
124
125 pub fn borrowed_clone<'me: 'a>(&'me self) -> LengthDecoder<'me> {
127 let mut small = CowSlice::default();
128 small.set_borrowed_slice_impl(self.small.ref_slice().clone());
129 Self {
130 small,
131 large: self.large.borrowed_clone(),
132 sum: self.sum,
133 }
134 }
135
136 #[cfg_attr(not(feature = "arrayvec"), allow(unused))]
139 pub unsafe fn any_greater_than<const N: usize>(&self, length: usize) -> bool {
140 if N < 255 {
141 self.small
144 .as_slice(length)
145 .iter()
146 .copied()
147 .max()
148 .unwrap_or(0) as usize
149 > N
150 } else {
151 let mut decoder = self.borrowed_clone();
152 (0..length).any(|_| decoder.decode() > N)
153 }
154 }
155}
156
157impl<'a> View<'a> for LengthDecoder<'a> {
158 fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
159 unpack_bytes(input, length, &mut self.small)?;
160 let small = unsafe { self.small.as_slice(length) };
161
162 let mut sum: u64 = small.iter().map(|&v| v as u64).sum();
164
165 if sum < 255 {
167 self.sum = sum as usize;
168 return Ok(());
169 }
170
171 let large_length = small.iter().filter(|&&v| v == 255).count();
173 self.large.populate(input, large_length)?;
174
175 sum -= large_length as u64 * 255;
177
178 let mut decoder = self.large.borrowed_clone();
180 for _ in 0..large_length {
181 let v: usize = decoder.decode();
182 sum = sum
183 .checked_add(v as u64)
184 .ok_or_else(|| error("length overflow"))?;
185 }
186 if sum >= HUGE_LEN {
187 return err("huge length"); }
189 self.sum = sum.try_into().map_err(|_| error("length > usize::MAX"))?;
190 Ok(())
191 }
192}
193
194const HUGE_LEN: u64 = 0x7FFFFFFF_FFFFFFFF / 4096;
196
197impl<'a> Decoder<'a, usize> for LengthDecoder<'a> {
198 #[inline(always)]
199 fn decode(&mut self) -> usize {
200 let length = unsafe {
201 let v = self.small.mut_slice().next_unchecked();
202
203 if v < 255 {
204 v as usize
205 } else {
206 #[cold]
207 unsafe fn cold(large: &mut IntDecoder<'_, usize>) -> usize {
208 large.decode()
209 }
210 cold(&mut self.large)
211 }
212 };
213
214 if length as u64 >= HUGE_LEN {
217 unsafe { core::hint::unreachable_unchecked() }
218 }
219 length
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::{LengthDecoder, LengthEncoder};
226 use crate::coder::{Buffer, Decoder, Encoder, View};
227 use core::num::NonZeroUsize;
228
229 #[test]
230 fn test() {
231 let mut encoder = LengthEncoder::default();
232 encoder.reserve(NonZeroUsize::new(3).unwrap());
233 encoder.encode(&1);
234 encoder.encode(&255);
235 encoder.encode(&2);
236 let bytes = encoder.collect();
237
238 let mut decoder = LengthDecoder::default();
239 decoder.populate(&mut bytes.as_slice(), 3).unwrap();
240 assert_eq!(decoder.decode(), 1);
241 assert_eq!(decoder.decode(), 255);
242 assert_eq!(decoder.decode(), 2);
243 }
244
245 #[cfg(target_pointer_width = "64")] #[test]
247 fn huge_len() {
248 for (x, is_ok) in [(super::HUGE_LEN - 1, true), (super::HUGE_LEN, false)] {
249 let mut encoder = LengthEncoder::default();
250 encoder.reserve(NonZeroUsize::new(1).unwrap());
251 encoder.encode(&(x as usize));
252 let bytes = encoder.collect();
253
254 let mut decoder = LengthDecoder::default();
255 assert_eq!(decoder.populate(&mut bytes.as_slice(), 1).is_ok(), is_ok);
256 }
257 }
258}