bitcode/serde/
de.rs

1use crate::bool::BoolDecoder;
2use crate::coder::{Decoder, Result, View};
3use crate::consume::expect_eof;
4use crate::error::{err, error, Error};
5use crate::f32::F32Decoder;
6use crate::int::IntDecoder;
7use crate::length::LengthDecoder;
8use crate::serde::guard::guard_zst;
9use crate::serde::variant::VariantDecoder;
10use crate::serde::{default_box_slice, get_mut_or_resize, type_changed};
11use crate::str::StrDecoder;
12use alloc::boxed::Box;
13use alloc::vec::Vec;
14use serde::de::{
15    DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess, Visitor,
16};
17use serde::{Deserialize, Deserializer};
18
19// Redefine Result from crate::coder::Result to std::result::Result since the former isn't public.
20mod inner {
21    use super::*;
22    use core::result::Result;
23
24    /// Deserializes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Deserialize`].
25    ///
26    /// **Warning:** The format is incompatible with [`encode`][`crate::encode`] and subject to
27    /// change between major versions.
28    pub fn deserialize<'de, T: Deserialize<'de>>(mut bytes: &'de [u8]) -> Result<T, Error> {
29        let mut decoder = SerdeDecoder::Unspecified { length: 1 };
30        let t = T::deserialize(DecoderWrapper {
31            decoder: &mut decoder,
32            input: &mut bytes,
33        })?;
34        expect_eof(bytes)?;
35        Ok(t)
36    }
37}
38pub use inner::deserialize;
39
40enum SerdeDecoder<'a> {
41    Bool(BoolDecoder<'a>),
42    Enum((VariantDecoder<'a>, Vec<SerdeDecoder<'a>>)), // (variants, values)
43    F32(F32Decoder<'a>),
44    // We don't need signed integer decoders here because unsigned ones work the same.
45    Map((LengthDecoder<'a>, Box<(SerdeDecoder<'a>, SerdeDecoder<'a>)>)), // (lengths, (keys, values))
46    Seq((LengthDecoder<'a>, Box<SerdeDecoder<'a>>)),                     // (lengths, values)
47    Str(StrDecoder<'a>),
48    Tuple(Box<[SerdeDecoder<'a>]>), // [field0, field1, ..]
49    U8(IntDecoder<'a, u8>),
50    U16(IntDecoder<'a, u16>),
51    U32(IntDecoder<'a, u32>),
52    U64(IntDecoder<'a, u64>),
53    U128(IntDecoder<'a, u128>),
54    Unpopulated,
55    Unspecified { length: usize },
56}
57
58impl Default for SerdeDecoder<'_> {
59    fn default() -> Self {
60        Self::Unpopulated
61    }
62}
63
64impl<'a> View<'a> for SerdeDecoder<'a> {
65    fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
66        match self {
67            Self::Bool(d) => d.populate(input, length),
68            Self::Enum(d) => {
69                d.0.populate(input, length)?;
70                if let Some(max_variant_index) = d.0.max_variant_index() {
71                    get_mut_or_resize(&mut d.1, max_variant_index as usize);
72                    d.1.iter_mut()
73                        .enumerate()
74                        .try_for_each(|(i, variant)| variant.populate(input, d.0.length(i as u8)))
75                } else {
76                    Ok(())
77                }
78            }
79            Self::F32(d) => d.populate(input, length),
80            Self::Map(d) => {
81                d.0.populate(input, length)?;
82                let length = d.0.length();
83                d.1 .0.populate(input, length)?;
84                d.1 .1.populate(input, length)
85            }
86            Self::Seq(d) => {
87                d.0.populate(input, length)?;
88                let length = d.0.length();
89                d.1.populate(input, length)
90            }
91            Self::Str(d) => d.populate(input, length),
92            Self::Tuple(d) => d.iter_mut().try_for_each(|d| d.populate(input, length)),
93            Self::U8(d) => d.populate(input, length),
94            Self::U16(d) => d.populate(input, length),
95            Self::U32(d) => d.populate(input, length),
96            Self::U64(d) => d.populate(input, length),
97            Self::U128(d) => d.populate(input, length),
98            Self::Unpopulated => {
99                *self = Self::Unspecified { length };
100                Ok(())
101            }
102            Self::Unspecified { .. } => unreachable!(),
103        }
104    }
105}
106
107struct DecoderWrapper<'a, 'de> {
108    decoder: &'a mut SerdeDecoder<'de>,
109    input: &'a mut &'de [u8],
110}
111
112macro_rules! specify {
113    ($self:ident, $variant:ident) => {{
114        match &mut $self.decoder {
115            // Check if it's already the correct decoder. This results in 1 branch in the hot path.
116            SerdeDecoder::$variant(_) => (),
117            _ => {
118                // Either create the correct decoder if unspecified or diverge via panic/error.
119                #[cold]
120                #[rustfmt::skip]
121                fn cold<'de>(decoder: &mut SerdeDecoder<'de>, input: &mut &'de [u8]) -> Result<()> {
122                    let &mut SerdeDecoder::Unspecified { length } = decoder else {
123                        type_changed!()
124                    };
125                    *decoder = SerdeDecoder::$variant(Default::default());
126                    decoder.populate(input, length)
127                }
128                cold(&mut *$self.decoder, &mut *$self.input)?;
129            }
130        }
131        #[rustfmt::skip]
132        let SerdeDecoder::$variant(d) = &mut *$self.decoder else {
133            // Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either
134            // errors or sets lazy to the correct decoder.
135            unsafe { core::hint::unreachable_unchecked() };
136        };
137        d
138    }};
139}
140
141macro_rules! impl_de {
142    ($deserialize:ident, $visit:ident, $t:ty, $variant:ident) => {
143        #[inline(always)]
144        fn $deserialize<V>(mut self, v: V) -> Result<V::Value>
145        where
146            V: Visitor<'de>,
147        {
148            v.$visit(specify!(self, $variant).decode())
149        }
150    };
151}
152
153impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> {
154    type Error = Error;
155
156    fn deserialize_any<V>(self, _: V) -> Result<V::Value>
157    where
158        V: Visitor<'de>,
159    {
160        err("deserialize_any is not supported")
161    }
162
163    // Use native decoders.
164    impl_de!(deserialize_bool, visit_bool, bool, Bool);
165    impl_de!(deserialize_f32, visit_f32, f32, F32);
166    impl_de!(deserialize_u8, visit_u8, u8, U8);
167    impl_de!(deserialize_u16, visit_u16, u16, U16);
168    impl_de!(deserialize_u32, visit_u32, u32, U32);
169    impl_de!(deserialize_u64, visit_u64, u64, U64);
170    impl_de!(deserialize_u128, visit_u128, u128, U128);
171    impl_de!(deserialize_str, visit_borrowed_str, &str, Str);
172
173    // IntDecoder<unsigned> works on signed integers/f64 (but not chars).
174    impl_de!(deserialize_i8, visit_i8, i8, U8);
175    impl_de!(deserialize_i16, visit_i16, i16, U16);
176    impl_de!(deserialize_i32, visit_i32, i32, U32);
177    impl_de!(deserialize_i64, visit_i64, i64, U64);
178    impl_de!(deserialize_i128, visit_i128, i128, U128);
179    impl_de!(deserialize_f64, visit_f64, f64, U64);
180
181    #[inline(always)]
182    fn deserialize_char<V>(self, v: V) -> Result<V::Value>
183    where
184        V: Visitor<'de>,
185    {
186        v.visit_char(char::from_u32(u32::deserialize(self)?).ok_or_else(|| error("invalid char"))?)
187    }
188
189    #[inline(always)]
190    fn deserialize_string<V>(self, v: V) -> Result<V::Value>
191    where
192        V: Visitor<'de>,
193    {
194        self.deserialize_str(v)
195    }
196
197    #[inline(always)]
198    fn deserialize_bytes<V>(self, v: V) -> Result<V::Value>
199    where
200        V: Visitor<'de>,
201    {
202        self.deserialize_byte_buf(v) // TODO avoid allocation.
203    }
204
205    #[inline(always)]
206    fn deserialize_byte_buf<V>(self, v: V) -> Result<V::Value>
207    where
208        V: Visitor<'de>,
209    {
210        v.visit_byte_buf(<Vec<u8>>::deserialize(self)?)
211    }
212
213    #[inline(always)]
214    fn deserialize_option<V>(mut self, v: V) -> Result<V::Value>
215    where
216        V: Visitor<'de>,
217    {
218        let (variant_decoder, decoders) = specify!(self, Enum);
219        let variant_index = variant_decoder.decode();
220        // Safety: populate guarantees `variant_decoder.max_variant_index() < decoders.len()`.
221        let decoder = unsafe { decoders.get_unchecked_mut(variant_index as usize) };
222
223        match variant_index {
224            0 => v.visit_none(),
225            1 => v.visit_some(DecoderWrapper {
226                decoder,
227                input: &mut *self.input,
228            }),
229            _ => err("invalid option"),
230        }
231    }
232
233    #[inline(always)]
234    fn deserialize_unit<V>(self, v: V) -> Result<V::Value>
235    where
236        V: Visitor<'de>,
237    {
238        v.visit_unit()
239    }
240
241    #[inline(always)]
242    fn deserialize_unit_struct<V>(self, _: &'static str, v: V) -> Result<V::Value>
243    where
244        V: Visitor<'de>,
245    {
246        v.visit_unit()
247    }
248
249    #[inline(always)]
250    fn deserialize_newtype_struct<V>(self, _: &'static str, v: V) -> Result<V::Value>
251    where
252        V: Visitor<'de>,
253    {
254        v.visit_newtype_struct(self)
255    }
256
257    fn deserialize_seq<V>(mut self, v: V) -> Result<V::Value>
258    where
259        V: Visitor<'de>,
260    {
261        let (length_decoder, decoder) = specify!(self, Seq);
262        let len = length_decoder.decode();
263
264        struct Access<'a, 'de> {
265            wrapper: DecoderWrapper<'a, 'de>,
266            len: usize,
267        }
268        impl<'de> SeqAccess<'de> for Access<'_, 'de> {
269            type Error = Error;
270
271            #[inline(always)]
272            fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
273            where
274                T: DeserializeSeed<'de>,
275            {
276                guard_zst::<T::Value>(self.len)?;
277                if self.len != 0 {
278                    self.len -= 1;
279                    Ok(Some(DeserializeSeed::deserialize(
280                        seed,
281                        DecoderWrapper {
282                            decoder: &mut *self.wrapper.decoder,
283                            input: &mut *self.wrapper.input,
284                        },
285                    )?))
286                } else {
287                    Ok(None)
288                }
289            }
290
291            #[inline(always)]
292            fn size_hint(&self) -> Option<usize> {
293                Some(self.len)
294            }
295        }
296        v.visit_seq(Access {
297            wrapper: DecoderWrapper {
298                decoder,
299                input: self.input,
300            },
301            len,
302        })
303    }
304
305    #[inline(always)]
306    fn deserialize_tuple<V>(mut self, tuple_len: usize, v: V) -> Result<V::Value>
307    where
308        V: Visitor<'de>,
309    {
310        // Fast path: avoid overhead of tuple for 1 element.
311        if tuple_len == 1 {
312            return v.visit_seq(Access {
313                decoders: core::slice::from_mut(self.decoder),
314                input: self.input,
315                index: 0,
316            });
317        }
318
319        // Copy of specify! macro that takes an additional tuple_len parameter to cold.
320        match &mut self.decoder {
321            SerdeDecoder::Tuple(_) => (),
322            _ => {
323                #[cold]
324                fn cold<'de>(
325                    decoder: &mut SerdeDecoder<'de>,
326                    input: &mut &'de [u8],
327                    tuple_len: usize,
328                ) -> Result<()> {
329                    let &mut SerdeDecoder::Unspecified { length } = decoder else {
330                        type_changed!()
331                    };
332                    *decoder = SerdeDecoder::Tuple(default_box_slice(tuple_len));
333                    decoder.populate(input, length)
334                }
335                cold(&mut *self.decoder, &mut *self.input, tuple_len)?;
336            }
337        }
338        let SerdeDecoder::Tuple(decoders) = &mut *self.decoder else {
339            // Safety: see specify! macro which this is based on.
340            unsafe { core::hint::unreachable_unchecked() };
341        };
342        if decoders.len() != tuple_len {
343            type_changed!() // Removes multiple bounds checks.
344        }
345
346        struct Access<'a, 'de> {
347            decoders: &'a mut [SerdeDecoder<'de>],
348            input: &'a mut &'de [u8],
349            index: usize,
350        }
351        impl<'de> SeqAccess<'de> for Access<'_, 'de> {
352            type Error = Error;
353
354            #[inline(always)]
355            fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
356            where
357                T: DeserializeSeed<'de>,
358            {
359                if let Some(decoder) = self.decoders.get_mut(self.index) {
360                    self.index += 1;
361                    Ok(Some(DeserializeSeed::deserialize(
362                        seed,
363                        DecoderWrapper {
364                            decoder,
365                            input: &mut *self.input,
366                        },
367                    )?))
368                } else {
369                    Ok(None)
370                }
371            }
372
373            #[inline(always)]
374            fn size_hint(&self) -> Option<usize> {
375                Some(self.decoders.len())
376            }
377        }
378
379        v.visit_seq(Access {
380            decoders,
381            input: &mut *self.input,
382            index: 0,
383        })
384    }
385
386    #[inline(always)]
387    fn deserialize_tuple_struct<V>(self, _: &'static str, len: usize, v: V) -> Result<V::Value>
388    where
389        V: Visitor<'de>,
390    {
391        self.deserialize_tuple(len, v)
392    }
393
394    fn deserialize_map<V>(mut self, v: V) -> Result<V::Value>
395    where
396        V: Visitor<'de>,
397    {
398        let (length_decoder, decoders) = specify!(self, Map);
399        let len = length_decoder.decode();
400
401        struct Access<'a, 'de> {
402            decoders: &'a mut (SerdeDecoder<'de>, SerdeDecoder<'de>),
403            input: &'a mut &'de [u8],
404            len: usize,
405            key_deserialized: bool,
406        }
407
408        impl<'de> MapAccess<'de> for Access<'_, 'de> {
409            type Error = Error;
410
411            #[inline(always)]
412            fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
413            where
414                K: DeserializeSeed<'de>,
415            {
416                guard_zst::<K::Value>(self.len)?;
417                if self.len != 0 {
418                    self.len -= 1;
419                    // Safety: Make sure next_value_seed is called at most once after each len decrement.
420                    // We don't care if DeserializeSeed fails after this (not critical to safety).
421                    self.key_deserialized = true;
422                    Ok(Some(DeserializeSeed::deserialize(
423                        seed,
424                        DecoderWrapper {
425                            decoder: &mut self.decoders.0,
426                            input: &mut *self.input,
427                        },
428                    )?))
429                } else {
430                    Ok(None)
431                }
432            }
433
434            #[inline(always)]
435            fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
436            where
437                V: DeserializeSeed<'de>,
438            {
439                // Safety: Make sure next_value_seed is called at most once after each len decrement
440                // since only len values exist.
441                assert!(
442                    core::mem::take(&mut self.key_deserialized),
443                    "next_value_seed before next_key_seed"
444                );
445                DeserializeSeed::deserialize(
446                    seed,
447                    DecoderWrapper {
448                        decoder: &mut self.decoders.1,
449                        input: &mut *self.input,
450                    },
451                )
452            }
453            // TODO implement next_entry_seed to avoid checking key_deserialized.
454
455            #[inline(always)]
456            fn size_hint(&self) -> Option<usize> {
457                Some(self.len)
458            }
459        }
460
461        v.visit_map(Access {
462            decoders,
463            input: self.input,
464            len,
465            key_deserialized: false, // No keys have been deserialized yet, so next_value_seed can't be called.
466        })
467    }
468
469    #[inline(always)]
470    fn deserialize_struct<V>(
471        self,
472        _: &'static str,
473        fields: &'static [&'static str],
474        v: V,
475    ) -> Result<V::Value>
476    where
477        V: Visitor<'de>,
478    {
479        self.deserialize_tuple(fields.len(), v)
480    }
481
482    #[inline(always)]
483    fn deserialize_enum<V>(
484        self,
485        _: &'static str,
486        _: &'static [&'static str],
487        v: V,
488    ) -> Result<V::Value>
489    where
490        V: Visitor<'de>,
491    {
492        v.visit_enum(self)
493    }
494
495    fn deserialize_identifier<V>(self, _: V) -> Result<V::Value>
496    where
497        V: Visitor<'de>,
498    {
499        err("deserialize_identifier is not supported")
500    }
501
502    fn deserialize_ignored_any<V>(self, _: V) -> Result<V::Value>
503    where
504        V: Visitor<'de>,
505    {
506        err("deserialize_ignored_any is not supported")
507    }
508
509    #[inline(always)]
510    fn is_human_readable(&self) -> bool {
511        false
512    }
513}
514
515impl<'a, 'de> EnumAccess<'de> for DecoderWrapper<'a, 'de> {
516    type Error = Error;
517    type Variant = DecoderWrapper<'a, 'de>;
518
519    #[inline(always)]
520    fn variant_seed<V>(mut self, seed: V) -> Result<(V::Value, Self::Variant)>
521    where
522        V: DeserializeSeed<'de>,
523    {
524        let (variant_decoder, decoders) = specify!(self, Enum);
525        let variant_index = variant_decoder.decode();
526        // Safety: populate guarantees `variant_decoder.max_variant_index() < decoders.len()`.
527        let decoder = unsafe { decoders.get_unchecked_mut(variant_index as usize) };
528        let variant_index = variant_index as u32;
529
530        let val: Result<_> = seed.deserialize(variant_index.into_deserializer());
531        Ok((
532            val?,
533            DecoderWrapper {
534                decoder,
535                input: &mut *self.input,
536            },
537        ))
538    }
539}
540
541impl<'de> VariantAccess<'de> for DecoderWrapper<'_, 'de> {
542    type Error = Error;
543
544    #[inline(always)]
545    fn unit_variant(self) -> Result<()> {
546        Ok(())
547    }
548
549    #[inline(always)]
550    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
551    where
552        T: DeserializeSeed<'de>,
553    {
554        seed.deserialize(self)
555    }
556
557    #[inline(always)]
558    fn tuple_variant<V>(self, len: usize, v: V) -> Result<V::Value>
559    where
560        V: Visitor<'de>,
561    {
562        self.deserialize_tuple(len, v)
563    }
564
565    #[inline(always)]
566    fn struct_variant<V>(self, fields: &'static [&'static str], v: V) -> Result<V::Value>
567    where
568        V: Visitor<'de>,
569    {
570        self.deserialize_tuple(fields.len(), v)
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use alloc::borrow::ToOwned;
577    use alloc::collections::BTreeMap;
578    use alloc::string::String;
579    use alloc::vec::Vec;
580    use serde::de::MapAccess;
581    use serde::Deserializer;
582
583    #[test]
584    fn deserialize() {
585        macro_rules! test {
586            ($v:expr, $t:ty) => {
587                let v = $v;
588                let ser = crate::serialize::<$t>(&v).unwrap();
589                #[cfg(feature = "std")]
590                println!("{:<24} {ser:?}", stringify!($t));
591                assert_eq!(v, crate::deserialize::<$t>(&ser).unwrap());
592            };
593        }
594        // Primitives
595        test!(5, u8);
596        test!(5, u16);
597        test!(5, u32);
598        test!(5, u64);
599        test!(5, u128);
600        test!(5, i8);
601        test!(5, i16);
602        test!(5, i32);
603        test!(5, i64);
604        test!(5, i128);
605        test!(true, bool);
606        test!('a', char);
607
608        // Enums
609        test!(Some(true), Option<bool>);
610        test!(Ok(true), Result<bool, u32>);
611        test!(vec![Ok(true), Err(2)], Vec<Result<bool, u32>>);
612        test!(vec![Err(1), Ok(false)], Vec<Result<bool, u32>>);
613
614        // Maps
615        let mut map = BTreeMap::new();
616        map.insert(1u8, 11u8);
617        map.insert(2u8, 22u8);
618        test!(map, BTreeMap<u8, u8>);
619
620        // Sequences
621        test!("abc".to_owned(), String);
622        test!(vec![1u8, 2u8, 3u8], Vec<u8>);
623        // Make sure signed integers are being packed properly (output should end in 85).
624        test!(vec![0, -1, 0, -1, 0, -1, 0], Vec<i8>);
625        test!(vec![0, -1, 0, -1, 0, -1, 0], Vec<i16>);
626        test!(vec![0, -1, 0, -1, 0, -1, 0], Vec<i32>);
627        test!(vec![0, -1, 0, -1, 0, -1, 0], Vec<i64>);
628        test!(vec![0, -1, 0, -1, 0, -1, 0], Vec<i128>);
629        // Make sure f32 sign_exp is grouped (output should end in 4x 63).
630        test!(vec![1.0; 4], Vec<f32>);
631        test!(
632            vec!["abc".to_owned(), "def".to_owned(), "ghi".to_owned()],
633            Vec<String>
634        );
635
636        // Tuples
637        test!((1u8, 2u8, 3u8), (u8, u8, u8));
638        test!([1u8, 2u8, 3u8], [u8; 3]);
639        test!([], [u8; 0]);
640
641        // Complex.
642        test!(vec![(None, 3), (Some(4), 5)], Vec<(Option<u8>, u8)>);
643    }
644
645    #[test]
646    #[should_panic = "next_value_seed before next_key_seed"]
647    fn map_incorrect_len_values() {
648        let mut map = BTreeMap::new();
649        map.insert(1u8, 2u8);
650        let input = crate::serialize(&map).unwrap();
651
652        let w = super::DecoderWrapper {
653            decoder: &mut super::SerdeDecoder::Unspecified { length: 1 },
654            input: &mut input.as_slice(),
655        };
656
657        struct Visitor;
658        impl<'de> serde::de::Visitor<'de> for Visitor {
659            type Value = ();
660            fn expecting(&self, _: &mut core::fmt::Formatter) -> core::fmt::Result {
661                unreachable!()
662            }
663            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
664            where
665                A: MapAccess<'de>,
666            {
667                assert_eq!(map.next_key::<u8>().unwrap().unwrap(), 1u8);
668                assert_eq!(map.next_value::<u8>().unwrap(), 2u8);
669                map.next_value::<u8>().unwrap();
670                Ok(())
671            }
672        }
673        w.deserialize_map(Visitor).unwrap();
674    }
675}