aws_runtime/retries/
classifiers.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
7use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
8use aws_smithy_runtime_api::client::retries::classifiers::{
9    ClassifyRetry, RetryAction, RetryClassifierPriority, RetryReason,
10};
11use aws_smithy_types::error::metadata::ProvideErrorMetadata;
12use aws_smithy_types::retry::ErrorKind;
13use std::borrow::Cow;
14use std::error::Error as StdError;
15use std::marker::PhantomData;
16
17/// AWS error codes that represent throttling errors.
18pub const THROTTLING_ERRORS: &[&str] = &[
19    "Throttling",
20    "ThrottlingException",
21    "ThrottledException",
22    "RequestThrottledException",
23    "TooManyRequestsException",
24    "ProvisionedThroughputExceededException",
25    "TransactionInProgressException",
26    "RequestLimitExceeded",
27    "BandwidthLimitExceeded",
28    "LimitExceededException",
29    "RequestThrottled",
30    "SlowDown",
31    "PriorRequestNotComplete",
32    "EC2ThrottledException",
33];
34
35/// AWS error codes that represent transient errors.
36pub const TRANSIENT_ERRORS: &[&str] = &["RequestTimeout", "RequestTimeoutException"];
37
38/// A retry classifier for determining if the response sent by an AWS service requires a retry.
39#[derive(Debug)]
40pub struct AwsErrorCodeClassifier<E> {
41    throttling_errors: Cow<'static, [&'static str]>,
42    transient_errors: Cow<'static, [&'static str]>,
43    _inner: PhantomData<E>,
44}
45
46impl<E> Default for AwsErrorCodeClassifier<E> {
47    fn default() -> Self {
48        Self {
49            throttling_errors: THROTTLING_ERRORS.into(),
50            transient_errors: TRANSIENT_ERRORS.into(),
51            _inner: PhantomData,
52        }
53    }
54}
55
56/// Builder for [`AwsErrorCodeClassifier`]
57#[derive(Debug)]
58pub struct AwsErrorCodeClassifierBuilder<E> {
59    throttling_errors: Option<Cow<'static, [&'static str]>>,
60    transient_errors: Option<Cow<'static, [&'static str]>>,
61    _inner: PhantomData<E>,
62}
63
64impl<E> AwsErrorCodeClassifierBuilder<E> {
65    /// Set `transient_errors` for the builder
66    pub fn transient_errors(
67        mut self,
68        transient_errors: impl Into<Cow<'static, [&'static str]>>,
69    ) -> Self {
70        self.transient_errors = Some(transient_errors.into());
71        self
72    }
73
74    /// Build a new [`AwsErrorCodeClassifier`]
75    pub fn build(self) -> AwsErrorCodeClassifier<E> {
76        AwsErrorCodeClassifier {
77            throttling_errors: self.throttling_errors.unwrap_or(THROTTLING_ERRORS.into()),
78            transient_errors: self.transient_errors.unwrap_or(TRANSIENT_ERRORS.into()),
79            _inner: self._inner,
80        }
81    }
82}
83
84impl<E> AwsErrorCodeClassifier<E> {
85    /// Create a new [`AwsErrorCodeClassifier`]
86    pub fn new() -> Self {
87        Self::default()
88    }
89
90    /// Return a builder that can create a new [`AwsErrorCodeClassifier`]
91    pub fn builder() -> AwsErrorCodeClassifierBuilder<E> {
92        AwsErrorCodeClassifierBuilder {
93            throttling_errors: None,
94            transient_errors: None,
95            _inner: PhantomData,
96        }
97    }
98}
99
100impl<E> ClassifyRetry for AwsErrorCodeClassifier<E>
101where
102    E: StdError + ProvideErrorMetadata + Send + Sync + 'static,
103{
104    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
105        // Check for a result
106        let output_or_error = ctx.output_or_error();
107        // Check for an error
108        let error = match output_or_error {
109            Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
110            Some(Err(err)) => err,
111        };
112
113        let retry_after = ctx
114            .response()
115            .and_then(|res| res.headers().get("x-amz-retry-after"))
116            .and_then(|header| header.parse::<u64>().ok())
117            .map(std::time::Duration::from_millis);
118
119        let error_code = OrchestratorError::as_operation_error(error)
120            .and_then(|err| err.downcast_ref::<E>())
121            .and_then(|err| err.code());
122
123        if let Some(error_code) = error_code {
124            if self.throttling_errors.contains(&error_code) {
125                return RetryAction::RetryIndicated(RetryReason::RetryableError {
126                    kind: ErrorKind::ThrottlingError,
127                    retry_after,
128                });
129            }
130            if self.transient_errors.contains(&error_code) {
131                return RetryAction::RetryIndicated(RetryReason::RetryableError {
132                    kind: ErrorKind::TransientError,
133                    retry_after,
134                });
135            }
136        };
137
138        debug_assert!(
139            retry_after.is_none(),
140            "retry_after should be None if the error wasn't an identifiable AWS error"
141        );
142
143        RetryAction::NoActionIndicated
144    }
145
146    fn name(&self) -> &'static str {
147        "AWS Error Code"
148    }
149
150    fn priority(&self) -> RetryClassifierPriority {
151        RetryClassifierPriority::run_before(
152            RetryClassifierPriority::modeled_as_retryable_classifier(),
153        )
154    }
155}
156
157#[cfg(test)]
158mod test {
159    use crate::retries::classifiers::AwsErrorCodeClassifier;
160    use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
161    use aws_smithy_runtime_api::client::interceptors::context::{Error, Input};
162    use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
163    use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
164    use aws_smithy_types::body::SdkBody;
165    use aws_smithy_types::error::metadata::ProvideErrorMetadata;
166    use aws_smithy_types::error::ErrorMetadata;
167    use aws_smithy_types::retry::ErrorKind;
168    use std::fmt;
169    use std::time::Duration;
170
171    #[derive(Debug)]
172    struct CodedError {
173        metadata: ErrorMetadata,
174    }
175
176    impl CodedError {
177        fn new(code: &'static str) -> Self {
178            Self {
179                metadata: ErrorMetadata::builder().code(code).build(),
180            }
181        }
182    }
183
184    impl fmt::Display for CodedError {
185        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186            write!(f, "Coded Error")
187        }
188    }
189
190    impl std::error::Error for CodedError {}
191
192    impl ProvideErrorMetadata for CodedError {
193        fn meta(&self) -> &ErrorMetadata {
194            &self.metadata
195        }
196    }
197
198    #[test]
199    fn classify_by_error_code() {
200        let policy = AwsErrorCodeClassifier::<CodedError>::new();
201        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
202        ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(
203            CodedError::new("Throttling"),
204        ))));
205
206        assert_eq!(policy.classify_retry(&ctx), RetryAction::throttling_error());
207
208        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
209        ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(
210            CodedError::new("RequestTimeout"),
211        ))));
212        assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error())
213    }
214
215    #[test]
216    fn classify_generic() {
217        let policy = AwsErrorCodeClassifier::<ErrorMetadata>::new();
218        let err = ErrorMetadata::builder().code("SlowDown").build();
219        let test_response = http_02x::Response::new("OK").map(SdkBody::from);
220
221        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
222        ctx.set_response(test_response.try_into().unwrap());
223        ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(err))));
224
225        assert_eq!(policy.classify_retry(&ctx), RetryAction::throttling_error());
226    }
227
228    #[test]
229    fn test_retry_after_header() {
230        let policy = AwsErrorCodeClassifier::<ErrorMetadata>::new();
231        let err = ErrorMetadata::builder().code("SlowDown").build();
232        let res = http_02x::Response::builder()
233            .header("x-amz-retry-after", "5000")
234            .body("retry later")
235            .unwrap()
236            .map(SdkBody::from);
237        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
238        ctx.set_response(res.try_into().unwrap());
239        ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(err))));
240
241        assert_eq!(
242            policy.classify_retry(&ctx),
243            RetryAction::retryable_error_with_explicit_delay(
244                ErrorKind::ThrottlingError,
245                Duration::from_secs(5)
246            )
247        );
248    }
249}