aws_smithy_http/event_stream/
receiver.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{
7    DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21/// Wrapper around SegmentedBuf that tracks the state of the stream.
22#[derive(Debug)]
23enum RecvBuf {
24    /// Nothing has been buffered yet.
25    Empty,
26    /// Some data has been buffered.
27    /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary.
28    Partial(SegmentedBuf<Bytes>),
29    /// The end of the stream has been reached, but there may still be some buffered data.
30    EosPartial(SegmentedBuf<Bytes>),
31    /// An exception terminated this stream.
32    Terminated,
33}
34
35impl RecvBuf {
36    /// Returns true if there's more buffered data.
37    fn has_data(&self) -> bool {
38        match self {
39            RecvBuf::Empty | RecvBuf::Terminated => false,
40            RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41        }
42    }
43
44    /// Returns true if the stream has ended.
45    fn is_eos(&self) -> bool {
46        matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47    }
48
49    /// Returns a mutable reference to the underlying buffered data.
50    fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51        match self {
52            RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53            RecvBuf::Partial(segmented) => segmented,
54            RecvBuf::EosPartial(segmented) => segmented,
55            RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56        }
57    }
58
59    /// Returns a new `RecvBuf` with additional data buffered. This will only allocate
60    /// if the `RecvBuf` was previously empty.
61    fn with_partial(self, partial: Bytes) -> Self {
62        match self {
63            RecvBuf::Empty => {
64                let mut segmented = SegmentedBuf::new();
65                segmented.push(partial);
66                RecvBuf::Partial(segmented)
67            }
68            RecvBuf::Partial(mut segmented) => {
69                segmented.push(partial);
70                RecvBuf::Partial(segmented)
71            }
72            RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73                panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74            }
75        }
76    }
77
78    /// Returns a `RecvBuf` that has reached end of stream.
79    fn ended(self) -> Self {
80        match self {
81            RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82            RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83            RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84            RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85        }
86    }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91    /// The stream ended before a complete message frame was received.
92    UnexpectedEndOfStream,
93}
94
95/// An error that occurs within an event stream receiver.
96#[derive(Debug)]
97pub struct ReceiverError {
98    kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        match self.kind {
104            ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105        }
106    }
107}
108
109impl StdError for ReceiverError {}
110
111/// Receives Smithy-modeled messages out of an Event Stream.
112#[derive(Debug)]
113pub struct Receiver<T, E> {
114    unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115    decoder: MessageFrameDecoder,
116    buffer: RecvBuf,
117    body: SdkBody,
118    /// Event Stream has optional initial response frames an with `:message-type` of
119    /// `initial-response`. If `try_recv_initial()` is called and the next message isn't an
120    /// initial response, then the message will be stored in `buffered_message` so that it can
121    /// be returned with the next call of `recv()`.
122    buffered_message: Option<Message>,
123    _phantom: PhantomData<E>,
124}
125
126// Used by `Receiver::try_recv_initial`, hence this enum is also doc hidden
127#[doc(hidden)]
128#[non_exhaustive]
129pub enum InitialMessageType {
130    Request,
131    Response,
132}
133
134impl InitialMessageType {
135    fn as_str(&self) -> &'static str {
136        match self {
137            InitialMessageType::Request => "initial-request",
138            InitialMessageType::Response => "initial-response",
139        }
140    }
141}
142
143impl<T, E> Receiver<T, E> {
144    /// Creates a new `Receiver` with the given message unmarshaller and SDK body.
145    pub fn new(
146        unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
147        body: SdkBody,
148    ) -> Self {
149        Receiver {
150            unmarshaller: Box::new(unmarshaller),
151            decoder: MessageFrameDecoder::new(),
152            buffer: RecvBuf::Empty,
153            body,
154            buffered_message: None,
155            _phantom: Default::default(),
156        }
157    }
158
159    fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
160        match self.unmarshaller.unmarshall(&message) {
161            Ok(unmarshalled) => match unmarshalled {
162                UnmarshalledMessage::Event(event) => Ok(Some(event)),
163                UnmarshalledMessage::Error(err) => {
164                    Err(SdkError::service_error(err, RawMessage::Decoded(message)))
165                }
166            },
167            Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
168        }
169    }
170
171    async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
172        use http_body_04x::Body;
173
174        if !self.buffer.is_eos() {
175            let next_chunk = self
176                .body
177                .data()
178                .await
179                .transpose()
180                .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
181            let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
182            if let Some(chunk) = next_chunk {
183                self.buffer = buffer.with_partial(chunk);
184            } else {
185                self.buffer = buffer.ended();
186            }
187        }
188        Ok(())
189    }
190
191    async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
192        while !self.buffer.is_eos() {
193            if self.buffer.has_data() {
194                if let DecodedFrame::Complete(message) = self
195                    .decoder
196                    .decode_frame(self.buffer.buffered())
197                    .map_err(|err| {
198                        SdkError::response_error(
199                            err,
200                            // the buffer has been consumed
201                            RawMessage::Invalid(None),
202                        )
203                    })?
204                {
205                    trace!(message = ?message, "received complete event stream message");
206                    return Ok(Some(message));
207                }
208            }
209
210            self.buffer_next_chunk().await?;
211        }
212        if self.buffer.has_data() {
213            trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
214            let buf = self.buffer.buffered();
215            return Err(SdkError::response_error(
216                ReceiverError {
217                    kind: ReceiverErrorKind::UnexpectedEndOfStream,
218                },
219                RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
220            ));
221        }
222        Ok(None)
223    }
224
225    /// Tries to receive the initial response message that has `:event-type` of a given `message_type`.
226    /// If a different event type is received, then it is buffered and `Ok(None)` is returned.
227    #[doc(hidden)]
228    pub async fn try_recv_initial(
229        &mut self,
230        message_type: InitialMessageType,
231    ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
232        if let Some(message) = self.next_message().await? {
233            if let Some(event_type) = message
234                .headers()
235                .iter()
236                .find(|h| h.name().as_str() == ":event-type")
237            {
238                if event_type
239                    .value()
240                    .as_string()
241                    .map(|s| s.as_str() == message_type.as_str())
242                    .unwrap_or(false)
243                {
244                    return Ok(Some(message));
245                }
246            }
247            // Buffer the message so that it can be returned by the next call to `recv()`
248            self.buffered_message = Some(message);
249        }
250        Ok(None)
251    }
252
253    /// Asynchronously tries to receive a message from the stream. If the stream has ended,
254    /// it returns an `Ok(None)`. If there is a transport layer error, it will return
255    /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
256    /// messages.
257    pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
258        if let Some(buffered) = self.buffered_message.take() {
259            return match self.unmarshall(buffered) {
260                Ok(message) => Ok(message),
261                Err(error) => {
262                    self.buffer = RecvBuf::Terminated;
263                    Err(error)
264                }
265            };
266        }
267        if let Some(message) = self.next_message().await? {
268            match self.unmarshall(message) {
269                Ok(message) => Ok(message),
270                Err(error) => {
271                    self.buffer = RecvBuf::Terminated;
272                    Err(error)
273                }
274            }
275        } else {
276            Ok(None)
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::{InitialMessageType, Receiver, UnmarshallMessage};
284    use aws_smithy_eventstream::error::Error as EventStreamError;
285    use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
286    use aws_smithy_runtime_api::client::result::SdkError;
287    use aws_smithy_types::body::SdkBody;
288    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
289    use bytes::Bytes;
290    use hyper::body::Body;
291    use std::error::Error as StdError;
292    use std::io::{Error as IOError, ErrorKind};
293
294    fn encode_initial_response() -> Bytes {
295        let mut buffer = Vec::new();
296        let message = Message::new(Bytes::new())
297            .add_header(Header::new(
298                ":message-type",
299                HeaderValue::String("event".into()),
300            ))
301            .add_header(Header::new(
302                ":event-type",
303                HeaderValue::String("initial-response".into()),
304            ));
305        write_message_to(&message, &mut buffer).unwrap();
306        buffer.into()
307    }
308
309    fn encode_message(message: &str) -> Bytes {
310        let mut buffer = Vec::new();
311        let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
312        write_message_to(&message, &mut buffer).unwrap();
313        buffer.into()
314    }
315
316    #[derive(Debug)]
317    struct FakeError;
318    impl std::fmt::Display for FakeError {
319        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320            write!(f, "FakeError")
321        }
322    }
323    impl StdError for FakeError {}
324
325    #[derive(Debug, Eq, PartialEq)]
326    struct TestMessage(String);
327
328    #[derive(Debug)]
329    struct Unmarshaller;
330    impl UnmarshallMessage for Unmarshaller {
331        type Output = TestMessage;
332        type Error = EventStreamError;
333
334        fn unmarshall(
335            &self,
336            message: &Message,
337        ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
338            Ok(UnmarshalledMessage::Event(TestMessage(
339                std::str::from_utf8(&message.payload()[..]).unwrap().into(),
340            )))
341        }
342    }
343
344    #[tokio::test]
345    async fn receive_success() {
346        let chunks: Vec<Result<_, IOError>> =
347            vec![Ok(encode_message("one")), Ok(encode_message("two"))];
348        let chunk_stream = futures_util::stream::iter(chunks);
349        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
350        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
351        assert_eq!(
352            TestMessage("one".into()),
353            receiver.recv().await.unwrap().unwrap()
354        );
355        assert_eq!(
356            TestMessage("two".into()),
357            receiver.recv().await.unwrap().unwrap()
358        );
359        assert_eq!(None, receiver.recv().await.unwrap());
360    }
361
362    #[tokio::test]
363    async fn receive_last_chunk_empty() {
364        let chunks: Vec<Result<_, IOError>> = vec![
365            Ok(encode_message("one")),
366            Ok(encode_message("two")),
367            Ok(Bytes::from_static(&[])),
368        ];
369        let chunk_stream = futures_util::stream::iter(chunks);
370        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
371        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
372        assert_eq!(
373            TestMessage("one".into()),
374            receiver.recv().await.unwrap().unwrap()
375        );
376        assert_eq!(
377            TestMessage("two".into()),
378            receiver.recv().await.unwrap().unwrap()
379        );
380        assert_eq!(None, receiver.recv().await.unwrap());
381    }
382
383    #[tokio::test]
384    async fn receive_last_chunk_not_full_message() {
385        let chunks: Vec<Result<_, IOError>> = vec![
386            Ok(encode_message("one")),
387            Ok(encode_message("two")),
388            Ok(encode_message("three").split_to(10)),
389        ];
390        let chunk_stream = futures_util::stream::iter(chunks);
391        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
392        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
393        assert_eq!(
394            TestMessage("one".into()),
395            receiver.recv().await.unwrap().unwrap()
396        );
397        assert_eq!(
398            TestMessage("two".into()),
399            receiver.recv().await.unwrap().unwrap()
400        );
401        assert!(matches!(
402            receiver.recv().await,
403            Err(SdkError::ResponseError { .. }),
404        ));
405    }
406
407    #[tokio::test]
408    async fn receive_last_chunk_has_multiple_messages() {
409        let chunks: Vec<Result<_, IOError>> = vec![
410            Ok(encode_message("one")),
411            Ok(encode_message("two")),
412            Ok(Bytes::from(
413                [encode_message("three"), encode_message("four")].concat(),
414            )),
415        ];
416        let chunk_stream = futures_util::stream::iter(chunks);
417        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
418        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
419        assert_eq!(
420            TestMessage("one".into()),
421            receiver.recv().await.unwrap().unwrap()
422        );
423        assert_eq!(
424            TestMessage("two".into()),
425            receiver.recv().await.unwrap().unwrap()
426        );
427        assert_eq!(
428            TestMessage("three".into()),
429            receiver.recv().await.unwrap().unwrap()
430        );
431        assert_eq!(
432            TestMessage("four".into()),
433            receiver.recv().await.unwrap().unwrap()
434        );
435        assert_eq!(None, receiver.recv().await.unwrap());
436    }
437
438    proptest::proptest! {
439        #[test]
440        fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
441            let combined = Bytes::from([
442                encode_message("one"),
443                encode_message("two"),
444                encode_message("three"),
445                encode_message("four"),
446                encode_message("five"),
447                encode_message("six"),
448                encode_message("seven"),
449                encode_message("eight"),
450            ].concat());
451
452            let midpoint = combined.len() / 2;
453            let (start, boundary1, boundary2, end) = (
454                0,
455                b1 % midpoint,
456                midpoint + b2 % midpoint,
457                combined.len()
458            );
459            println!("[{}, {}], [{}, {}], [{}, {}]", start, boundary1, boundary1, boundary2, boundary2, end);
460
461            let rt = tokio::runtime::Runtime::new().unwrap();
462            rt.block_on(async move {
463                let chunks: Vec<Result<_, IOError>> = vec![
464                    Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
465                    Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
466                    Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
467                ];
468
469                let chunk_stream = futures_util::stream::iter(chunks);
470                let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
471                let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
472                for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
473                    assert_eq!(
474                        TestMessage((*payload).into()),
475                        receiver.recv().await.unwrap().unwrap()
476                    );
477                }
478                assert_eq!(None, receiver.recv().await.unwrap());
479            });
480        }
481    }
482
483    #[tokio::test]
484    async fn receive_network_failure() {
485        let chunks: Vec<Result<_, IOError>> = vec![
486            Ok(encode_message("one")),
487            Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
488        ];
489        let chunk_stream = futures_util::stream::iter(chunks);
490        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
491        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
492        assert_eq!(
493            TestMessage("one".into()),
494            receiver.recv().await.unwrap().unwrap()
495        );
496        assert!(matches!(
497            receiver.recv().await,
498            Err(SdkError::DispatchFailure(_))
499        ));
500    }
501
502    #[tokio::test]
503    async fn receive_message_parse_failure() {
504        let chunks: Vec<Result<_, IOError>> = vec![
505            Ok(encode_message("one")),
506            // A zero length message will be invalid. We need to provide a minimum of 12 bytes
507            // for the MessageFrameDecoder to actually start parsing it.
508            Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
509        ];
510        let chunk_stream = futures_util::stream::iter(chunks);
511        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
512        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
513        assert_eq!(
514            TestMessage("one".into()),
515            receiver.recv().await.unwrap().unwrap()
516        );
517        assert!(matches!(
518            receiver.recv().await,
519            Err(SdkError::ResponseError { .. })
520        ));
521    }
522
523    #[tokio::test]
524    async fn receive_initial_response() {
525        let chunks: Vec<Result<_, IOError>> =
526            vec![Ok(encode_initial_response()), Ok(encode_message("one"))];
527        let chunk_stream = futures_util::stream::iter(chunks);
528        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
529        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
530        assert!(receiver
531            .try_recv_initial(InitialMessageType::Response)
532            .await
533            .unwrap()
534            .is_some());
535        assert_eq!(
536            TestMessage("one".into()),
537            receiver.recv().await.unwrap().unwrap()
538        );
539    }
540
541    #[tokio::test]
542    async fn receive_no_initial_response() {
543        let chunks: Vec<Result<_, IOError>> =
544            vec![Ok(encode_message("one")), Ok(encode_message("two"))];
545        let chunk_stream = futures_util::stream::iter(chunks);
546        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
547        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
548        assert!(receiver
549            .try_recv_initial(InitialMessageType::Response)
550            .await
551            .unwrap()
552            .is_none());
553        assert_eq!(
554            TestMessage("one".into()),
555            receiver.recv().await.unwrap().unwrap()
556        );
557        assert_eq!(
558            TestMessage("two".into()),
559            receiver.recv().await.unwrap().unwrap()
560        );
561    }
562
563    fn assert_send_and_sync<T: Send + Sync>() {}
564
565    #[tokio::test]
566    async fn receiver_is_send_and_sync() {
567        assert_send_and_sync::<Receiver<(), ()>>();
568    }
569}