aws_smithy_runtime_api/client/interceptors/
context.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Interceptor context.
7//!
8//! Interceptors have access to varying pieces of context during the course of an operation.
9//!
10//! An operation is composed of multiple phases. The initial phase is "before serialization", which
11//! has the original input as context. The next phase is "before transmit", which has the serialized
12//! request as context. Depending on which hook is being called with the dispatch context,
13//! the serialized request may or may not be signed (which should be apparent from the hook name).
14//! Following the "before transmit" phase is the "before deserialization" phase, which has
15//! the raw response available as context. Finally, the "after deserialization" phase
16//! has both the raw and parsed response available.
17//!
18//! To summarize:
19//! 1. Before serialization: Only has the operation input.
20//! 2. Before transmit: Only has the serialized request.
21//! 3. Before deserialization: Has the raw response.
22//! 3. After deserialization: Has the raw response and the parsed response.
23//!
24//! When implementing hooks, if information from a previous phase is required, then implement
25//! an earlier hook to examine that context, and save off any necessary information into the
26//! [`ConfigBag`] for later hooks to examine.  Interior mutability is **NOT**
27//! recommended for storing request-specific information in your interceptor implementation.
28//! Use the [`ConfigBag`] instead.
29
30use crate::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError};
31use crate::client::result::SdkError;
32use aws_smithy_types::config_bag::ConfigBag;
33use aws_smithy_types::type_erasure::{TypeErasedBox, TypeErasedError};
34use phase::Phase;
35use std::fmt::Debug;
36use std::{fmt, mem};
37use tracing::{debug, error, trace};
38
39macro_rules! new_type_box {
40    ($name:ident, $doc:literal) => {
41        new_type_box!($name, TypeErasedBox, $doc, Send, Sync, fmt::Debug,);
42    };
43    ($name:ident, $underlying:ident, $doc:literal, $($additional_bound:path,)*) => {
44        #[doc = $doc]
45        #[derive(Debug)]
46        pub struct $name($underlying);
47
48        impl $name {
49            #[doc = concat!("Creates a new `", stringify!($name), "` with the provided concrete input value.")]
50            pub fn erase<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(input: T) -> Self {
51                Self($underlying::new(input))
52            }
53
54            #[doc = concat!("Downcasts to the concrete input value.")]
55            pub fn downcast_ref<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(&self) -> Option<&T> {
56                self.0.downcast_ref()
57            }
58
59            #[doc = concat!("Downcasts to the concrete input value.")]
60            pub fn downcast_mut<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(&mut self) -> Option<&mut T> {
61                self.0.downcast_mut()
62            }
63
64            #[doc = concat!("Downcasts to the concrete input value.")]
65            pub fn downcast<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(self) -> Result<T, Self> {
66                self.0.downcast::<T>().map(|v| *v).map_err(Self)
67            }
68
69            #[doc = concat!("Returns a `", stringify!($name), "` with a fake/test value with the expectation that it won't be downcast in the test.")]
70            #[cfg(feature = "test-util")]
71            pub fn doesnt_matter() -> Self {
72                Self($underlying::doesnt_matter())
73            }
74        }
75    };
76}
77
78new_type_box!(Input, "Type-erased operation input.");
79new_type_box!(Output, "Type-erased operation output.");
80new_type_box!(
81    Error,
82    TypeErasedError,
83    "Type-erased operation error.",
84    std::error::Error,
85);
86
87impl fmt::Display for Error {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        fmt::Display::fmt(&self.0, f)
90    }
91}
92impl std::error::Error for Error {
93    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
94        self.0.source()
95    }
96}
97
98/// Type-erased result for an operation.
99pub type OutputOrError = Result<Output, OrchestratorError<Error>>;
100
101type Request = HttpRequest;
102type Response = HttpResponse;
103
104pub use wrappers::{
105    AfterDeserializationInterceptorContextRef, BeforeDeserializationInterceptorContextMut,
106    BeforeDeserializationInterceptorContextRef, BeforeSerializationInterceptorContextMut,
107    BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
108    BeforeTransmitInterceptorContextRef, FinalizerInterceptorContextMut,
109    FinalizerInterceptorContextRef,
110};
111
112mod wrappers;
113
114/// Operation phases.
115pub(crate) mod phase;
116
117/// A container for the data currently available to an interceptor.
118///
119/// Different context is available based on which phase the operation is currently in. For example,
120/// context in the "before serialization" phase won't have a `request` yet since the input hasn't been
121/// serialized at that point. But once it gets into the "before transmit" phase, the `request` will be set.
122#[derive(Debug)]
123pub struct InterceptorContext<I = Input, O = Output, E = Error> {
124    pub(crate) input: Option<I>,
125    pub(crate) output_or_error: Option<Result<O, OrchestratorError<E>>>,
126    pub(crate) request: Option<Request>,
127    pub(crate) response: Option<Response>,
128    phase: Phase,
129    tainted: bool,
130    request_checkpoint: Option<HttpRequest>,
131}
132
133impl InterceptorContext<Input, Output, Error> {
134    /// Creates a new interceptor context in the "before serialization" phase.
135    pub fn new(input: Input) -> InterceptorContext<Input, Output, Error> {
136        InterceptorContext {
137            input: Some(input),
138            output_or_error: None,
139            request: None,
140            response: None,
141            phase: Phase::BeforeSerialization,
142            tainted: false,
143            request_checkpoint: None,
144        }
145    }
146}
147
148impl<I, O, E> InterceptorContext<I, O, E> {
149    /// Retrieve the input for the operation being invoked.
150    ///
151    /// Note: This method is intended for internal use only.
152    pub fn input(&self) -> Option<&I> {
153        self.input.as_ref()
154    }
155
156    /// Retrieve the input for the operation being invoked.
157    ///
158    /// Note: This method is intended for internal use only.
159    pub fn input_mut(&mut self) -> Option<&mut I> {
160        self.input.as_mut()
161    }
162
163    /// Takes ownership of the input.
164    ///
165    /// Note: This method is intended for internal use only.
166    pub fn take_input(&mut self) -> Option<I> {
167        self.input.take()
168    }
169
170    /// Set the request for the operation being invoked.
171    ///
172    /// Note: This method is intended for internal use only.
173    pub fn set_request(&mut self, request: Request) {
174        self.request = Some(request);
175    }
176
177    /// Retrieve the transmittable request for the operation being invoked.
178    /// This will only be available once request marshalling has completed.
179    ///
180    /// Note: This method is intended for internal use only.
181    pub fn request(&self) -> Option<&Request> {
182        self.request.as_ref()
183    }
184
185    /// Retrieve the transmittable request for the operation being invoked.
186    /// This will only be available once request marshalling has completed.
187    ///
188    /// Note: This method is intended for internal use only.
189    pub fn request_mut(&mut self) -> Option<&mut Request> {
190        self.request.as_mut()
191    }
192
193    /// Takes ownership of the request.
194    ///
195    /// Note: This method is intended for internal use only.
196    pub fn take_request(&mut self) -> Option<Request> {
197        self.request.take()
198    }
199
200    /// Set the response for the operation being invoked.
201    ///
202    /// Note: This method is intended for internal use only.
203    pub fn set_response(&mut self, response: Response) {
204        self.response = Some(response);
205    }
206
207    /// Returns the response.
208    ///
209    /// Note: This method is intended for internal use only.
210    pub fn response(&self) -> Option<&Response> {
211        self.response.as_ref()
212    }
213
214    /// Returns a mutable reference to the response.
215    ///
216    /// Note: This method is intended for internal use only.
217    pub fn response_mut(&mut self) -> Option<&mut Response> {
218        self.response.as_mut()
219    }
220
221    /// Set the output or error for the operation being invoked.
222    ///
223    /// Note: This method is intended for internal use only.
224    pub fn set_output_or_error(&mut self, output: Result<O, OrchestratorError<E>>) {
225        self.output_or_error = Some(output);
226    }
227
228    /// Returns the deserialized output or error.
229    ///
230    /// Note: This method is intended for internal use only.
231    pub fn output_or_error(&self) -> Option<Result<&O, &OrchestratorError<E>>> {
232        self.output_or_error.as_ref().map(Result::as_ref)
233    }
234
235    /// Returns the mutable reference to the deserialized output or error.
236    ///
237    /// Note: This method is intended for internal use only.
238    pub fn output_or_error_mut(&mut self) -> Option<&mut Result<O, OrchestratorError<E>>> {
239        self.output_or_error.as_mut()
240    }
241
242    /// Grants ownership of the deserialized output/error.
243    ///
244    /// Note: This method is intended for internal use only.
245    pub fn take_output_or_error(&mut self) -> Option<Result<O, OrchestratorError<E>>> {
246        self.output_or_error.take()
247    }
248
249    /// Return `true` if this context's `output_or_error` is an error. Otherwise, return `false`.
250    ///
251    /// Note: This method is intended for internal use only.
252    pub fn is_failed(&self) -> bool {
253        self.output_or_error
254            .as_ref()
255            .map(Result::is_err)
256            .unwrap_or_default()
257    }
258
259    /// Advance to the Serialization phase.
260    ///
261    /// Note: This method is intended for internal use only.
262    pub fn enter_serialization_phase(&mut self) {
263        debug!("entering \'serialization\' phase");
264        debug_assert!(
265            self.phase.is_before_serialization(),
266            "called enter_serialization_phase but phase is not before 'serialization'"
267        );
268        self.phase = Phase::Serialization;
269    }
270
271    /// Advance to the BeforeTransmit phase.
272    ///
273    /// Note: This method is intended for internal use only.
274    pub fn enter_before_transmit_phase(&mut self) {
275        debug!("entering \'before transmit\' phase");
276        debug_assert!(
277            self.phase.is_serialization(),
278            "called enter_before_transmit_phase but phase is not 'serialization'"
279        );
280        debug_assert!(
281            self.input.is_none(),
282            "input must be taken before calling enter_before_transmit_phase"
283        );
284        debug_assert!(
285            self.request.is_some(),
286            "request must be set before calling enter_before_transmit_phase"
287        );
288        self.request_checkpoint = self.request().expect("checked above").try_clone();
289        self.phase = Phase::BeforeTransmit;
290    }
291
292    /// Advance to the Transmit phase.
293    ///
294    /// Note: This method is intended for internal use only.
295    pub fn enter_transmit_phase(&mut self) {
296        debug!("entering \'transmit\' phase");
297        debug_assert!(
298            self.phase.is_before_transmit(),
299            "called enter_transmit_phase but phase is not before transmit"
300        );
301        self.phase = Phase::Transmit;
302    }
303
304    /// Advance to the BeforeDeserialization phase.
305    ///
306    /// Note: This method is intended for internal use only.
307    pub fn enter_before_deserialization_phase(&mut self) {
308        debug!("entering \'before deserialization\' phase");
309        debug_assert!(
310            self.phase.is_transmit(),
311            "called enter_before_deserialization_phase but phase is not 'transmit'"
312        );
313        debug_assert!(
314            self.request.is_none(),
315            "request must be taken before entering the 'before deserialization' phase"
316        );
317        debug_assert!(
318            self.response.is_some(),
319            "response must be set to before entering the 'before deserialization' phase"
320        );
321        self.phase = Phase::BeforeDeserialization;
322    }
323
324    /// Advance to the Deserialization phase.
325    ///
326    /// Note: This method is intended for internal use only.
327    pub fn enter_deserialization_phase(&mut self) {
328        debug!("entering \'deserialization\' phase");
329        debug_assert!(
330            self.phase.is_before_deserialization(),
331            "called enter_deserialization_phase but phase is not 'before deserialization'"
332        );
333        self.phase = Phase::Deserialization;
334    }
335
336    /// Advance to the AfterDeserialization phase.
337    ///
338    /// Note: This method is intended for internal use only.
339    pub fn enter_after_deserialization_phase(&mut self) {
340        debug!("entering \'after deserialization\' phase");
341        debug_assert!(
342            self.phase.is_deserialization(),
343            "called enter_after_deserialization_phase but phase is not 'deserialization'"
344        );
345        debug_assert!(
346            self.output_or_error.is_some(),
347            "output must be set to before entering the 'after deserialization' phase"
348        );
349        self.phase = Phase::AfterDeserialization;
350    }
351
352    /// Set the request checkpoint. This should only be called once, right before entering the retry loop.
353    ///
354    /// Note: This method is intended for internal use only.
355    pub fn save_checkpoint(&mut self) {
356        trace!("saving request checkpoint...");
357        self.request_checkpoint = self.request().and_then(|r| r.try_clone());
358        match self.request_checkpoint.as_ref() {
359            Some(_) => trace!("successfully saved request checkpoint"),
360            None => trace!("failed to save request checkpoint: request body could not be cloned"),
361        }
362    }
363
364    /// Returns false if rewinding isn't possible
365    ///
366    /// Note: This method is intended for internal use only.
367    pub fn rewind(&mut self, _cfg: &mut ConfigBag) -> RewindResult {
368        // If request_checkpoint was never set, but we've already made one attempt,
369        // then this is not a retryable request
370        let request_checkpoint = match (self.request_checkpoint.as_ref(), self.tainted) {
371            (None, true) => return RewindResult::Impossible,
372            (_, false) => {
373                self.tainted = true;
374                return RewindResult::Unnecessary;
375            }
376            (Some(req), _) => req.try_clone(),
377        };
378
379        // Otherwise, rewind to the saved request checkpoint
380        self.phase = Phase::BeforeTransmit;
381        self.request = request_checkpoint;
382        assert!(
383            self.request.is_some(),
384            "if the request wasn't cloneable, then we should have already returned from this method."
385        );
386        self.response = None;
387        self.output_or_error = None;
388        RewindResult::Occurred
389    }
390}
391
392impl<I, O, E> InterceptorContext<I, O, E>
393where
394    E: Debug,
395{
396    /// Decomposes the context into its constituent parts.
397    ///
398    /// Note: This method is intended for internal use only.
399    #[allow(clippy::type_complexity)]
400    pub fn into_parts(
401        self,
402    ) -> (
403        Option<I>,
404        Option<Result<O, OrchestratorError<E>>>,
405        Option<Request>,
406        Option<Response>,
407    ) {
408        (
409            self.input,
410            self.output_or_error,
411            self.request,
412            self.response,
413        )
414    }
415
416    /// Convert this context into the final operation result that is returned in client's the public API.
417    ///
418    /// Note: This method is intended for internal use only.
419    pub fn finalize(mut self) -> Result<O, SdkError<E, HttpResponse>> {
420        let output_or_error = self
421            .output_or_error
422            .take()
423            .expect("output_or_error must always be set before finalize is called.");
424        self.finalize_result(output_or_error)
425    }
426
427    /// Convert the given output/error into a final operation result that is returned in the client's public API.
428    ///
429    /// Note: This method is intended for internal use only.
430    pub fn finalize_result(
431        &mut self,
432        result: Result<O, OrchestratorError<E>>,
433    ) -> Result<O, SdkError<E, HttpResponse>> {
434        let response = self.response.take();
435        result.map_err(|error| OrchestratorError::into_sdk_error(error, &self.phase, response))
436    }
437
438    /// Mark this context as failed due to errors during the operation. Any errors already contained
439    /// by the context will be replaced by the given error.
440    ///
441    /// Note: This method is intended for internal use only.
442    pub fn fail(&mut self, error: OrchestratorError<E>) {
443        if !self.is_failed() {
444            trace!(
445                "orchestrator is transitioning to the 'failure' phase from the '{:?}' phase",
446                self.phase
447            );
448        }
449        if let Some(Err(existing_err)) = mem::replace(&mut self.output_or_error, Some(Err(error))) {
450            error!("orchestrator context received an error but one was already present; Throwing away previous error: {:?}", existing_err);
451        }
452    }
453}
454
455/// The result of attempting to rewind a request.
456///
457/// Note: This is intended for internal use only.
458#[non_exhaustive]
459#[derive(Debug, PartialEq, Eq, Clone, Copy)]
460pub enum RewindResult {
461    /// The request couldn't be rewound because it wasn't cloneable.
462    Impossible,
463    /// The request wasn't rewound because it was unnecessary.
464    Unnecessary,
465    /// The request was rewound successfully.
466    Occurred,
467}
468
469impl fmt::Display for RewindResult {
470    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471        match self {
472            RewindResult::Impossible => write!(
473                f,
474                "The request couldn't be rewound because it wasn't cloneable."
475            ),
476            RewindResult::Unnecessary => {
477                write!(f, "The request wasn't rewound because it was unnecessary.")
478            }
479            RewindResult::Occurred => write!(f, "The request was rewound successfully."),
480        }
481    }
482}
483
484#[cfg(all(test, feature = "test-util", feature = "http-02x"))]
485mod tests {
486    use super::*;
487    use aws_smithy_types::body::SdkBody;
488    use http_02x::header::{AUTHORIZATION, CONTENT_LENGTH};
489    use http_02x::{HeaderValue, Uri};
490
491    #[test]
492    fn test_success_transitions() {
493        let input = Input::doesnt_matter();
494        let output = Output::erase("output".to_string());
495
496        let mut context = InterceptorContext::new(input);
497        assert!(context.input().is_some());
498        context.input_mut();
499
500        context.enter_serialization_phase();
501        let _ = context.take_input();
502        context.set_request(HttpRequest::empty());
503
504        context.enter_before_transmit_phase();
505        context.request();
506        context.request_mut();
507
508        context.enter_transmit_phase();
509        let _ = context.take_request();
510        context.set_response(
511            http_02x::Response::builder()
512                .body(SdkBody::empty())
513                .unwrap()
514                .try_into()
515                .unwrap(),
516        );
517
518        context.enter_before_deserialization_phase();
519        context.response();
520        context.response_mut();
521
522        context.enter_deserialization_phase();
523        context.response();
524        context.response_mut();
525        context.set_output_or_error(Ok(output));
526
527        context.enter_after_deserialization_phase();
528        context.response();
529        context.response_mut();
530        let _ = context.output_or_error();
531        let _ = context.output_or_error_mut();
532
533        let output = context.output_or_error.unwrap().expect("success");
534        assert_eq!("output", output.downcast_ref::<String>().unwrap());
535    }
536
537    #[test]
538    fn test_rewind_for_retry() {
539        let mut cfg = ConfigBag::base();
540        let input = Input::doesnt_matter();
541        let output = Output::erase("output".to_string());
542        let error = Error::doesnt_matter();
543
544        let mut context = InterceptorContext::new(input);
545        assert!(context.input().is_some());
546
547        context.enter_serialization_phase();
548        let _ = context.take_input();
549        context.set_request(
550            http_02x::Request::builder()
551                .header("test", "the-original-un-mutated-request")
552                .body(SdkBody::empty())
553                .unwrap()
554                .try_into()
555                .unwrap(),
556        );
557        context.enter_before_transmit_phase();
558        context.save_checkpoint();
559        assert_eq!(context.rewind(&mut cfg), RewindResult::Unnecessary);
560        // Modify the test header post-checkpoint to simulate modifying the request for signing or a mutating interceptor
561        context.request_mut().unwrap().headers_mut().remove("test");
562        context.request_mut().unwrap().headers_mut().insert(
563            "test",
564            HeaderValue::from_static("request-modified-after-signing"),
565        );
566
567        context.enter_transmit_phase();
568        let request = context.take_request().unwrap();
569        assert_eq!(
570            "request-modified-after-signing",
571            request.headers().get("test").unwrap()
572        );
573        context.set_response(
574            http_02x::Response::builder()
575                .body(SdkBody::empty())
576                .unwrap()
577                .try_into()
578                .unwrap(),
579        );
580
581        context.enter_before_deserialization_phase();
582        context.enter_deserialization_phase();
583        context.set_output_or_error(Err(OrchestratorError::operation(error)));
584
585        assert_eq!(context.rewind(&mut cfg), RewindResult::Occurred);
586
587        // Now after rewinding, the test header should be its original value
588        assert_eq!(
589            "the-original-un-mutated-request",
590            context.request().unwrap().headers().get("test").unwrap()
591        );
592
593        context.enter_transmit_phase();
594        let _ = context.take_request();
595        context.set_response(
596            http_02x::Response::builder()
597                .body(SdkBody::empty())
598                .unwrap()
599                .try_into()
600                .unwrap(),
601        );
602
603        context.enter_before_deserialization_phase();
604        context.enter_deserialization_phase();
605        context.set_output_or_error(Ok(output));
606
607        context.enter_after_deserialization_phase();
608
609        let output = context.output_or_error.unwrap().expect("success");
610        assert_eq!("output", output.downcast_ref::<String>().unwrap());
611    }
612
613    #[test]
614    fn try_clone_clones_all_data() {
615        let request: HttpRequest = http_02x::Request::builder()
616            .uri(Uri::from_static("https://www.amazon.com"))
617            .method("POST")
618            .header(CONTENT_LENGTH, 456)
619            .header(AUTHORIZATION, "Token: hello")
620            .body(SdkBody::from("hello world!"))
621            .expect("valid request")
622            .try_into()
623            .unwrap();
624        let cloned = request.try_clone().expect("request is cloneable");
625
626        assert_eq!(&Uri::from_static("https://www.amazon.com"), cloned.uri());
627        assert_eq!("POST", cloned.method());
628        assert_eq!(2, cloned.headers().len());
629        assert_eq!("Token: hello", cloned.headers().get(AUTHORIZATION).unwrap(),);
630        assert_eq!("456", cloned.headers().get(CONTENT_LENGTH).unwrap());
631        assert_eq!("hello world!".as_bytes(), cloned.body().bytes().unwrap());
632    }
633}