rustls/msgs/
codec.rs

1use std::fmt::Debug;
2
3use crate::error::InvalidMessage;
4
5/// Wrapper over a slice of bytes that allows reading chunks from
6/// with the current position state held using a cursor.
7///
8/// A new reader for a sub section of the the buffer can be created
9/// using the `sub` function or a section of a certain length can
10/// be obtained using the `take` function
11pub struct Reader<'a> {
12    /// The underlying buffer storing the readers content
13    buffer: &'a [u8],
14    /// Stores the current reading position for the buffer
15    cursor: usize,
16}
17
18impl<'a> Reader<'a> {
19    /// Creates a new Reader of the provided `bytes` slice with
20    /// the initial cursor position of zero.
21    pub fn init(bytes: &[u8]) -> Reader {
22        Reader {
23            buffer: bytes,
24            cursor: 0,
25        }
26    }
27
28    /// Attempts to create a new Reader on a sub section of this
29    /// readers bytes by taking a slice of the provided `length`
30    /// will return None if there is not enough bytes
31    pub fn sub(&mut self, length: usize) -> Result<Reader, InvalidMessage> {
32        match self.take(length) {
33            Some(bytes) => Ok(Reader::init(bytes)),
34            None => Err(InvalidMessage::MessageTooShort),
35        }
36    }
37
38    /// Borrows a slice of all the remaining bytes
39    /// that appear after the cursor position.
40    ///
41    /// Moves the cursor to the end of the buffer length.
42    pub fn rest(&mut self) -> &[u8] {
43        let rest = &self.buffer[self.cursor..];
44        self.cursor = self.buffer.len();
45        rest
46    }
47
48    /// Attempts to borrow a slice of bytes from the current
49    /// cursor position of `length` if there is not enough
50    /// bytes remaining after the cursor to take the length
51    /// then None is returned instead.
52    pub fn take(&mut self, length: usize) -> Option<&[u8]> {
53        if self.left() < length {
54            return None;
55        }
56        let current = self.cursor;
57        self.cursor += length;
58        Some(&self.buffer[current..current + length])
59    }
60
61    /// Used to check whether the reader has any content left
62    /// after the cursor (cursor has not reached end of buffer)
63    pub fn any_left(&self) -> bool {
64        self.cursor < self.buffer.len()
65    }
66
67    pub fn expect_empty(&self, name: &'static str) -> Result<(), InvalidMessage> {
68        match self.any_left() {
69            true => Err(InvalidMessage::TrailingData(name)),
70            false => Ok(()),
71        }
72    }
73
74    /// Returns the cursor position which is also the number
75    /// of bytes that have been read from the buffer.
76    pub fn used(&self) -> usize {
77        self.cursor
78    }
79
80    /// Returns the number of bytes that are still able to be
81    /// read (The number of remaining takes)
82    pub fn left(&self) -> usize {
83        self.buffer.len() - self.cursor
84    }
85}
86
87/// Trait for implementing encoding and decoding functionality
88/// on something.
89pub trait Codec: Debug + Sized {
90    /// Function for encoding itself by appending itself to
91    /// the provided vec of bytes.
92    fn encode(&self, bytes: &mut Vec<u8>);
93
94    /// Function for decoding itself from the provided reader
95    /// will return Some if the decoding was successful or
96    /// None if it was not.
97    fn read(_: &mut Reader) -> Result<Self, InvalidMessage>;
98
99    /// Convenience function for encoding the implementation
100    /// into a vec and returning it
101    fn get_encoding(&self) -> Vec<u8> {
102        let mut bytes = Vec::new();
103        self.encode(&mut bytes);
104        bytes
105    }
106
107    /// Function for wrapping a call to the read function in
108    /// a Reader for the slice of bytes provided
109    fn read_bytes(bytes: &[u8]) -> Result<Self, InvalidMessage> {
110        let mut reader = Reader::init(bytes);
111        Self::read(&mut reader)
112    }
113}
114
115impl Codec for u8 {
116    fn encode(&self, bytes: &mut Vec<u8>) {
117        bytes.push(*self);
118    }
119
120    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
121        match r.take(1) {
122            Some(&[byte]) => Ok(byte),
123            _ => Err(InvalidMessage::MissingData("u8")),
124        }
125    }
126}
127
128pub fn put_u16(v: u16, out: &mut [u8]) {
129    let out: &mut [u8; 2] = (&mut out[..2]).try_into().unwrap();
130    *out = u16::to_be_bytes(v);
131}
132
133impl Codec for u16 {
134    fn encode(&self, bytes: &mut Vec<u8>) {
135        let mut b16 = [0u8; 2];
136        put_u16(*self, &mut b16);
137        bytes.extend_from_slice(&b16);
138    }
139
140    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
141        match r.take(2) {
142            Some(&[b1, b2]) => Ok(Self::from_be_bytes([b1, b2])),
143            _ => Err(InvalidMessage::MissingData("u8")),
144        }
145    }
146}
147
148// Make a distinct type for u24, even though it's a u32 underneath
149#[allow(non_camel_case_types)]
150#[derive(Debug, Copy, Clone)]
151pub struct u24(pub u32);
152
153#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
154impl From<u24> for usize {
155    #[inline]
156    fn from(v: u24) -> Self {
157        v.0 as Self
158    }
159}
160
161impl Codec for u24 {
162    fn encode(&self, bytes: &mut Vec<u8>) {
163        let be_bytes = u32::to_be_bytes(self.0);
164        bytes.extend_from_slice(&be_bytes[1..]);
165    }
166
167    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
168        match r.take(3) {
169            Some(&[a, b, c]) => Ok(Self(u32::from_be_bytes([0, a, b, c]))),
170            _ => Err(InvalidMessage::MissingData("u24")),
171        }
172    }
173}
174
175impl Codec for u32 {
176    fn encode(&self, bytes: &mut Vec<u8>) {
177        bytes.extend(Self::to_be_bytes(*self));
178    }
179
180    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
181        match r.take(4) {
182            Some(&[a, b, c, d]) => Ok(Self::from_be_bytes([a, b, c, d])),
183            _ => Err(InvalidMessage::MissingData("u32")),
184        }
185    }
186}
187
188pub fn put_u64(v: u64, bytes: &mut [u8]) {
189    let bytes: &mut [u8; 8] = (&mut bytes[..8]).try_into().unwrap();
190    *bytes = u64::to_be_bytes(v);
191}
192
193impl Codec for u64 {
194    fn encode(&self, bytes: &mut Vec<u8>) {
195        let mut b64 = [0u8; 8];
196        put_u64(*self, &mut b64);
197        bytes.extend_from_slice(&b64);
198    }
199
200    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
201        match r.take(8) {
202            Some(&[a, b, c, d, e, f, g, h]) => Ok(Self::from_be_bytes([a, b, c, d, e, f, g, h])),
203            _ => Err(InvalidMessage::MissingData("u64")),
204        }
205    }
206}
207
208/// Implement `Codec` for lists of elements that implement `TlsListElement`.
209///
210/// `TlsListElement` provides the size of the length prefix for the list.
211impl<T: Codec + TlsListElement + Debug> Codec for Vec<T> {
212    fn encode(&self, bytes: &mut Vec<u8>) {
213        let len_offset = bytes.len();
214        bytes.extend(match T::SIZE_LEN {
215            ListLength::U8 => &[0][..],
216            ListLength::U16 => &[0, 0],
217            ListLength::U24 { .. } => &[0, 0, 0],
218        });
219
220        for i in self {
221            i.encode(bytes);
222        }
223
224        match T::SIZE_LEN {
225            ListLength::U8 => {
226                let len = bytes.len() - len_offset - 1;
227                debug_assert!(len <= 0xff);
228                bytes[len_offset] = len as u8;
229            }
230            ListLength::U16 => {
231                let len = bytes.len() - len_offset - 2;
232                debug_assert!(len <= 0xffff);
233                let out: &mut [u8; 2] = (&mut bytes[len_offset..len_offset + 2])
234                    .try_into()
235                    .unwrap();
236                *out = u16::to_be_bytes(len as u16);
237            }
238            ListLength::U24 { .. } => {
239                let len = bytes.len() - len_offset - 3;
240                debug_assert!(len <= 0xff_ffff);
241                let len_bytes = u32::to_be_bytes(len as u32);
242                let out: &mut [u8; 3] = (&mut bytes[len_offset..len_offset + 3])
243                    .try_into()
244                    .unwrap();
245                out.copy_from_slice(&len_bytes[1..]);
246            }
247        }
248    }
249
250    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
251        let len = match T::SIZE_LEN {
252            ListLength::U8 => usize::from(u8::read(r)?),
253            ListLength::U16 => usize::from(u16::read(r)?),
254            ListLength::U24 { max } => Ord::min(usize::from(u24::read(r)?), max),
255        };
256
257        let mut sub = r.sub(len)?;
258        let mut ret = Self::new();
259        while sub.any_left() {
260            ret.push(T::read(&mut sub)?);
261        }
262
263        Ok(ret)
264    }
265}
266
267/// A trait for types that can be encoded and decoded in a list.
268///
269/// This trait is used to implement `Codec` for `Vec<T>`. Lists in the TLS wire format are
270/// prefixed with a length, the size of which depends on the type of the list elements.
271/// As such, the `Codec` implementation for `Vec<T>` requires an implementation of this trait
272/// for its element type `T`.
273///
274// TODO: make this `pub(crate)` once our MSRV allows it?
275pub trait TlsListElement {
276    const SIZE_LEN: ListLength;
277}
278
279/// The length of the length prefix for a list.
280///
281/// The types that appear in lists are limited to three kinds of length prefixes:
282/// 1, 2, and 3 bytes. For the latter kind, we require a `TlsListElement` implementer
283/// to specify a maximum length.
284///
285// TODO: make this `pub(crate)` once our MSRV allows it?
286pub enum ListLength {
287    U8,
288    U16,
289    U24 { max: usize },
290}