bitcode/derive/
map.rs

1use crate::coder::{Buffer, Decoder, Encoder, Result, View};
2use crate::derive::{Decode, Encode};
3use crate::length::{LengthDecoder, LengthEncoder};
4use alloc::collections::BTreeMap;
5use alloc::vec::Vec;
6use core::num::NonZeroUsize;
7
8#[cfg(feature = "std")]
9use core::hash::{BuildHasher, Hash};
10#[cfg(feature = "std")]
11use std::collections::HashMap;
12
13pub struct MapEncoder<K: Encode, V: Encode> {
14    lengths: LengthEncoder,
15    keys: K::Encoder,
16    values: V::Encoder,
17}
18
19// Can't derive since it would bound K + V: Default.
20impl<K: Encode, V: Encode> Default for MapEncoder<K, V> {
21    fn default() -> Self {
22        Self {
23            lengths: Default::default(),
24            keys: Default::default(),
25            values: Default::default(),
26        }
27    }
28}
29
30impl<K: Encode, V: Encode> Buffer for MapEncoder<K, V> {
31    fn collect_into(&mut self, out: &mut Vec<u8>) {
32        self.lengths.collect_into(out);
33        self.keys.collect_into(out);
34        self.values.collect_into(out);
35    }
36
37    fn reserve(&mut self, additional: NonZeroUsize) {
38        self.lengths.reserve(additional);
39        // We don't know the lengths of the maps, so we can't reserve more.
40    }
41}
42
43pub struct MapDecoder<'a, K: Decode<'a>, V: Decode<'a>> {
44    lengths: LengthDecoder<'a>,
45    keys: K::Decoder,
46    values: V::Decoder,
47}
48
49// Can't derive since it would bound K + V: Default.
50impl<'a, K: Decode<'a>, V: Decode<'a>> Default for MapDecoder<'a, K, V> {
51    fn default() -> Self {
52        Self {
53            lengths: Default::default(),
54            keys: Default::default(),
55            values: Default::default(),
56        }
57    }
58}
59
60impl<'a, K: Decode<'a>, V: Decode<'a>> View<'a> for MapDecoder<'a, K, V> {
61    fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
62        self.lengths.populate(input, length)?;
63        self.keys.populate(input, self.lengths.length())?;
64        self.values.populate(input, self.lengths.length())
65    }
66}
67
68macro_rules! encode_body {
69    ($t:ty) => {
70        #[inline(always)]
71        fn encode(&mut self, map: &$t) {
72            let n = map.len();
73            self.lengths.encode(&n);
74
75            if let Some(n) = NonZeroUsize::new(n) {
76                self.keys.reserve(n);
77                self.values.reserve(n);
78                for (k, v) in map {
79                    self.keys.encode(k);
80                    self.values.encode(v);
81                }
82            }
83        }
84    };
85}
86macro_rules! decode_body {
87    ($t:ty) => {
88        #[inline(always)]
89        fn decode(&mut self) -> $t {
90            // BTreeMap::from_iter is faster than BTreeMap::insert since it can add the items in
91            // bulk once it ensures they are sorted. They are about equivalent for HashMap.
92            (0..self.lengths.decode())
93                .map(|_| (self.keys.decode(), self.values.decode()))
94                .collect()
95        }
96    };
97}
98
99impl<K: Encode, V: Encode> Encoder<BTreeMap<K, V>> for MapEncoder<K, V> {
100    encode_body!(BTreeMap<K, V>);
101}
102impl<'a, K: Decode<'a> + Ord, V: Decode<'a>> Decoder<'a, BTreeMap<K, V>> for MapDecoder<'a, K, V> {
103    decode_body!(BTreeMap<K, V>);
104}
105
106#[cfg(feature = "std")]
107impl<K: Encode, V: Encode, S> Encoder<HashMap<K, V, S>> for MapEncoder<K, V> {
108    encode_body!(HashMap<K, V, S>);
109}
110#[cfg(feature = "std")]
111impl<'a, K: Decode<'a> + Eq + Hash, V: Decode<'a>, S: BuildHasher + Default>
112    Decoder<'a, HashMap<K, V, S>> for MapDecoder<'a, K, V>
113{
114    decode_body!(HashMap<K, V, S>);
115}
116
117#[cfg(test)]
118mod test {
119    use alloc::collections::BTreeMap;
120
121    fn bench_data<T: FromIterator<(u8, u8)>>() -> T {
122        (0..=255).map(|k| (k, 0)).collect()
123    }
124
125    crate::bench_encode_decode!(btree_map: BTreeMap<_, _>);
126    #[cfg(feature = "std")]
127    crate::bench_encode_decode!(hash_map: std::collections::HashMap<_, _>);
128}