alloy_rlp/
header.rs

1use crate::{decode::static_left_pad, Error, Result, EMPTY_LIST_CODE, EMPTY_STRING_CODE};
2use bytes::{Buf, BufMut};
3use core::hint::unreachable_unchecked;
4
5/// The header of an RLP item.
6#[derive(Clone, Debug, Default, PartialEq, Eq)]
7pub struct Header {
8    /// True if list, false otherwise.
9    pub list: bool,
10    /// Length of the payload in bytes.
11    pub payload_length: usize,
12}
13
14impl Header {
15    /// Decodes an RLP header from the given buffer.
16    ///
17    /// # Errors
18    ///
19    /// Returns an error if the buffer is too short or the header is invalid.
20    #[inline]
21    pub fn decode(buf: &mut &[u8]) -> Result<Self> {
22        let payload_length;
23        let mut list = false;
24        match get_next_byte(buf)? {
25            0..=0x7F => payload_length = 1,
26
27            b @ EMPTY_STRING_CODE..=0xB7 => {
28                buf.advance(1);
29                payload_length = (b - EMPTY_STRING_CODE) as usize;
30                if payload_length == 1 && get_next_byte(buf)? < EMPTY_STRING_CODE {
31                    return Err(Error::NonCanonicalSingleByte);
32                }
33            }
34
35            b @ (0xB8..=0xBF | 0xF8..=0xFF) => {
36                buf.advance(1);
37
38                list = b >= 0xF8; // second range
39                let code = if list { 0xF7 } else { 0xB7 };
40
41                // SAFETY: `b - code` is always in the range `1..=8` in the current match arm.
42                // The compiler/LLVM apparently cannot prove this because of the `|` pattern +
43                // the above `if`, since it can do it in the other arms with only 1 range.
44                let len_of_len = unsafe { b.checked_sub(code).unwrap_unchecked() } as usize;
45                if len_of_len == 0 || len_of_len > 8 {
46                    unsafe { unreachable_unchecked() }
47                }
48
49                if buf.len() < len_of_len {
50                    return Err(Error::InputTooShort);
51                }
52                // SAFETY: length checked above
53                let len = unsafe { buf.get_unchecked(..len_of_len) };
54                buf.advance(len_of_len);
55
56                let len = u64::from_be_bytes(static_left_pad(len)?);
57                payload_length =
58                    usize::try_from(len).map_err(|_| Error::Custom("Input too big"))?;
59                if payload_length < 56 {
60                    return Err(Error::NonCanonicalSize);
61                }
62            }
63
64            b @ EMPTY_LIST_CODE..=0xF7 => {
65                buf.advance(1);
66                list = true;
67                payload_length = (b - EMPTY_LIST_CODE) as usize;
68            }
69        }
70
71        if buf.remaining() < payload_length {
72            return Err(Error::InputTooShort);
73        }
74
75        Ok(Self { list, payload_length })
76    }
77
78    /// Decodes the next payload from the given buffer, advancing it.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if the buffer is too short or the header is invalid.
83    #[inline]
84    pub fn decode_bytes<'a>(buf: &mut &'a [u8], is_list: bool) -> Result<&'a [u8]> {
85        let Self { list, payload_length } = Self::decode(buf)?;
86
87        if list != is_list {
88            return Err(if is_list { Error::UnexpectedString } else { Error::UnexpectedList });
89        }
90
91        // SAFETY: this is already checked in `decode`
92        let bytes = unsafe { advance_unchecked(buf, payload_length) };
93        Ok(bytes)
94    }
95
96    /// Decodes a string slice from the given buffer, advancing it.
97    ///
98    /// # Errors
99    ///
100    /// Returns an error if the buffer is too short or the header is invalid.
101    #[inline]
102    pub fn decode_str<'a>(buf: &mut &'a [u8]) -> Result<&'a str> {
103        let bytes = Self::decode_bytes(buf, false)?;
104        core::str::from_utf8(bytes).map_err(|_| Error::Custom("invalid string"))
105    }
106
107    /// Extracts the next payload from the given buffer, advancing it.
108    ///
109    /// The returned `PayloadView` provides a structured view of the payload, allowing for efficient
110    /// parsing of nested items without unnecessary allocations.
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if:
115    /// - The buffer is too short
116    /// - The header is invalid
117    /// - Any nested headers (for list items) are invalid
118    #[inline]
119    pub fn decode_raw<'a>(buf: &mut &'a [u8]) -> Result<PayloadView<'a>> {
120        let Self { list, payload_length } = Self::decode(buf)?;
121        // SAFETY: this is already checked in `decode`
122        let mut payload = unsafe { advance_unchecked(buf, payload_length) };
123
124        if !list {
125            return Ok(PayloadView::String(payload));
126        }
127
128        let mut items = alloc::vec::Vec::new();
129        while !payload.is_empty() {
130            // store the start of the current item for later slice creation
131            let item_start = payload;
132
133            // decode the header of the next RLP item, advancing the payload
134            let Self { payload_length, .. } = Self::decode(&mut payload)?;
135            // SAFETY: this is already checked in `decode`
136            unsafe { advance_unchecked(&mut payload, payload_length) };
137
138            // calculate the total length of the item (header + payload) by subtracting the
139            // remaining payload length from the initial length
140            let item_length = item_start.len() - payload.len();
141            items.push(&item_start[..item_length]);
142        }
143
144        Ok(PayloadView::List(items))
145    }
146
147    /// Encodes the header into the `out` buffer.
148    #[inline]
149    pub fn encode(&self, out: &mut dyn BufMut) {
150        if self.payload_length < 56 {
151            let code = if self.list { EMPTY_LIST_CODE } else { EMPTY_STRING_CODE };
152            out.put_u8(code + self.payload_length as u8);
153        } else {
154            let len_be;
155            let len_be = crate::encode::to_be_bytes_trimmed!(len_be, self.payload_length);
156            let code = if self.list { 0xF7 } else { 0xB7 };
157            out.put_u8(code + len_be.len() as u8);
158            out.put_slice(len_be);
159        }
160    }
161
162    /// Returns the length of the encoded header.
163    #[inline]
164    pub const fn length(&self) -> usize {
165        crate::length_of_length(self.payload_length)
166    }
167
168    /// Returns the total length of the encoded header and payload.
169    pub const fn length_with_payload(&self) -> usize {
170        self.length() + self.payload_length
171    }
172}
173
174/// Structured representation of an RLP payload.
175#[derive(Debug)]
176pub enum PayloadView<'a> {
177    /// Payload is a byte string.
178    String(&'a [u8]),
179    /// Payload is a list of RLP encoded data.
180    List(alloc::vec::Vec<&'a [u8]>),
181}
182
183/// Same as `buf.first().ok_or(Error::InputTooShort)`.
184#[inline(always)]
185fn get_next_byte(buf: &[u8]) -> Result<u8> {
186    if buf.is_empty() {
187        return Err(Error::InputTooShort);
188    }
189    // SAFETY: length checked above
190    Ok(*unsafe { buf.get_unchecked(0) })
191}
192
193/// Same as `let (bytes, rest) = buf.split_at(cnt); *buf = rest; bytes`.
194#[inline(always)]
195unsafe fn advance_unchecked<'a>(buf: &mut &'a [u8], cnt: usize) -> &'a [u8] {
196    if buf.remaining() < cnt {
197        unreachable_unchecked()
198    }
199    let bytes = &buf[..cnt];
200    buf.advance(cnt);
201    bytes
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::Encodable;
208    use alloc::vec::Vec;
209    use core::fmt::Debug;
210
211    fn check_decode_raw_list<T: Encodable + Debug>(input: Vec<T>) {
212        let encoded = crate::encode(&input);
213        let expected: Vec<_> = input.iter().map(crate::encode).collect();
214        let mut buf = encoded.as_slice();
215        assert!(
216            matches!(Header::decode_raw(&mut buf), Ok(PayloadView::List(v)) if v == expected),
217            "input: {:?}, expected list: {:?}",
218            input,
219            expected
220        );
221        assert!(buf.is_empty(), "buffer was not advanced");
222    }
223
224    fn check_decode_raw_string(input: &str) {
225        let encoded = crate::encode(input);
226        let expected = Header::decode_bytes(&mut &encoded[..], false).unwrap();
227        let mut buf = encoded.as_slice();
228        assert!(
229            matches!(Header::decode_raw(&mut buf), Ok(PayloadView::String(v)) if v == expected),
230            "input: {}, expected string: {:?}",
231            input,
232            expected
233        );
234        assert!(buf.is_empty(), "buffer was not advanced");
235    }
236
237    #[test]
238    fn decode_raw() {
239        // empty list
240        check_decode_raw_list(Vec::<u64>::new());
241        // list of an empty RLP list
242        check_decode_raw_list(vec![Vec::<u64>::new()]);
243        // list of an empty RLP string
244        check_decode_raw_list(vec![""]);
245        // list of two RLP strings
246        check_decode_raw_list(vec![0xBBCCB5_u64, 0xFFC0B5_u64]);
247        // list of three RLP lists of various lengths
248        check_decode_raw_list(vec![vec![0u64], vec![1u64, 2u64], vec![3u64, 4u64, 5u64]]);
249        // list of four empty RLP strings
250        check_decode_raw_list(vec![0u64; 4]);
251        // list of all one-byte strings, some will have an RLP header and some won't
252        check_decode_raw_list((0u64..0xFF).collect());
253
254        // strings of various lengths
255        check_decode_raw_string("");
256        check_decode_raw_string(" ");
257        check_decode_raw_string("test1234");
258    }
259}