aws_smithy_eventstream/
frame.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Event Stream message frame types and serialization/deserialization logic.
7
8use crate::buf::count::CountBuf;
9use crate::buf::crc::{CrcBuf, CrcBufMut};
10use crate::error::{Error, ErrorKind};
11use aws_smithy_types::config_bag::{Storable, StoreReplace};
12use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
13use aws_smithy_types::str_bytes::StrBytes;
14use aws_smithy_types::DateTime;
15use bytes::{Buf, BufMut};
16use std::error::Error as StdError;
17use std::fmt;
18use std::mem::size_of;
19use std::sync::{mpsc, Mutex};
20
21const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::<u32>() as u32;
22const PRELUDE_LENGTH_BYTES_USIZE: usize = PRELUDE_LENGTH_BYTES as usize;
23const MESSAGE_CRC_LENGTH_BYTES: u32 = size_of::<u32>() as u32;
24const MAX_HEADER_NAME_LEN: usize = 255;
25const MIN_HEADER_LEN: usize = 2;
26
27pub(crate) const TYPE_TRUE: u8 = 0;
28pub(crate) const TYPE_FALSE: u8 = 1;
29pub(crate) const TYPE_BYTE: u8 = 2;
30pub(crate) const TYPE_INT16: u8 = 3;
31pub(crate) const TYPE_INT32: u8 = 4;
32pub(crate) const TYPE_INT64: u8 = 5;
33pub(crate) const TYPE_BYTE_ARRAY: u8 = 6;
34pub(crate) const TYPE_STRING: u8 = 7;
35pub(crate) const TYPE_TIMESTAMP: u8 = 8;
36pub(crate) const TYPE_UUID: u8 = 9;
37
38pub type SignMessageError = Box<dyn StdError + Send + Sync + 'static>;
39
40/// Signs an Event Stream message.
41pub trait SignMessage: fmt::Debug {
42    fn sign(&mut self, message: Message) -> Result<Message, SignMessageError>;
43
44    /// SigV4 requires an empty last signed message to be sent.
45    /// Other protocols do not require one.
46    /// Return `Some(_)` to send a signed last empty message, before completing the stream.
47    /// Return `None` to not send one and terminate the stream immediately.
48    fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>>;
49}
50
51/// A sender that gets placed in the request config to wire up an event stream signer after signing.
52#[derive(Debug)]
53#[non_exhaustive]
54pub struct DeferredSignerSender(Mutex<mpsc::Sender<Box<dyn SignMessage + Send + Sync>>>);
55
56impl DeferredSignerSender {
57    /// Creates a new `DeferredSignerSender`
58    fn new(tx: mpsc::Sender<Box<dyn SignMessage + Send + Sync>>) -> Self {
59        Self(Mutex::new(tx))
60    }
61
62    /// Sends a signer on the channel
63    pub fn send(
64        &self,
65        signer: Box<dyn SignMessage + Send + Sync>,
66    ) -> Result<(), mpsc::SendError<Box<dyn SignMessage + Send + Sync>>> {
67        self.0.lock().unwrap().send(signer)
68    }
69}
70
71impl Storable for DeferredSignerSender {
72    type Storer = StoreReplace<Self>;
73}
74
75/// Deferred event stream signer to allow a signer to be wired up later.
76///
77/// HTTP request signing takes place after serialization, and the event stream
78/// message stream body is established during serialization. Since event stream
79/// signing may need context from the initial HTTP signing operation, this
80/// [`DeferredSigner`] is needed to wire up the signer later in the request lifecycle.
81///
82/// This signer basically just establishes a MPSC channel so that the sender can
83/// be placed in the request's config. Then the HTTP signer implementation can
84/// retrieve the sender from that config and send an actual signing implementation
85/// with all the context needed.
86///
87/// When an event stream implementation needs to sign a message, the first call to
88/// sign will acquire a signing implementation off of the channel and cache it
89/// for the remainder of the operation.
90#[derive(Debug)]
91pub struct DeferredSigner {
92    rx: Option<Mutex<mpsc::Receiver<Box<dyn SignMessage + Send + Sync>>>>,
93    signer: Option<Box<dyn SignMessage + Send + Sync>>,
94}
95
96impl DeferredSigner {
97    pub fn new() -> (Self, DeferredSignerSender) {
98        let (tx, rx) = mpsc::channel();
99        (
100            Self {
101                rx: Some(Mutex::new(rx)),
102                signer: None,
103            },
104            DeferredSignerSender::new(tx),
105        )
106    }
107
108    fn acquire(&mut self) -> &mut (dyn SignMessage + Send + Sync) {
109        // Can't use `if let Some(signer) = &mut self.signer` because the borrow checker isn't smart enough
110        if self.signer.is_some() {
111            return self.signer.as_mut().unwrap().as_mut();
112        } else {
113            self.signer = Some(
114                self.rx
115                    .take()
116                    .expect("only taken once")
117                    .lock()
118                    .unwrap()
119                    .try_recv()
120                    .ok()
121                    // TODO(enableNewSmithyRuntimeCleanup): When the middleware implementation is removed,
122                    // this should panic rather than default to the `NoOpSigner`. The reason it defaults
123                    // is because middleware-based generic clients don't have any default middleware,
124                    // so there is no way to send a `NoOpSigner` by default when there is no other
125                    // auth scheme. The orchestrator auth setup is a lot more robust and will make
126                    // this problem trivial.
127                    .unwrap_or_else(|| Box::new(NoOpSigner {}) as _),
128            );
129            self.acquire()
130        }
131    }
132}
133
134impl SignMessage for DeferredSigner {
135    fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
136        self.acquire().sign(message)
137    }
138
139    fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
140        self.acquire().sign_empty()
141    }
142}
143
144#[derive(Debug)]
145pub struct NoOpSigner {}
146impl SignMessage for NoOpSigner {
147    fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
148        Ok(message)
149    }
150
151    fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
152        None
153    }
154}
155
156/// Converts a Smithy modeled Event Stream type into a [`Message`].
157pub trait MarshallMessage: fmt::Debug {
158    /// Smithy modeled input type to convert from.
159    type Input;
160
161    fn marshall(&self, input: Self::Input) -> Result<Message, Error>;
162}
163
164/// A successfully unmarshalled message that is either an `Event` or an `Error`.
165#[derive(Debug)]
166pub enum UnmarshalledMessage<T, E> {
167    Event(T),
168    Error(E),
169}
170
171/// Converts an Event Stream [`Message`] into a Smithy modeled type.
172pub trait UnmarshallMessage: fmt::Debug {
173    /// Smithy modeled type to convert into.
174    type Output;
175    /// Smithy modeled error to convert into.
176    type Error;
177
178    fn unmarshall(
179        &self,
180        message: &Message,
181    ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, Error>;
182}
183
184macro_rules! read_value {
185    ($buf:ident, $typ:ident, $size_typ:ident, $read_fn:ident) => {
186        if $buf.remaining() >= size_of::<$size_typ>() {
187            Ok(HeaderValue::$typ($buf.$read_fn()))
188        } else {
189            Err(ErrorKind::InvalidHeaderValue.into())
190        }
191    };
192}
193
194fn read_header_value_from<B: Buf>(mut buffer: B) -> Result<HeaderValue, Error> {
195    let value_type = buffer.get_u8();
196    match value_type {
197        TYPE_TRUE => Ok(HeaderValue::Bool(true)),
198        TYPE_FALSE => Ok(HeaderValue::Bool(false)),
199        TYPE_BYTE => read_value!(buffer, Byte, i8, get_i8),
200        TYPE_INT16 => read_value!(buffer, Int16, i16, get_i16),
201        TYPE_INT32 => read_value!(buffer, Int32, i32, get_i32),
202        TYPE_INT64 => read_value!(buffer, Int64, i64, get_i64),
203        TYPE_BYTE_ARRAY | TYPE_STRING => {
204            if buffer.remaining() > size_of::<u16>() {
205                let len = buffer.get_u16() as usize;
206                if buffer.remaining() < len {
207                    return Err(ErrorKind::InvalidHeaderValue.into());
208                }
209                let bytes = buffer.copy_to_bytes(len);
210                if value_type == TYPE_STRING {
211                    Ok(HeaderValue::String(
212                        bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?,
213                    ))
214                } else {
215                    Ok(HeaderValue::ByteArray(bytes))
216                }
217            } else {
218                Err(ErrorKind::InvalidHeaderValue.into())
219            }
220        }
221        TYPE_TIMESTAMP => {
222            if buffer.remaining() >= size_of::<i64>() {
223                let epoch_millis = buffer.get_i64();
224                Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis)))
225            } else {
226                Err(ErrorKind::InvalidHeaderValue.into())
227            }
228        }
229        TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128),
230        _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()),
231    }
232}
233
234fn write_header_value_to<B: BufMut>(value: &HeaderValue, mut buffer: B) -> Result<(), Error> {
235    use HeaderValue::*;
236    match value {
237        Bool(val) => buffer.put_u8(if *val { TYPE_TRUE } else { TYPE_FALSE }),
238        Byte(val) => {
239            buffer.put_u8(TYPE_BYTE);
240            buffer.put_i8(*val);
241        }
242        Int16(val) => {
243            buffer.put_u8(TYPE_INT16);
244            buffer.put_i16(*val);
245        }
246        Int32(val) => {
247            buffer.put_u8(TYPE_INT32);
248            buffer.put_i32(*val);
249        }
250        Int64(val) => {
251            buffer.put_u8(TYPE_INT64);
252            buffer.put_i64(*val);
253        }
254        ByteArray(val) => {
255            buffer.put_u8(TYPE_BYTE_ARRAY);
256            buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?);
257            buffer.put_slice(&val[..]);
258        }
259        String(val) => {
260            buffer.put_u8(TYPE_STRING);
261            buffer.put_u16(checked(
262                val.as_bytes().len(),
263                ErrorKind::HeaderValueTooLong.into(),
264            )?);
265            buffer.put_slice(&val.as_bytes()[..]);
266        }
267        Timestamp(time) => {
268            buffer.put_u8(TYPE_TIMESTAMP);
269            buffer.put_i64(
270                time.to_millis()
271                    .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?,
272            );
273        }
274        Uuid(val) => {
275            buffer.put_u8(TYPE_UUID);
276            buffer.put_u128(*val);
277        }
278        _ => {
279            panic!("matched on unexpected variant in `aws_smithy_types::event_stream::HeaderValue`")
280        }
281    }
282    Ok(())
283}
284
285/// Reads a header from the given `buffer`.
286fn read_header_from<B: Buf>(mut buffer: B) -> Result<(Header, usize), Error> {
287    if buffer.remaining() < MIN_HEADER_LEN {
288        return Err(ErrorKind::InvalidHeadersLength.into());
289    }
290
291    let mut counting_buf = CountBuf::new(&mut buffer);
292    let name_len = counting_buf.get_u8();
293    if name_len as usize >= counting_buf.remaining() {
294        return Err(ErrorKind::InvalidHeaderNameLength.into());
295    }
296
297    let name: StrBytes = counting_buf
298        .copy_to_bytes(name_len as usize)
299        .try_into()
300        .map_err(|_| ErrorKind::InvalidUtf8String)?;
301    let value = read_header_value_from(&mut counting_buf)?;
302    Ok((Header::new(name, value), counting_buf.into_count()))
303}
304
305/// Writes the header to the given `buffer`.
306fn write_header_to<B: BufMut>(header: &Header, mut buffer: B) -> Result<(), Error> {
307    if header.name().as_bytes().len() > MAX_HEADER_NAME_LEN {
308        return Err(ErrorKind::InvalidHeaderNameLength.into());
309    }
310
311    buffer.put_u8(u8::try_from(header.name().as_bytes().len()).expect("bounds check above"));
312    buffer.put_slice(&header.name().as_bytes()[..]);
313    write_header_value_to(header.value(), buffer)
314}
315
316/// Writes the given `headers` to a `buffer`.
317pub fn write_headers_to<B: BufMut>(headers: &[Header], mut buffer: B) -> Result<(), Error> {
318    for header in headers {
319        write_header_to(header, &mut buffer)?;
320    }
321    Ok(())
322}
323
324// Returns (total_len, header_len)
325fn read_prelude_from<B: Buf>(mut buffer: B) -> Result<(u32, u32), Error> {
326    let mut crc_buffer = CrcBuf::new(&mut buffer);
327
328    // If the buffer doesn't have the entire, then error
329    let total_len = crc_buffer.get_u32();
330    if crc_buffer.remaining() + size_of::<u32>() < total_len as usize {
331        return Err(ErrorKind::InvalidMessageLength.into());
332    }
333
334    // Validate the prelude
335    let header_len = crc_buffer.get_u32();
336    let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32());
337    if expected_crc != prelude_crc {
338        return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into());
339    }
340    // The header length can be 0 or >= 2, but must fit within the frame size
341    if header_len == 1 || header_len > max_header_len(total_len)? {
342        return Err(ErrorKind::InvalidHeadersLength.into());
343    }
344    Ok((total_len, header_len))
345}
346
347/// Reads a message from the given `buffer`. For streaming use cases, use
348/// the [`MessageFrameDecoder`] instead of this.
349pub fn read_message_from<B: Buf>(mut buffer: B) -> Result<Message, Error> {
350    if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE {
351        return Err(ErrorKind::InvalidMessageLength.into());
352    }
353
354    // Calculate a CRC as we go and read the prelude
355    let mut crc_buffer = CrcBuf::new(&mut buffer);
356    let (total_len, header_len) = read_prelude_from(&mut crc_buffer)?;
357
358    // Verify we have the full frame before continuing
359    let remaining_len = total_len
360        .checked_sub(PRELUDE_LENGTH_BYTES)
361        .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
362    if crc_buffer.remaining() < remaining_len as usize {
363        return Err(ErrorKind::InvalidMessageLength.into());
364    }
365
366    // Read headers
367    let mut header_bytes_read = 0;
368    let mut headers = Vec::new();
369    while header_bytes_read < header_len as usize {
370        let (header, bytes_read) = read_header_from(&mut crc_buffer)?;
371        header_bytes_read += bytes_read;
372        if header_bytes_read > header_len as usize {
373            return Err(ErrorKind::InvalidHeaderValue.into());
374        }
375        headers.push(header);
376    }
377
378    // Read payload
379    let payload_len = payload_len(total_len, header_len)?;
380    let payload = crc_buffer.copy_to_bytes(payload_len as usize);
381
382    let expected_crc = crc_buffer.into_crc();
383    let message_crc = buffer.get_u32();
384    if expected_crc != message_crc {
385        return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into());
386    }
387
388    Ok(Message::new_from_parts(headers, payload))
389}
390
391/// Writes the `message` to the given `buffer`.
392pub fn write_message_to(message: &Message, buffer: &mut dyn BufMut) -> Result<(), Error> {
393    let mut headers = Vec::new();
394    for header in message.headers() {
395        write_header_to(header, &mut headers)?;
396    }
397
398    let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?;
399    let payload_len = checked(message.payload().len(), ErrorKind::PayloadTooLong.into())?;
400    let message_len = [
401        PRELUDE_LENGTH_BYTES,
402        headers_len,
403        payload_len,
404        MESSAGE_CRC_LENGTH_BYTES,
405    ]
406    .iter()
407    .try_fold(0u32, |acc, v| {
408        acc.checked_add(*v)
409            .ok_or_else(|| Error::from(ErrorKind::MessageTooLong))
410    })?;
411
412    let mut crc_buffer = CrcBufMut::new(buffer);
413    crc_buffer.put_u32(message_len);
414    crc_buffer.put_u32(headers_len);
415    crc_buffer.put_crc();
416    crc_buffer.put(&headers[..]);
417    crc_buffer.put(&message.payload()[..]);
418    crc_buffer.put_crc();
419    Ok(())
420}
421
422fn checked<T: TryFrom<U>, U>(from: U, err: Error) -> Result<T, Error> {
423    T::try_from(from).map_err(|_| err)
424}
425
426fn max_header_len(total_len: u32) -> Result<u32, Error> {
427    total_len
428        .checked_sub(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
429        .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
430}
431
432fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> {
433    total_len
434        .checked_sub(
435            header_len
436                .checked_add(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
437                .ok_or_else(|| Error::from(ErrorKind::InvalidHeadersLength))?,
438        )
439        .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
440}
441
442#[cfg(test)]
443mod message_tests {
444    use super::read_message_from;
445    use crate::error::ErrorKind;
446    use crate::frame::{write_message_to, Header, HeaderValue, Message};
447    use aws_smithy_types::DateTime;
448    use bytes::Bytes;
449
450    macro_rules! read_message_expect_err {
451        ($bytes:expr, $err:pat) => {
452            let result = read_message_from(&mut Bytes::from_static($bytes));
453            let result = result.as_ref();
454            assert!(result.is_err(), "Expected error, got {:?}", result);
455            assert!(
456                matches!(result.err().unwrap().kind(), $err),
457                "Expected {}, got {:?}",
458                stringify!($err),
459                result
460            );
461        };
462    }
463
464    #[test]
465    fn invalid_messages() {
466        read_message_expect_err!(
467            include_bytes!("../test_data/invalid_header_string_value_length"),
468            ErrorKind::InvalidHeaderValue
469        );
470        read_message_expect_err!(
471            include_bytes!("../test_data/invalid_header_string_length_cut_off"),
472            ErrorKind::InvalidHeaderValue
473        );
474        read_message_expect_err!(
475            include_bytes!("../test_data/invalid_header_value_type"),
476            ErrorKind::InvalidHeaderValueType(0x60)
477        );
478        read_message_expect_err!(
479            include_bytes!("../test_data/invalid_header_name_length"),
480            ErrorKind::InvalidHeaderNameLength
481        );
482        read_message_expect_err!(
483            include_bytes!("../test_data/invalid_headers_length"),
484            ErrorKind::InvalidHeadersLength
485        );
486        read_message_expect_err!(
487            include_bytes!("../test_data/invalid_prelude_checksum"),
488            ErrorKind::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF)
489        );
490        read_message_expect_err!(
491            include_bytes!("../test_data/invalid_message_checksum"),
492            ErrorKind::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF)
493        );
494        read_message_expect_err!(
495            include_bytes!("../test_data/invalid_header_name_length_too_long"),
496            ErrorKind::InvalidUtf8String
497        );
498    }
499
500    #[test]
501    fn read_message_no_headers() {
502        // Test message taken from the CRT:
503        // https://github.com/awslabs/aws-c-event-stream/blob/main/tests/message_deserializer_test.c
504        let data: &'static [u8] = &[
505            0x00, 0x00, 0x00, 0x1D, 0x00, 0x00, 0x00, 0x00, 0xfd, 0x52, 0x8c, 0x5a, 0x7b, 0x27,
506            0x66, 0x6f, 0x6f, 0x27, 0x3a, 0x27, 0x62, 0x61, 0x72, 0x27, 0x7d, 0xc3, 0x65, 0x39,
507            0x36,
508        ];
509
510        let result = read_message_from(&mut Bytes::from_static(data)).unwrap();
511        assert_eq!(result.headers(), Vec::new());
512
513        let expected_payload = b"{'foo':'bar'}";
514        assert_eq!(expected_payload, result.payload().as_ref());
515    }
516
517    #[test]
518    fn read_message_one_header() {
519        // Test message taken from the CRT:
520        // https://github.com/awslabs/aws-c-event-stream/blob/main/tests/message_deserializer_test.c
521        let data: &'static [u8] = &[
522            0x00, 0x00, 0x00, 0x3D, 0x00, 0x00, 0x00, 0x20, 0x07, 0xFD, 0x83, 0x96, 0x0C, b'c',
523            b'o', b'n', b't', b'e', b'n', b't', b'-', b't', b'y', b'p', b'e', 0x07, 0x00, 0x10,
524            b'a', b'p', b'p', b'l', b'i', b'c', b'a', b't', b'i', b'o', b'n', b'/', b'j', b's',
525            b'o', b'n', 0x7b, 0x27, 0x66, 0x6f, 0x6f, 0x27, 0x3a, 0x27, 0x62, 0x61, 0x72, 0x27,
526            0x7d, 0x8D, 0x9C, 0x08, 0xB1,
527        ];
528
529        let result = read_message_from(&mut Bytes::from_static(data)).unwrap();
530        assert_eq!(
531            result.headers(),
532            vec![Header::new(
533                "content-type",
534                HeaderValue::String("application/json".into())
535            )]
536        );
537
538        let expected_payload = b"{'foo':'bar'}";
539        assert_eq!(expected_payload, result.payload().as_ref());
540    }
541
542    #[test]
543    fn read_all_headers_and_payload() {
544        let message = include_bytes!("../test_data/valid_with_all_headers_and_payload");
545        let result = read_message_from(&mut Bytes::from_static(message)).unwrap();
546        assert_eq!(
547            result.headers(),
548            vec![
549                Header::new("true", HeaderValue::Bool(true)),
550                Header::new("false", HeaderValue::Bool(false)),
551                Header::new("byte", HeaderValue::Byte(50)),
552                Header::new("short", HeaderValue::Int16(20_000)),
553                Header::new("int", HeaderValue::Int32(500_000)),
554                Header::new("long", HeaderValue::Int64(50_000_000_000)),
555                Header::new(
556                    "bytes",
557                    HeaderValue::ByteArray(Bytes::from(&b"some bytes"[..]))
558                ),
559                Header::new("str", HeaderValue::String("some str".into())),
560                Header::new(
561                    "time",
562                    HeaderValue::Timestamp(DateTime::from_secs(5_000_000))
563                ),
564                Header::new(
565                    "uuid",
566                    HeaderValue::Uuid(0xb79bc914_de21_4e13_b8b2_bc47e85b7f0b)
567                ),
568            ]
569        );
570
571        assert_eq!(b"some payload", result.payload().as_ref());
572    }
573
574    #[test]
575    fn round_trip_all_headers_payload() {
576        let message = Message::new(&b"some payload"[..])
577            .add_header(Header::new("true", HeaderValue::Bool(true)))
578            .add_header(Header::new("false", HeaderValue::Bool(false)))
579            .add_header(Header::new("byte", HeaderValue::Byte(50)))
580            .add_header(Header::new("short", HeaderValue::Int16(20_000)))
581            .add_header(Header::new("int", HeaderValue::Int32(500_000)))
582            .add_header(Header::new("long", HeaderValue::Int64(50_000_000_000)))
583            .add_header(Header::new(
584                "bytes",
585                HeaderValue::ByteArray((&b"some bytes"[..]).into()),
586            ))
587            .add_header(Header::new("str", HeaderValue::String("some str".into())))
588            .add_header(Header::new(
589                "time",
590                HeaderValue::Timestamp(DateTime::from_secs(5_000_000)),
591            ))
592            .add_header(Header::new(
593                "uuid",
594                HeaderValue::Uuid(0xb79bc914_de21_4e13_b8b2_bc47e85b7f0b),
595            ));
596
597        let mut actual = Vec::new();
598        write_message_to(&message, &mut actual).unwrap();
599
600        let expected = include_bytes!("../test_data/valid_with_all_headers_and_payload").to_vec();
601        assert_eq!(expected, actual);
602
603        let result = read_message_from(&mut Bytes::from(actual)).unwrap();
604        assert_eq!(message.headers(), result.headers());
605        assert_eq!(message.payload().as_ref(), result.payload().as_ref());
606    }
607}
608
609/// Return value from [`MessageFrameDecoder`].
610#[derive(Debug)]
611pub enum DecodedFrame {
612    /// There wasn't enough data in the buffer to decode a full message.
613    Incomplete,
614    /// There was enough data in the buffer to decode.
615    Complete(Message),
616}
617
618/// Streaming decoder for decoding a [`Message`] from a stream.
619#[non_exhaustive]
620#[derive(Default, Debug)]
621pub struct MessageFrameDecoder {
622    prelude: [u8; PRELUDE_LENGTH_BYTES_USIZE],
623    prelude_read: bool,
624}
625
626impl MessageFrameDecoder {
627    /// Returns a new `MessageFrameDecoder`.
628    pub fn new() -> Self {
629        Default::default()
630    }
631
632    /// Determines if the `buffer` has enough data in it to read a full frame.
633    /// Returns `Ok(None)` if there's not enough data, or `Some(remaining)` where
634    /// `remaining` is the number of bytes after the prelude that belong to the
635    /// message that's in the buffer.
636    fn remaining_bytes_if_frame_available<B: Buf>(
637        &self,
638        buffer: &B,
639    ) -> Result<Option<usize>, Error> {
640        if self.prelude_read {
641            let remaining_len = (&self.prelude[..])
642                .get_u32()
643                .checked_sub(PRELUDE_LENGTH_BYTES)
644                .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
645            if buffer.remaining() >= remaining_len as usize {
646                return Ok(Some(remaining_len as usize));
647            }
648        }
649        Ok(None)
650    }
651
652    /// Resets the decoder.
653    fn reset(&mut self) {
654        self.prelude_read = false;
655        self.prelude = [0u8; PRELUDE_LENGTH_BYTES_USIZE];
656    }
657
658    /// Attempts to decode a [`Message`] from the given `buffer`. This function expects
659    /// to be called over and over again with more data in the buffer each time its called.
660    /// When there's not enough data to decode a message, it returns `Ok(None)`.
661    ///
662    /// Once there is enough data to read a message prelude, then it will mutate the `Buf`
663    /// position. The state from the reading of the prelude is stored in the decoder so that
664    /// the next call will be able to decode the entire message, even though the prelude
665    /// is no longer available in the `Buf`.
666    pub fn decode_frame<B: Buf>(&mut self, mut buffer: B) -> Result<DecodedFrame, Error> {
667        if !self.prelude_read && buffer.remaining() >= PRELUDE_LENGTH_BYTES_USIZE {
668            buffer.copy_to_slice(&mut self.prelude);
669            self.prelude_read = true;
670        }
671
672        if let Some(remaining_len) = self.remaining_bytes_if_frame_available(&buffer)? {
673            let mut message_buf = (&self.prelude[..]).chain(buffer.take(remaining_len));
674            let result = read_message_from(&mut message_buf).map(DecodedFrame::Complete);
675            self.reset();
676            return result;
677        }
678
679        Ok(DecodedFrame::Incomplete)
680    }
681}
682
683#[cfg(test)]
684mod message_frame_decoder_tests {
685    use super::{DecodedFrame, MessageFrameDecoder};
686    use crate::frame::read_message_from;
687    use bytes::Bytes;
688    use bytes_utils::SegmentedBuf;
689
690    #[test]
691    fn single_streaming_message() {
692        let message = include_bytes!("../test_data/valid_with_all_headers_and_payload");
693
694        let mut decoder = MessageFrameDecoder::new();
695        let mut segmented = SegmentedBuf::new();
696        for i in 0..(message.len() - 1) {
697            segmented.push(&message[i..(i + 1)]);
698            if let DecodedFrame::Complete(_) = decoder.decode_frame(&mut segmented).unwrap() {
699                panic!("incomplete frame shouldn't result in message");
700            }
701        }
702
703        segmented.push(&message[(message.len() - 1)..]);
704        match decoder.decode_frame(&mut segmented).unwrap() {
705            DecodedFrame::Incomplete => panic!("frame should be complete now"),
706            DecodedFrame::Complete(actual) => {
707                let expected = read_message_from(&mut Bytes::from_static(message)).unwrap();
708                assert_eq!(expected, actual);
709            }
710        }
711    }
712
713    fn multiple_streaming_messages_chunk_size(chunk_size: usize) {
714        let message1 = include_bytes!("../test_data/valid_with_all_headers_and_payload");
715        let message2 = include_bytes!("../test_data/valid_empty_payload");
716        let message3 = include_bytes!("../test_data/valid_no_headers");
717        let mut repeated = message1.to_vec();
718        repeated.extend_from_slice(message2);
719        repeated.extend_from_slice(message3);
720
721        let mut decoder = MessageFrameDecoder::new();
722        let mut segmented = SegmentedBuf::new();
723        let mut decoded = Vec::new();
724        for window in repeated.chunks(chunk_size) {
725            segmented.push(window);
726            match dbg!(decoder.decode_frame(&mut segmented)).unwrap() {
727                DecodedFrame::Incomplete => {}
728                DecodedFrame::Complete(message) => {
729                    decoded.push(message);
730                }
731            }
732        }
733
734        let expected1 = read_message_from(&mut Bytes::from_static(message1)).unwrap();
735        let expected2 = read_message_from(&mut Bytes::from_static(message2)).unwrap();
736        let expected3 = read_message_from(&mut Bytes::from_static(message3)).unwrap();
737        assert_eq!(3, decoded.len());
738        assert_eq!(expected1, decoded[0]);
739        assert_eq!(expected2, decoded[1]);
740        assert_eq!(expected3, decoded[2]);
741    }
742
743    #[test]
744    fn multiple_streaming_messages() {
745        for chunk_size in 1..=11 {
746            println!("chunk size: {}", chunk_size);
747            multiple_streaming_messages_chunk_size(chunk_size);
748        }
749    }
750}
751
752#[cfg(test)]
753mod deferred_signer_tests {
754    use crate::frame::{DeferredSigner, Header, HeaderValue, Message, SignMessage};
755    use bytes::Bytes;
756
757    fn check_send_sync<T: Send + Sync>(value: T) -> T {
758        value
759    }
760
761    #[test]
762    fn deferred_signer() {
763        #[derive(Default, Debug)]
764        struct TestSigner {
765            call_num: i32,
766        }
767        impl SignMessage for TestSigner {
768            fn sign(
769                &mut self,
770                message: Message,
771            ) -> Result<Message, crate::frame::SignMessageError> {
772                self.call_num += 1;
773                Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num))))
774            }
775
776            fn sign_empty(&mut self) -> Option<Result<Message, crate::frame::SignMessageError>> {
777                None
778            }
779        }
780
781        let (mut signer, sender) = check_send_sync(DeferredSigner::new());
782
783        sender.send(Box::<TestSigner>::default()).expect("success");
784
785        let message = signer.sign(Message::new(Bytes::new())).expect("success");
786        assert_eq!(1, message.headers()[0].value().as_int32().unwrap());
787
788        let message = signer.sign(Message::new(Bytes::new())).expect("success");
789        assert_eq!(2, message.headers()[0].value().as_int32().unwrap());
790
791        assert!(signer.sign_empty().is_none());
792    }
793
794    #[test]
795    fn deferred_signer_defaults_to_noop_signer() {
796        let (mut signer, _sender) = DeferredSigner::new();
797        assert_eq!(
798            Message::new(Bytes::new()),
799            signer.sign(Message::new(Bytes::new())).unwrap()
800        );
801        assert!(signer.sign_empty().is_none());
802    }
803}