aws_smithy_http/event_stream/
sender.rs
1use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_runtime_api::client::result::SdkError;
8use aws_smithy_types::error::ErrorMetadata;
9use bytes::Bytes;
10use futures_core::Stream;
11use std::error::Error as StdError;
12use std::fmt;
13use std::fmt::Debug;
14use std::marker::PhantomData;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tracing::trace;
18
19pub struct EventStreamSender<T, E> {
21 input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
22}
23
24impl<T, E> Debug for EventStreamSender<T, E> {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 let name_t = std::any::type_name::<T>();
27 let name_e = std::any::type_name::<E>();
28 write!(f, "EventStreamSender<{name_t}, {name_e}>")
29 }
30}
31
32impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
33 #[doc(hidden)]
34 pub fn into_body_stream(
35 self,
36 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
37 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
38 signer: impl SignMessage + Send + Sync + 'static,
39 ) -> MessageStreamAdapter<T, E> {
40 MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
41 }
42}
43
44impl<T, E, S> From<S> for EventStreamSender<T, E>
45where
46 S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
47{
48 fn from(stream: S) -> Self {
49 EventStreamSender {
50 input_stream: Box::pin(stream),
51 }
52 }
53}
54
55#[derive(Debug)]
57pub struct MessageStreamError {
58 kind: MessageStreamErrorKind,
59 pub(crate) meta: ErrorMetadata,
60}
61
62#[derive(Debug)]
63enum MessageStreamErrorKind {
64 Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
65}
66
67impl MessageStreamError {
68 pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
70 Self {
71 meta: Default::default(),
72 kind: MessageStreamErrorKind::Unhandled(err.into()),
73 }
74 }
75
76 pub fn generic(err: ErrorMetadata) -> Self {
78 Self {
79 meta: err.clone(),
80 kind: MessageStreamErrorKind::Unhandled(err.into()),
81 }
82 }
83
84 pub fn meta(&self) -> &ErrorMetadata {
87 &self.meta
88 }
89}
90
91impl StdError for MessageStreamError {
92 fn source(&self) -> Option<&(dyn StdError + 'static)> {
93 match &self.kind {
94 MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
95 }
96 }
97}
98
99impl fmt::Display for MessageStreamError {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 match &self.kind {
102 MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
103 }
104 }
105}
106
107#[allow(missing_debug_implementations)]
113pub struct MessageStreamAdapter<T, E: StdError + Send + Sync + 'static> {
114 marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
115 error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
116 signer: Box<dyn SignMessage + Send + Sync>,
117 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
118 end_signal_sent: bool,
119 _phantom: PhantomData<E>,
120}
121
122impl<T, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
123
124impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
125 pub fn new(
127 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
128 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
129 signer: impl SignMessage + Send + Sync + 'static,
130 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
131 ) -> Self {
132 MessageStreamAdapter {
133 marshaller: Box::new(marshaller),
134 error_marshaller: Box::new(error_marshaller),
135 signer: Box::new(signer),
136 stream,
137 end_signal_sent: false,
138 _phantom: Default::default(),
139 }
140 }
141}
142
143impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
144 type Item =
145 Result<Bytes, SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>;
146
147 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
148 match self.stream.as_mut().poll_next(cx) {
149 Poll::Ready(message_option) => {
150 if let Some(message_result) = message_option {
151 let message = match message_result {
152 Ok(message) => self
153 .marshaller
154 .marshall(message)
155 .map_err(SdkError::construction_failure)?,
156 Err(message) => self
157 .error_marshaller
158 .marshall(message)
159 .map_err(SdkError::construction_failure)?,
160 };
161
162 trace!(unsigned_message = ?message, "signing event stream message");
163 let message = self
164 .signer
165 .sign(message)
166 .map_err(SdkError::construction_failure)?;
167
168 let mut buffer = Vec::new();
169 write_message_to(&message, &mut buffer)
170 .map_err(SdkError::construction_failure)?;
171 trace!(signed_message = ?buffer, "sending signed event stream message");
172 Poll::Ready(Some(Ok(Bytes::from(buffer))))
173 } else if !self.end_signal_sent {
174 self.end_signal_sent = true;
175 let mut buffer = Vec::new();
176 match self.signer.sign_empty() {
177 Some(sign) => {
178 let message = sign.map_err(SdkError::construction_failure)?;
179 write_message_to(&message, &mut buffer)
180 .map_err(SdkError::construction_failure)?;
181 trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
182 Poll::Ready(Some(Ok(Bytes::from(buffer))))
183 }
184 None => Poll::Ready(None),
185 }
186 } else {
187 Poll::Ready(None)
188 }
189 }
190 Poll::Pending => Poll::Pending,
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::MarshallMessage;
198 use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
199 use async_stream::stream;
200 use aws_smithy_eventstream::error::Error as EventStreamError;
201 use aws_smithy_eventstream::frame::{
202 read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
203 };
204 use aws_smithy_runtime_api::client::result::SdkError;
205 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
206 use bytes::Bytes;
207 use futures_core::Stream;
208 use futures_util::stream::StreamExt;
209 use std::error::Error as StdError;
210
211 #[derive(Debug, Eq, PartialEq)]
212 struct TestMessage(String);
213
214 #[derive(Debug)]
215 struct Marshaller;
216 impl MarshallMessage for Marshaller {
217 type Input = TestMessage;
218
219 fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
220 Ok(Message::new(input.0.as_bytes().to_vec()))
221 }
222 }
223 #[derive(Debug)]
224 struct ErrorMarshaller;
225 impl MarshallMessage for ErrorMarshaller {
226 type Input = TestServiceError;
227
228 fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
229 Err(read_message_from(&b""[..]).expect_err("this should always fail"))
230 }
231 }
232
233 #[derive(Debug)]
234 struct TestServiceError;
235 impl std::fmt::Display for TestServiceError {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 write!(f, "TestServiceError")
238 }
239 }
240 impl StdError for TestServiceError {}
241
242 #[derive(Debug)]
243 struct TestSigner;
244 impl SignMessage for TestSigner {
245 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
246 let mut buffer = Vec::new();
247 write_message_to(&message, &mut buffer).unwrap();
248 Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
249 }
250
251 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
252 Some(Ok(
253 Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
254 ))
255 }
256 }
257
258 fn check_send_sync<T: Send + Sync>(value: T) -> T {
259 value
260 }
261
262 #[test]
263 fn event_stream_sender_send_sync() {
264 check_send_sync(EventStreamSender::from(stream! {
265 yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
266 }));
267 }
268
269 fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
270 where
271 S: Stream<Item = Result<O, E>> + Send + 'static,
272 O: Into<Bytes> + 'static,
273 E: Into<Box<dyn StdError + Send + Sync + 'static>> + 'static,
274 {
275 stream
276 }
277
278 #[tokio::test]
279 async fn message_stream_adapter_success() {
280 let stream = stream! {
281 yield Ok(TestMessage("test".into()));
282 };
283 let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
284 TestMessage,
285 TestServiceError,
286 >::new(
287 Marshaller,
288 ErrorMarshaller,
289 TestSigner,
290 Box::pin(stream),
291 ));
292
293 let mut sent_bytes = adapter.next().await.unwrap().unwrap();
294 let sent = read_message_from(&mut sent_bytes).unwrap();
295 assert_eq!("signed", sent.headers()[0].name().as_str());
296 assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
297 let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
298 assert_eq!(&b"test"[..], &inner.payload()[..]);
299
300 let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
301 let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
302 assert_eq!("signed", end_signal.headers()[0].name().as_str());
303 assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
304 assert_eq!(0, end_signal.payload().len());
305 }
306
307 #[tokio::test]
308 async fn message_stream_adapter_construction_failure() {
309 let stream = stream! {
310 yield Err(TestServiceError);
311 };
312 let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
313 TestMessage,
314 TestServiceError,
315 >::new(
316 Marshaller,
317 ErrorMarshaller,
318 NoOpSigner {},
319 Box::pin(stream),
320 ));
321
322 let result = adapter.next().await.unwrap();
323 assert!(result.is_err());
324 assert!(matches!(
325 result.err().unwrap(),
326 SdkError::ConstructionFailure(_)
327 ));
328 }
329
330 #[allow(unused)]
332 fn event_stream_input_ergonomics() {
333 fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
334 let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
335 }
336 check(stream! {
337 yield Ok(TestMessage("test".into()));
338 });
339 check(stream! {
340 yield Err(TestServiceError);
341 });
342 }
343}