bitcode/
buffer.rs

1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use core::any::TypeId;
4
5/// A buffer for reusing allocations between calls to [`Buffer::encode`] and/or [`Buffer::decode`].
6///
7/// ```rust
8/// use bitcode::{Buffer, Encode, Decode};
9///
10/// let original = "Hello world!";
11///
12/// let mut buffer = Buffer::new();
13/// buffer.encode(&original);
14/// let encoded: &[u8] = buffer.encode(&original); // Won't allocate
15///
16/// let mut buffer = Buffer::new();
17/// buffer.decode::<&str>(&encoded).unwrap();
18/// let decoded: &str = buffer.decode(&encoded).unwrap(); // Won't allocate
19/// assert_eq!(original, decoded);
20/// ```
21#[derive(Default)]
22pub struct Buffer {
23    pub(crate) registry: Registry,
24    pub(crate) out: Vec<u8>, // Isn't stored in registry because all encoders can share this.
25}
26
27impl Buffer {
28    /// Constructs a new buffer.
29    pub fn new() -> Self {
30        Self::default()
31    }
32}
33
34// Set of arbitrary types.
35#[derive(Default)]
36pub(crate) struct Registry(Vec<(TypeId, ErasedBox)>);
37
38impl Registry {
39    /// Gets a `&mut T` if it already exists or initializes one with [`Default`].
40    #[cfg(test)]
41    pub(crate) fn get<T: Default + Send + Sync + 'static>(&mut self) -> &mut T {
42        // Safety: T is static.
43        unsafe { self.get_non_static::<T>() }
44    }
45
46    /// Like [`Registry::get`] but can get non-static types.
47    /// # Safety
48    /// Lifetimes are the responsibility of the caller. `&'static [u8]` and `&'a [u8]` are the same
49    /// type from the perspective of this function.
50    pub(crate) unsafe fn get_non_static<T: Default + Send + Sync>(&mut self) -> &mut T {
51        // Use non-generic function to avoid monomorphization.
52        #[inline(never)]
53        fn inner(me: &mut Registry, type_id: TypeId, create: fn() -> ErasedBox) -> *mut () {
54            // Use sorted Vec + binary search because we expect fewer insertions than lookups.
55            // We could use a HashMap, but that seems like overkill.
56            match me.0.binary_search_by_key(&type_id, |(k, _)| *k) {
57                Ok(i) => me.0[i].1.ptr,
58                Err(i) => {
59                    #[cold]
60                    #[inline(never)]
61                    fn cold(
62                        me: &mut Registry,
63                        i: usize,
64                        type_id: TypeId,
65                        create: fn() -> ErasedBox,
66                    ) -> *mut () {
67                        me.0.insert(i, (type_id, create()));
68                        me.0[i].1.ptr
69                    }
70                    cold(me, i, type_id, create)
71                }
72            }
73        }
74        let erased_ptr = inner(self, non_static_type_id::<T>(), || {
75            // Safety: Caller upholds any lifetime requirements.
76            ErasedBox::new(T::default())
77        });
78
79        // Safety: type_id uniquely identifies the type, so the entry with equal TypeId is a T.
80        &mut *(erased_ptr as *mut T)
81    }
82}
83
84/// Ignores lifetimes in `T` when determining its [`TypeId`].
85/// https://github.com/rust-lang/rust/issues/41875#issuecomment-317292888
86fn non_static_type_id<T: ?Sized>() -> TypeId {
87    use core::marker::PhantomData;
88    trait NonStaticAny {
89        fn get_type_id(&self) -> TypeId
90        where
91            Self: 'static;
92    }
93    impl<T: ?Sized> NonStaticAny for PhantomData<T> {
94        fn get_type_id(&self) -> TypeId
95        where
96            Self: 'static,
97        {
98            TypeId::of::<T>()
99        }
100    }
101    let phantom_data = PhantomData::<T>;
102    NonStaticAny::get_type_id(unsafe {
103        core::mem::transmute::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>(&phantom_data)
104    })
105}
106
107/// `Box<T>` but of an unknown runtime `T`, requires unsafe to get the `T` back out.
108struct ErasedBox {
109    ptr: *mut (),             // Box<T>
110    drop: unsafe fn(*mut ()), // fn(Box<T>)
111}
112
113// Safety: `ErasedBox::new` ensures `T: Send + Sync`.
114unsafe impl Send for ErasedBox {}
115unsafe impl Sync for ErasedBox {}
116
117impl ErasedBox {
118    /// Allocates a [`Box<T>`] which doesn't know its own type. Only works on `T: Sized`.
119    /// # Safety
120    /// Ignores lifetimes so drop may be called after `T`'s lifetime has expired.
121    unsafe fn new<T: Send + Sync>(t: T) -> Self {
122        let ptr = Box::into_raw(Box::new(t)) as *mut ();
123        let drop: unsafe fn(*mut ()) = core::mem::transmute(drop::<Box<T>> as fn(Box<T>));
124        Self { ptr, drop }
125    }
126}
127
128impl Drop for ErasedBox {
129    fn drop(&mut self) {
130        // Safety: `ErasedBox::new` put a `Box<T>` in self.ptr and an `fn(Box<T>)` in self.drop.
131        unsafe { (self.drop)(self.ptr) };
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::{non_static_type_id, Buffer, ErasedBox, Registry};
138    use test::{black_box, Bencher};
139
140    #[test]
141    fn buffer() {
142        let mut b = Buffer::new();
143        assert_eq!(b.encode(&false), &[0]);
144        assert_eq!(b.encode(&true), &[1]);
145        assert_eq!(b.decode::<bool>(&[0]).unwrap(), false);
146        assert_eq!(b.decode::<bool>(&[1]).unwrap(), true);
147
148        fn assert_send_sync<T: Send + Sync>() {}
149        assert_send_sync::<Buffer>()
150    }
151
152    #[test]
153    fn registry() {
154        let mut r = Registry::default();
155        assert_eq!(*r.get::<u8>(), 0);
156        *r.get::<u8>() = 1;
157        assert_eq!(*r.get::<u8>(), 1);
158
159        assert_eq!(*r.get::<u16>(), 0);
160        *r.get::<u16>() = 5;
161        assert_eq!(*r.get::<u16>(), 5);
162
163        assert_eq!(*r.get::<u8>(), 1);
164    }
165
166    #[test]
167    fn type_id() {
168        assert_ne!(non_static_type_id::<u8>(), non_static_type_id::<i8>());
169        assert_ne!(non_static_type_id::<()>(), non_static_type_id::<[(); 1]>());
170        assert_ne!(
171            non_static_type_id::<&'static mut [u8]>(),
172            non_static_type_id::<&'static [u8]>()
173        );
174        assert_ne!(
175            non_static_type_id::<*mut u8>(),
176            non_static_type_id::<*const u8>()
177        );
178        fn f<'a>(_: &'a ()) {
179            assert_eq!(
180                non_static_type_id::<&'static [u8]>(),
181                non_static_type_id::<&'a [u8]>()
182            );
183            assert_eq!(
184                non_static_type_id::<&'static ()>(),
185                non_static_type_id::<&'a ()>()
186            );
187        }
188        f(&());
189    }
190
191    #[test]
192    fn erased_box() {
193        use alloc::sync::Arc;
194        let rc = Arc::new(());
195        struct TestDrop(#[allow(unused)] Arc<()>);
196        let b = unsafe { ErasedBox::new(TestDrop(Arc::clone(&rc))) };
197        assert_eq!(Arc::strong_count(&rc), 2);
198        drop(b);
199        assert_eq!(Arc::strong_count(&rc), 1);
200    }
201
202    macro_rules! register10 {
203        ($registry:ident $(, $t:literal)*) => {
204            $(
205                $registry.get::<[u8; $t]>();
206                $registry.get::<[i8; $t]>();
207                $registry.get::<[u16; $t]>();
208                $registry.get::<[i16; $t]>();
209                $registry.get::<[u32; $t]>();
210                $registry.get::<[i32; $t]>();
211                $registry.get::<[u64; $t]>();
212                $registry.get::<[i64; $t]>();
213                $registry.get::<[u128; $t]>();
214                $registry.get::<[i128; $t]>();
215            )*
216        }
217    }
218    type T = [u8; 1];
219
220    #[bench]
221    fn bench_registry1_get(b: &mut Bencher) {
222        let mut r = Registry::default();
223        r.get::<T>();
224        assert_eq!(r.0.len(), 1);
225        b.iter(|| {
226            black_box(*black_box(&mut r).get::<T>());
227        })
228    }
229
230    #[bench]
231    fn bench_registry10_get(b: &mut Bencher) {
232        let mut r = Registry::default();
233        r.get::<T>();
234        register10!(r, 1);
235        assert_eq!(r.0.len(), 10);
236        b.iter(|| {
237            black_box(*black_box(&mut r).get::<T>());
238        })
239    }
240
241    #[bench]
242    fn bench_registry100_get(b: &mut Bencher) {
243        let mut r = Registry::default();
244        r.get::<T>();
245        register10!(r, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
246        assert_eq!(r.0.len(), 100);
247        b.iter(|| {
248            black_box(*black_box(&mut r).get::<T>());
249        })
250    }
251}