aws_smithy_http/event_stream/
sender.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_runtime_api::client::result::SdkError;
8use aws_smithy_types::error::ErrorMetadata;
9use bytes::Bytes;
10use futures_core::Stream;
11use std::error::Error as StdError;
12use std::fmt;
13use std::fmt::Debug;
14use std::marker::PhantomData;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tracing::trace;
18
19/// Input type for Event Streams.
20pub struct EventStreamSender<T, E> {
21    input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
22}
23
24impl<T, E> Debug for EventStreamSender<T, E> {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        let name_t = std::any::type_name::<T>();
27        let name_e = std::any::type_name::<E>();
28        write!(f, "EventStreamSender<{name_t}, {name_e}>")
29    }
30}
31
32impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
33    #[doc(hidden)]
34    pub fn into_body_stream(
35        self,
36        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
37        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
38        signer: impl SignMessage + Send + Sync + 'static,
39    ) -> MessageStreamAdapter<T, E> {
40        MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
41    }
42}
43
44impl<T, E, S> From<S> for EventStreamSender<T, E>
45where
46    S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
47{
48    fn from(stream: S) -> Self {
49        EventStreamSender {
50            input_stream: Box::pin(stream),
51        }
52    }
53}
54
55/// An error that occurs within a message stream.
56#[derive(Debug)]
57pub struct MessageStreamError {
58    kind: MessageStreamErrorKind,
59    pub(crate) meta: ErrorMetadata,
60}
61
62#[derive(Debug)]
63enum MessageStreamErrorKind {
64    Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
65}
66
67impl MessageStreamError {
68    /// Creates the `MessageStreamError::Unhandled` variant from any error type.
69    pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
70        Self {
71            meta: Default::default(),
72            kind: MessageStreamErrorKind::Unhandled(err.into()),
73        }
74    }
75
76    /// Creates the `MessageStreamError::Unhandled` variant from an [`ErrorMetadata`].
77    pub fn generic(err: ErrorMetadata) -> Self {
78        Self {
79            meta: err.clone(),
80            kind: MessageStreamErrorKind::Unhandled(err.into()),
81        }
82    }
83
84    /// Returns error metadata, which includes the error code, message,
85    /// request ID, and potentially additional information.
86    pub fn meta(&self) -> &ErrorMetadata {
87        &self.meta
88    }
89}
90
91impl StdError for MessageStreamError {
92    fn source(&self) -> Option<&(dyn StdError + 'static)> {
93        match &self.kind {
94            MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
95        }
96    }
97}
98
99impl fmt::Display for MessageStreamError {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        match &self.kind {
102            MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
103        }
104    }
105}
106
107/// Adapts a `Stream<SmithyMessageType>` to a signed `Stream<Bytes>` by using the provided
108/// message marshaller and signer implementations.
109///
110/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be
111/// marshalled into an Event Stream frame, (e.g., if the message payload was too large).
112#[allow(missing_debug_implementations)]
113pub struct MessageStreamAdapter<T, E: StdError + Send + Sync + 'static> {
114    marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
115    error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
116    signer: Box<dyn SignMessage + Send + Sync>,
117    stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
118    end_signal_sent: bool,
119    _phantom: PhantomData<E>,
120}
121
122impl<T, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
123
124impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
125    /// Create a new `MessageStreamAdapter`.
126    pub fn new(
127        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
128        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
129        signer: impl SignMessage + Send + Sync + 'static,
130        stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
131    ) -> Self {
132        MessageStreamAdapter {
133            marshaller: Box::new(marshaller),
134            error_marshaller: Box::new(error_marshaller),
135            signer: Box::new(signer),
136            stream,
137            end_signal_sent: false,
138            _phantom: Default::default(),
139        }
140    }
141}
142
143impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
144    type Item =
145        Result<Bytes, SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>;
146
147    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
148        match self.stream.as_mut().poll_next(cx) {
149            Poll::Ready(message_option) => {
150                if let Some(message_result) = message_option {
151                    let message = match message_result {
152                        Ok(message) => self
153                            .marshaller
154                            .marshall(message)
155                            .map_err(SdkError::construction_failure)?,
156                        Err(message) => self
157                            .error_marshaller
158                            .marshall(message)
159                            .map_err(SdkError::construction_failure)?,
160                    };
161
162                    trace!(unsigned_message = ?message, "signing event stream message");
163                    let message = self
164                        .signer
165                        .sign(message)
166                        .map_err(SdkError::construction_failure)?;
167
168                    let mut buffer = Vec::new();
169                    write_message_to(&message, &mut buffer)
170                        .map_err(SdkError::construction_failure)?;
171                    trace!(signed_message = ?buffer, "sending signed event stream message");
172                    Poll::Ready(Some(Ok(Bytes::from(buffer))))
173                } else if !self.end_signal_sent {
174                    self.end_signal_sent = true;
175                    let mut buffer = Vec::new();
176                    match self.signer.sign_empty() {
177                        Some(sign) => {
178                            let message = sign.map_err(SdkError::construction_failure)?;
179                            write_message_to(&message, &mut buffer)
180                                .map_err(SdkError::construction_failure)?;
181                            trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
182                            Poll::Ready(Some(Ok(Bytes::from(buffer))))
183                        }
184                        None => Poll::Ready(None),
185                    }
186                } else {
187                    Poll::Ready(None)
188                }
189            }
190            Poll::Pending => Poll::Pending,
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::MarshallMessage;
198    use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
199    use async_stream::stream;
200    use aws_smithy_eventstream::error::Error as EventStreamError;
201    use aws_smithy_eventstream::frame::{
202        read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
203    };
204    use aws_smithy_runtime_api::client::result::SdkError;
205    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
206    use bytes::Bytes;
207    use futures_core::Stream;
208    use futures_util::stream::StreamExt;
209    use std::error::Error as StdError;
210
211    #[derive(Debug, Eq, PartialEq)]
212    struct TestMessage(String);
213
214    #[derive(Debug)]
215    struct Marshaller;
216    impl MarshallMessage for Marshaller {
217        type Input = TestMessage;
218
219        fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
220            Ok(Message::new(input.0.as_bytes().to_vec()))
221        }
222    }
223    #[derive(Debug)]
224    struct ErrorMarshaller;
225    impl MarshallMessage for ErrorMarshaller {
226        type Input = TestServiceError;
227
228        fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
229            Err(read_message_from(&b""[..]).expect_err("this should always fail"))
230        }
231    }
232
233    #[derive(Debug)]
234    struct TestServiceError;
235    impl std::fmt::Display for TestServiceError {
236        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237            write!(f, "TestServiceError")
238        }
239    }
240    impl StdError for TestServiceError {}
241
242    #[derive(Debug)]
243    struct TestSigner;
244    impl SignMessage for TestSigner {
245        fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
246            let mut buffer = Vec::new();
247            write_message_to(&message, &mut buffer).unwrap();
248            Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
249        }
250
251        fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
252            Some(Ok(
253                Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
254            ))
255        }
256    }
257
258    fn check_send_sync<T: Send + Sync>(value: T) -> T {
259        value
260    }
261
262    #[test]
263    fn event_stream_sender_send_sync() {
264        check_send_sync(EventStreamSender::from(stream! {
265            yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
266        }));
267    }
268
269    fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
270    where
271        S: Stream<Item = Result<O, E>> + Send + 'static,
272        O: Into<Bytes> + 'static,
273        E: Into<Box<dyn StdError + Send + Sync + 'static>> + 'static,
274    {
275        stream
276    }
277
278    #[tokio::test]
279    async fn message_stream_adapter_success() {
280        let stream = stream! {
281            yield Ok(TestMessage("test".into()));
282        };
283        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
284            TestMessage,
285            TestServiceError,
286        >::new(
287            Marshaller,
288            ErrorMarshaller,
289            TestSigner,
290            Box::pin(stream),
291        ));
292
293        let mut sent_bytes = adapter.next().await.unwrap().unwrap();
294        let sent = read_message_from(&mut sent_bytes).unwrap();
295        assert_eq!("signed", sent.headers()[0].name().as_str());
296        assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
297        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
298        assert_eq!(&b"test"[..], &inner.payload()[..]);
299
300        let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
301        let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
302        assert_eq!("signed", end_signal.headers()[0].name().as_str());
303        assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
304        assert_eq!(0, end_signal.payload().len());
305    }
306
307    #[tokio::test]
308    async fn message_stream_adapter_construction_failure() {
309        let stream = stream! {
310            yield Err(TestServiceError);
311        };
312        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
313            TestMessage,
314            TestServiceError,
315        >::new(
316            Marshaller,
317            ErrorMarshaller,
318            NoOpSigner {},
319            Box::pin(stream),
320        ));
321
322        let result = adapter.next().await.unwrap();
323        assert!(result.is_err());
324        assert!(matches!(
325            result.err().unwrap(),
326            SdkError::ConstructionFailure(_)
327        ));
328    }
329
330    // Verify the developer experience for this compiles
331    #[allow(unused)]
332    fn event_stream_input_ergonomics() {
333        fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
334            let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
335        }
336        check(stream! {
337            yield Ok(TestMessage("test".into()));
338        });
339        check(stream! {
340            yield Err(TestServiceError);
341        });
342    }
343}