1use crate::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError};
31use crate::client::result::SdkError;
32use aws_smithy_types::config_bag::ConfigBag;
33use aws_smithy_types::type_erasure::{TypeErasedBox, TypeErasedError};
34use phase::Phase;
35use std::fmt::Debug;
36use std::{fmt, mem};
37use tracing::{debug, error, trace};
38
39macro_rules! new_type_box {
40 ($name:ident, $doc:literal) => {
41 new_type_box!($name, TypeErasedBox, $doc, Send, Sync, fmt::Debug,);
42 };
43 ($name:ident, $underlying:ident, $doc:literal, $($additional_bound:path,)*) => {
44 #[doc = $doc]
45 #[derive(Debug)]
46 pub struct $name($underlying);
47
48 impl $name {
49 #[doc = concat!("Creates a new `", stringify!($name), "` with the provided concrete input value.")]
50 pub fn erase<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(input: T) -> Self {
51 Self($underlying::new(input))
52 }
53
54 #[doc = concat!("Downcasts to the concrete input value.")]
55 pub fn downcast_ref<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(&self) -> Option<&T> {
56 self.0.downcast_ref()
57 }
58
59 #[doc = concat!("Downcasts to the concrete input value.")]
60 pub fn downcast_mut<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(&mut self) -> Option<&mut T> {
61 self.0.downcast_mut()
62 }
63
64 #[doc = concat!("Downcasts to the concrete input value.")]
65 pub fn downcast<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(self) -> Result<T, Self> {
66 self.0.downcast::<T>().map(|v| *v).map_err(Self)
67 }
68
69 #[doc = concat!("Returns a `", stringify!($name), "` with a fake/test value with the expectation that it won't be downcast in the test.")]
70 #[cfg(feature = "test-util")]
71 pub fn doesnt_matter() -> Self {
72 Self($underlying::doesnt_matter())
73 }
74 }
75 };
76}
77
78new_type_box!(Input, "Type-erased operation input.");
79new_type_box!(Output, "Type-erased operation output.");
80new_type_box!(
81 Error,
82 TypeErasedError,
83 "Type-erased operation error.",
84 std::error::Error,
85);
86
87impl fmt::Display for Error {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 fmt::Display::fmt(&self.0, f)
90 }
91}
92impl std::error::Error for Error {
93 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
94 self.0.source()
95 }
96}
97
98pub type OutputOrError = Result<Output, OrchestratorError<Error>>;
100
101type Request = HttpRequest;
102type Response = HttpResponse;
103
104pub use wrappers::{
105 AfterDeserializationInterceptorContextRef, BeforeDeserializationInterceptorContextMut,
106 BeforeDeserializationInterceptorContextRef, BeforeSerializationInterceptorContextMut,
107 BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
108 BeforeTransmitInterceptorContextRef, FinalizerInterceptorContextMut,
109 FinalizerInterceptorContextRef,
110};
111
112mod wrappers;
113
114pub(crate) mod phase;
116
117#[derive(Debug)]
123pub struct InterceptorContext<I = Input, O = Output, E = Error> {
124 pub(crate) input: Option<I>,
125 pub(crate) output_or_error: Option<Result<O, OrchestratorError<E>>>,
126 pub(crate) request: Option<Request>,
127 pub(crate) response: Option<Response>,
128 phase: Phase,
129 tainted: bool,
130 request_checkpoint: Option<HttpRequest>,
131}
132
133impl InterceptorContext<Input, Output, Error> {
134 pub fn new(input: Input) -> InterceptorContext<Input, Output, Error> {
136 InterceptorContext {
137 input: Some(input),
138 output_or_error: None,
139 request: None,
140 response: None,
141 phase: Phase::BeforeSerialization,
142 tainted: false,
143 request_checkpoint: None,
144 }
145 }
146}
147
148impl<I, O, E> InterceptorContext<I, O, E> {
149 pub fn input(&self) -> Option<&I> {
153 self.input.as_ref()
154 }
155
156 pub fn input_mut(&mut self) -> Option<&mut I> {
160 self.input.as_mut()
161 }
162
163 pub fn take_input(&mut self) -> Option<I> {
167 self.input.take()
168 }
169
170 pub fn set_request(&mut self, request: Request) {
174 self.request = Some(request);
175 }
176
177 pub fn request(&self) -> Option<&Request> {
182 self.request.as_ref()
183 }
184
185 pub fn request_mut(&mut self) -> Option<&mut Request> {
190 self.request.as_mut()
191 }
192
193 pub fn take_request(&mut self) -> Option<Request> {
197 self.request.take()
198 }
199
200 pub fn set_response(&mut self, response: Response) {
204 self.response = Some(response);
205 }
206
207 pub fn response(&self) -> Option<&Response> {
211 self.response.as_ref()
212 }
213
214 pub fn response_mut(&mut self) -> Option<&mut Response> {
218 self.response.as_mut()
219 }
220
221 pub fn set_output_or_error(&mut self, output: Result<O, OrchestratorError<E>>) {
225 self.output_or_error = Some(output);
226 }
227
228 pub fn output_or_error(&self) -> Option<Result<&O, &OrchestratorError<E>>> {
232 self.output_or_error.as_ref().map(Result::as_ref)
233 }
234
235 pub fn output_or_error_mut(&mut self) -> Option<&mut Result<O, OrchestratorError<E>>> {
239 self.output_or_error.as_mut()
240 }
241
242 pub fn take_output_or_error(&mut self) -> Option<Result<O, OrchestratorError<E>>> {
246 self.output_or_error.take()
247 }
248
249 pub fn is_failed(&self) -> bool {
253 self.output_or_error
254 .as_ref()
255 .map(Result::is_err)
256 .unwrap_or_default()
257 }
258
259 pub fn enter_serialization_phase(&mut self) {
263 debug!("entering \'serialization\' phase");
264 debug_assert!(
265 self.phase.is_before_serialization(),
266 "called enter_serialization_phase but phase is not before 'serialization'"
267 );
268 self.phase = Phase::Serialization;
269 }
270
271 pub fn enter_before_transmit_phase(&mut self) {
275 debug!("entering \'before transmit\' phase");
276 debug_assert!(
277 self.phase.is_serialization(),
278 "called enter_before_transmit_phase but phase is not 'serialization'"
279 );
280 debug_assert!(
281 self.input.is_none(),
282 "input must be taken before calling enter_before_transmit_phase"
283 );
284 debug_assert!(
285 self.request.is_some(),
286 "request must be set before calling enter_before_transmit_phase"
287 );
288 self.request_checkpoint = self.request().expect("checked above").try_clone();
289 self.phase = Phase::BeforeTransmit;
290 }
291
292 pub fn enter_transmit_phase(&mut self) {
296 debug!("entering \'transmit\' phase");
297 debug_assert!(
298 self.phase.is_before_transmit(),
299 "called enter_transmit_phase but phase is not before transmit"
300 );
301 self.phase = Phase::Transmit;
302 }
303
304 pub fn enter_before_deserialization_phase(&mut self) {
308 debug!("entering \'before deserialization\' phase");
309 debug_assert!(
310 self.phase.is_transmit(),
311 "called enter_before_deserialization_phase but phase is not 'transmit'"
312 );
313 debug_assert!(
314 self.request.is_none(),
315 "request must be taken before entering the 'before deserialization' phase"
316 );
317 debug_assert!(
318 self.response.is_some(),
319 "response must be set to before entering the 'before deserialization' phase"
320 );
321 self.phase = Phase::BeforeDeserialization;
322 }
323
324 pub fn enter_deserialization_phase(&mut self) {
328 debug!("entering \'deserialization\' phase");
329 debug_assert!(
330 self.phase.is_before_deserialization(),
331 "called enter_deserialization_phase but phase is not 'before deserialization'"
332 );
333 self.phase = Phase::Deserialization;
334 }
335
336 pub fn enter_after_deserialization_phase(&mut self) {
340 debug!("entering \'after deserialization\' phase");
341 debug_assert!(
342 self.phase.is_deserialization(),
343 "called enter_after_deserialization_phase but phase is not 'deserialization'"
344 );
345 debug_assert!(
346 self.output_or_error.is_some(),
347 "output must be set to before entering the 'after deserialization' phase"
348 );
349 self.phase = Phase::AfterDeserialization;
350 }
351
352 pub fn save_checkpoint(&mut self) {
356 trace!("saving request checkpoint...");
357 self.request_checkpoint = self.request().and_then(|r| r.try_clone());
358 match self.request_checkpoint.as_ref() {
359 Some(_) => trace!("successfully saved request checkpoint"),
360 None => trace!("failed to save request checkpoint: request body could not be cloned"),
361 }
362 }
363
364 pub fn rewind(&mut self, _cfg: &mut ConfigBag) -> RewindResult {
368 let request_checkpoint = match (self.request_checkpoint.as_ref(), self.tainted) {
371 (None, true) => return RewindResult::Impossible,
372 (_, false) => {
373 self.tainted = true;
374 return RewindResult::Unnecessary;
375 }
376 (Some(req), _) => req.try_clone(),
377 };
378
379 self.phase = Phase::BeforeTransmit;
381 self.request = request_checkpoint;
382 assert!(
383 self.request.is_some(),
384 "if the request wasn't cloneable, then we should have already returned from this method."
385 );
386 self.response = None;
387 self.output_or_error = None;
388 RewindResult::Occurred
389 }
390}
391
392impl<I, O, E> InterceptorContext<I, O, E>
393where
394 E: Debug,
395{
396 #[allow(clippy::type_complexity)]
400 pub fn into_parts(
401 self,
402 ) -> (
403 Option<I>,
404 Option<Result<O, OrchestratorError<E>>>,
405 Option<Request>,
406 Option<Response>,
407 ) {
408 (
409 self.input,
410 self.output_or_error,
411 self.request,
412 self.response,
413 )
414 }
415
416 pub fn finalize(mut self) -> Result<O, SdkError<E, HttpResponse>> {
420 let output_or_error = self
421 .output_or_error
422 .take()
423 .expect("output_or_error must always be set before finalize is called.");
424 self.finalize_result(output_or_error)
425 }
426
427 pub fn finalize_result(
431 &mut self,
432 result: Result<O, OrchestratorError<E>>,
433 ) -> Result<O, SdkError<E, HttpResponse>> {
434 let response = self.response.take();
435 result.map_err(|error| OrchestratorError::into_sdk_error(error, &self.phase, response))
436 }
437
438 pub fn fail(&mut self, error: OrchestratorError<E>) {
443 if !self.is_failed() {
444 trace!(
445 "orchestrator is transitioning to the 'failure' phase from the '{:?}' phase",
446 self.phase
447 );
448 }
449 if let Some(Err(existing_err)) = mem::replace(&mut self.output_or_error, Some(Err(error))) {
450 error!("orchestrator context received an error but one was already present; Throwing away previous error: {:?}", existing_err);
451 }
452 }
453}
454
455#[non_exhaustive]
459#[derive(Debug, PartialEq, Eq, Clone, Copy)]
460pub enum RewindResult {
461 Impossible,
463 Unnecessary,
465 Occurred,
467}
468
469impl fmt::Display for RewindResult {
470 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471 match self {
472 RewindResult::Impossible => write!(
473 f,
474 "The request couldn't be rewound because it wasn't cloneable."
475 ),
476 RewindResult::Unnecessary => {
477 write!(f, "The request wasn't rewound because it was unnecessary.")
478 }
479 RewindResult::Occurred => write!(f, "The request was rewound successfully."),
480 }
481 }
482}
483
484#[cfg(all(test, feature = "test-util", feature = "http-02x"))]
485mod tests {
486 use super::*;
487 use aws_smithy_types::body::SdkBody;
488 use http_02x::header::{AUTHORIZATION, CONTENT_LENGTH};
489 use http_02x::{HeaderValue, Uri};
490
491 #[test]
492 fn test_success_transitions() {
493 let input = Input::doesnt_matter();
494 let output = Output::erase("output".to_string());
495
496 let mut context = InterceptorContext::new(input);
497 assert!(context.input().is_some());
498 context.input_mut();
499
500 context.enter_serialization_phase();
501 let _ = context.take_input();
502 context.set_request(HttpRequest::empty());
503
504 context.enter_before_transmit_phase();
505 context.request();
506 context.request_mut();
507
508 context.enter_transmit_phase();
509 let _ = context.take_request();
510 context.set_response(
511 http_02x::Response::builder()
512 .body(SdkBody::empty())
513 .unwrap()
514 .try_into()
515 .unwrap(),
516 );
517
518 context.enter_before_deserialization_phase();
519 context.response();
520 context.response_mut();
521
522 context.enter_deserialization_phase();
523 context.response();
524 context.response_mut();
525 context.set_output_or_error(Ok(output));
526
527 context.enter_after_deserialization_phase();
528 context.response();
529 context.response_mut();
530 let _ = context.output_or_error();
531 let _ = context.output_or_error_mut();
532
533 let output = context.output_or_error.unwrap().expect("success");
534 assert_eq!("output", output.downcast_ref::<String>().unwrap());
535 }
536
537 #[test]
538 fn test_rewind_for_retry() {
539 let mut cfg = ConfigBag::base();
540 let input = Input::doesnt_matter();
541 let output = Output::erase("output".to_string());
542 let error = Error::doesnt_matter();
543
544 let mut context = InterceptorContext::new(input);
545 assert!(context.input().is_some());
546
547 context.enter_serialization_phase();
548 let _ = context.take_input();
549 context.set_request(
550 http_02x::Request::builder()
551 .header("test", "the-original-un-mutated-request")
552 .body(SdkBody::empty())
553 .unwrap()
554 .try_into()
555 .unwrap(),
556 );
557 context.enter_before_transmit_phase();
558 context.save_checkpoint();
559 assert_eq!(context.rewind(&mut cfg), RewindResult::Unnecessary);
560 context.request_mut().unwrap().headers_mut().remove("test");
562 context.request_mut().unwrap().headers_mut().insert(
563 "test",
564 HeaderValue::from_static("request-modified-after-signing"),
565 );
566
567 context.enter_transmit_phase();
568 let request = context.take_request().unwrap();
569 assert_eq!(
570 "request-modified-after-signing",
571 request.headers().get("test").unwrap()
572 );
573 context.set_response(
574 http_02x::Response::builder()
575 .body(SdkBody::empty())
576 .unwrap()
577 .try_into()
578 .unwrap(),
579 );
580
581 context.enter_before_deserialization_phase();
582 context.enter_deserialization_phase();
583 context.set_output_or_error(Err(OrchestratorError::operation(error)));
584
585 assert_eq!(context.rewind(&mut cfg), RewindResult::Occurred);
586
587 assert_eq!(
589 "the-original-un-mutated-request",
590 context.request().unwrap().headers().get("test").unwrap()
591 );
592
593 context.enter_transmit_phase();
594 let _ = context.take_request();
595 context.set_response(
596 http_02x::Response::builder()
597 .body(SdkBody::empty())
598 .unwrap()
599 .try_into()
600 .unwrap(),
601 );
602
603 context.enter_before_deserialization_phase();
604 context.enter_deserialization_phase();
605 context.set_output_or_error(Ok(output));
606
607 context.enter_after_deserialization_phase();
608
609 let output = context.output_or_error.unwrap().expect("success");
610 assert_eq!("output", output.downcast_ref::<String>().unwrap());
611 }
612
613 #[test]
614 fn try_clone_clones_all_data() {
615 let request: HttpRequest = http_02x::Request::builder()
616 .uri(Uri::from_static("https://www.amazon.com"))
617 .method("POST")
618 .header(CONTENT_LENGTH, 456)
619 .header(AUTHORIZATION, "Token: hello")
620 .body(SdkBody::from("hello world!"))
621 .expect("valid request")
622 .try_into()
623 .unwrap();
624 let cloned = request.try_clone().expect("request is cloneable");
625
626 assert_eq!(&Uri::from_static("https://www.amazon.com"), cloned.uri());
627 assert_eq!("POST", cloned.method());
628 assert_eq!(2, cloned.headers().len());
629 assert_eq!("Token: hello", cloned.headers().get(AUTHORIZATION).unwrap(),);
630 assert_eq!("456", cloned.headers().get(CONTENT_LENGTH).unwrap());
631 assert_eq!("hello world!".as_bytes(), cloned.body().bytes().unwrap());
632 }
633}