1use aws_smithy_eventstream::frame::{
7 DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21#[derive(Debug)]
23enum RecvBuf {
24 Empty,
26 Partial(SegmentedBuf<Bytes>),
29 EosPartial(SegmentedBuf<Bytes>),
31 Terminated,
33}
34
35impl RecvBuf {
36 fn has_data(&self) -> bool {
38 match self {
39 RecvBuf::Empty | RecvBuf::Terminated => false,
40 RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41 }
42 }
43
44 fn is_eos(&self) -> bool {
46 matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47 }
48
49 fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51 match self {
52 RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53 RecvBuf::Partial(segmented) => segmented,
54 RecvBuf::EosPartial(segmented) => segmented,
55 RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56 }
57 }
58
59 fn with_partial(self, partial: Bytes) -> Self {
62 match self {
63 RecvBuf::Empty => {
64 let mut segmented = SegmentedBuf::new();
65 segmented.push(partial);
66 RecvBuf::Partial(segmented)
67 }
68 RecvBuf::Partial(mut segmented) => {
69 segmented.push(partial);
70 RecvBuf::Partial(segmented)
71 }
72 RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73 panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74 }
75 }
76 }
77
78 fn ended(self) -> Self {
80 match self {
81 RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82 RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83 RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84 RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85 }
86 }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91 UnexpectedEndOfStream,
93}
94
95#[derive(Debug)]
97pub struct ReceiverError {
98 kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 match self.kind {
104 ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105 }
106 }
107}
108
109impl StdError for ReceiverError {}
110
111#[derive(Debug)]
113pub struct Receiver<T, E> {
114 unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115 decoder: MessageFrameDecoder,
116 buffer: RecvBuf,
117 body: SdkBody,
118 buffered_message: Option<Message>,
123 _phantom: PhantomData<E>,
124}
125
126#[doc(hidden)]
128#[non_exhaustive]
129pub enum InitialMessageType {
130 Request,
131 Response,
132}
133
134impl InitialMessageType {
135 fn as_str(&self) -> &'static str {
136 match self {
137 InitialMessageType::Request => "initial-request",
138 InitialMessageType::Response => "initial-response",
139 }
140 }
141}
142
143impl<T, E> Receiver<T, E> {
144 pub fn new(
146 unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
147 body: SdkBody,
148 ) -> Self {
149 Receiver {
150 unmarshaller: Box::new(unmarshaller),
151 decoder: MessageFrameDecoder::new(),
152 buffer: RecvBuf::Empty,
153 body,
154 buffered_message: None,
155 _phantom: Default::default(),
156 }
157 }
158
159 fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
160 match self.unmarshaller.unmarshall(&message) {
161 Ok(unmarshalled) => match unmarshalled {
162 UnmarshalledMessage::Event(event) => Ok(Some(event)),
163 UnmarshalledMessage::Error(err) => {
164 Err(SdkError::service_error(err, RawMessage::Decoded(message)))
165 }
166 },
167 Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
168 }
169 }
170
171 async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
172 use http_body_04x::Body;
173
174 if !self.buffer.is_eos() {
175 let next_chunk = self
176 .body
177 .data()
178 .await
179 .transpose()
180 .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
181 let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
182 if let Some(chunk) = next_chunk {
183 self.buffer = buffer.with_partial(chunk);
184 } else {
185 self.buffer = buffer.ended();
186 }
187 }
188 Ok(())
189 }
190
191 async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
192 while !self.buffer.is_eos() {
193 if self.buffer.has_data() {
194 if let DecodedFrame::Complete(message) = self
195 .decoder
196 .decode_frame(self.buffer.buffered())
197 .map_err(|err| {
198 SdkError::response_error(
199 err,
200 RawMessage::Invalid(None),
202 )
203 })?
204 {
205 trace!(message = ?message, "received complete event stream message");
206 return Ok(Some(message));
207 }
208 }
209
210 self.buffer_next_chunk().await?;
211 }
212 if self.buffer.has_data() {
213 trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
214 let buf = self.buffer.buffered();
215 return Err(SdkError::response_error(
216 ReceiverError {
217 kind: ReceiverErrorKind::UnexpectedEndOfStream,
218 },
219 RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
220 ));
221 }
222 Ok(None)
223 }
224
225 #[doc(hidden)]
228 pub async fn try_recv_initial(
229 &mut self,
230 message_type: InitialMessageType,
231 ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
232 if let Some(message) = self.next_message().await? {
233 if let Some(event_type) = message
234 .headers()
235 .iter()
236 .find(|h| h.name().as_str() == ":event-type")
237 {
238 if event_type
239 .value()
240 .as_string()
241 .map(|s| s.as_str() == message_type.as_str())
242 .unwrap_or(false)
243 {
244 return Ok(Some(message));
245 }
246 }
247 self.buffered_message = Some(message);
249 }
250 Ok(None)
251 }
252
253 pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
258 if let Some(buffered) = self.buffered_message.take() {
259 return match self.unmarshall(buffered) {
260 Ok(message) => Ok(message),
261 Err(error) => {
262 self.buffer = RecvBuf::Terminated;
263 Err(error)
264 }
265 };
266 }
267 if let Some(message) = self.next_message().await? {
268 match self.unmarshall(message) {
269 Ok(message) => Ok(message),
270 Err(error) => {
271 self.buffer = RecvBuf::Terminated;
272 Err(error)
273 }
274 }
275 } else {
276 Ok(None)
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::{InitialMessageType, Receiver, UnmarshallMessage};
284 use aws_smithy_eventstream::error::Error as EventStreamError;
285 use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
286 use aws_smithy_runtime_api::client::result::SdkError;
287 use aws_smithy_types::body::SdkBody;
288 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
289 use bytes::Bytes;
290 use hyper::body::Body;
291 use std::error::Error as StdError;
292 use std::io::{Error as IOError, ErrorKind};
293
294 fn encode_initial_response() -> Bytes {
295 let mut buffer = Vec::new();
296 let message = Message::new(Bytes::new())
297 .add_header(Header::new(
298 ":message-type",
299 HeaderValue::String("event".into()),
300 ))
301 .add_header(Header::new(
302 ":event-type",
303 HeaderValue::String("initial-response".into()),
304 ));
305 write_message_to(&message, &mut buffer).unwrap();
306 buffer.into()
307 }
308
309 fn encode_message(message: &str) -> Bytes {
310 let mut buffer = Vec::new();
311 let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
312 write_message_to(&message, &mut buffer).unwrap();
313 buffer.into()
314 }
315
316 #[derive(Debug)]
317 struct FakeError;
318 impl std::fmt::Display for FakeError {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 write!(f, "FakeError")
321 }
322 }
323 impl StdError for FakeError {}
324
325 #[derive(Debug, Eq, PartialEq)]
326 struct TestMessage(String);
327
328 #[derive(Debug)]
329 struct Unmarshaller;
330 impl UnmarshallMessage for Unmarshaller {
331 type Output = TestMessage;
332 type Error = EventStreamError;
333
334 fn unmarshall(
335 &self,
336 message: &Message,
337 ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
338 Ok(UnmarshalledMessage::Event(TestMessage(
339 std::str::from_utf8(&message.payload()[..]).unwrap().into(),
340 )))
341 }
342 }
343
344 #[tokio::test]
345 async fn receive_success() {
346 let chunks: Vec<Result<_, IOError>> =
347 vec![Ok(encode_message("one")), Ok(encode_message("two"))];
348 let chunk_stream = futures_util::stream::iter(chunks);
349 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
350 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
351 assert_eq!(
352 TestMessage("one".into()),
353 receiver.recv().await.unwrap().unwrap()
354 );
355 assert_eq!(
356 TestMessage("two".into()),
357 receiver.recv().await.unwrap().unwrap()
358 );
359 assert_eq!(None, receiver.recv().await.unwrap());
360 }
361
362 #[tokio::test]
363 async fn receive_last_chunk_empty() {
364 let chunks: Vec<Result<_, IOError>> = vec![
365 Ok(encode_message("one")),
366 Ok(encode_message("two")),
367 Ok(Bytes::from_static(&[])),
368 ];
369 let chunk_stream = futures_util::stream::iter(chunks);
370 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
371 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
372 assert_eq!(
373 TestMessage("one".into()),
374 receiver.recv().await.unwrap().unwrap()
375 );
376 assert_eq!(
377 TestMessage("two".into()),
378 receiver.recv().await.unwrap().unwrap()
379 );
380 assert_eq!(None, receiver.recv().await.unwrap());
381 }
382
383 #[tokio::test]
384 async fn receive_last_chunk_not_full_message() {
385 let chunks: Vec<Result<_, IOError>> = vec![
386 Ok(encode_message("one")),
387 Ok(encode_message("two")),
388 Ok(encode_message("three").split_to(10)),
389 ];
390 let chunk_stream = futures_util::stream::iter(chunks);
391 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
392 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
393 assert_eq!(
394 TestMessage("one".into()),
395 receiver.recv().await.unwrap().unwrap()
396 );
397 assert_eq!(
398 TestMessage("two".into()),
399 receiver.recv().await.unwrap().unwrap()
400 );
401 assert!(matches!(
402 receiver.recv().await,
403 Err(SdkError::ResponseError { .. }),
404 ));
405 }
406
407 #[tokio::test]
408 async fn receive_last_chunk_has_multiple_messages() {
409 let chunks: Vec<Result<_, IOError>> = vec![
410 Ok(encode_message("one")),
411 Ok(encode_message("two")),
412 Ok(Bytes::from(
413 [encode_message("three"), encode_message("four")].concat(),
414 )),
415 ];
416 let chunk_stream = futures_util::stream::iter(chunks);
417 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
418 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
419 assert_eq!(
420 TestMessage("one".into()),
421 receiver.recv().await.unwrap().unwrap()
422 );
423 assert_eq!(
424 TestMessage("two".into()),
425 receiver.recv().await.unwrap().unwrap()
426 );
427 assert_eq!(
428 TestMessage("three".into()),
429 receiver.recv().await.unwrap().unwrap()
430 );
431 assert_eq!(
432 TestMessage("four".into()),
433 receiver.recv().await.unwrap().unwrap()
434 );
435 assert_eq!(None, receiver.recv().await.unwrap());
436 }
437
438 proptest::proptest! {
439 #[test]
440 fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
441 let combined = Bytes::from([
442 encode_message("one"),
443 encode_message("two"),
444 encode_message("three"),
445 encode_message("four"),
446 encode_message("five"),
447 encode_message("six"),
448 encode_message("seven"),
449 encode_message("eight"),
450 ].concat());
451
452 let midpoint = combined.len() / 2;
453 let (start, boundary1, boundary2, end) = (
454 0,
455 b1 % midpoint,
456 midpoint + b2 % midpoint,
457 combined.len()
458 );
459 println!("[{}, {}], [{}, {}], [{}, {}]", start, boundary1, boundary1, boundary2, boundary2, end);
460
461 let rt = tokio::runtime::Runtime::new().unwrap();
462 rt.block_on(async move {
463 let chunks: Vec<Result<_, IOError>> = vec![
464 Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
465 Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
466 Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
467 ];
468
469 let chunk_stream = futures_util::stream::iter(chunks);
470 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
471 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
472 for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
473 assert_eq!(
474 TestMessage((*payload).into()),
475 receiver.recv().await.unwrap().unwrap()
476 );
477 }
478 assert_eq!(None, receiver.recv().await.unwrap());
479 });
480 }
481 }
482
483 #[tokio::test]
484 async fn receive_network_failure() {
485 let chunks: Vec<Result<_, IOError>> = vec![
486 Ok(encode_message("one")),
487 Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
488 ];
489 let chunk_stream = futures_util::stream::iter(chunks);
490 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
491 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
492 assert_eq!(
493 TestMessage("one".into()),
494 receiver.recv().await.unwrap().unwrap()
495 );
496 assert!(matches!(
497 receiver.recv().await,
498 Err(SdkError::DispatchFailure(_))
499 ));
500 }
501
502 #[tokio::test]
503 async fn receive_message_parse_failure() {
504 let chunks: Vec<Result<_, IOError>> = vec![
505 Ok(encode_message("one")),
506 Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
509 ];
510 let chunk_stream = futures_util::stream::iter(chunks);
511 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
512 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
513 assert_eq!(
514 TestMessage("one".into()),
515 receiver.recv().await.unwrap().unwrap()
516 );
517 assert!(matches!(
518 receiver.recv().await,
519 Err(SdkError::ResponseError { .. })
520 ));
521 }
522
523 #[tokio::test]
524 async fn receive_initial_response() {
525 let chunks: Vec<Result<_, IOError>> =
526 vec![Ok(encode_initial_response()), Ok(encode_message("one"))];
527 let chunk_stream = futures_util::stream::iter(chunks);
528 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
529 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
530 assert!(receiver
531 .try_recv_initial(InitialMessageType::Response)
532 .await
533 .unwrap()
534 .is_some());
535 assert_eq!(
536 TestMessage("one".into()),
537 receiver.recv().await.unwrap().unwrap()
538 );
539 }
540
541 #[tokio::test]
542 async fn receive_no_initial_response() {
543 let chunks: Vec<Result<_, IOError>> =
544 vec![Ok(encode_message("one")), Ok(encode_message("two"))];
545 let chunk_stream = futures_util::stream::iter(chunks);
546 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
547 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
548 assert!(receiver
549 .try_recv_initial(InitialMessageType::Response)
550 .await
551 .unwrap()
552 .is_none());
553 assert_eq!(
554 TestMessage("one".into()),
555 receiver.recv().await.unwrap().unwrap()
556 );
557 assert_eq!(
558 TestMessage("two".into()),
559 receiver.recv().await.unwrap().unwrap()
560 );
561 }
562
563 fn assert_send_and_sync<T: Send + Sync>() {}
564
565 #[tokio::test]
566 async fn receiver_is_send_and_sync() {
567 assert_send_and_sync::<Receiver<(), ()>>();
568 }
569}