aws_runtime/retries/
classifiers.rs
1use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
7use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
8use aws_smithy_runtime_api::client::retries::classifiers::{
9 ClassifyRetry, RetryAction, RetryClassifierPriority, RetryReason,
10};
11use aws_smithy_types::error::metadata::ProvideErrorMetadata;
12use aws_smithy_types::retry::ErrorKind;
13use std::borrow::Cow;
14use std::error::Error as StdError;
15use std::marker::PhantomData;
16
17pub const THROTTLING_ERRORS: &[&str] = &[
19 "Throttling",
20 "ThrottlingException",
21 "ThrottledException",
22 "RequestThrottledException",
23 "TooManyRequestsException",
24 "ProvisionedThroughputExceededException",
25 "TransactionInProgressException",
26 "RequestLimitExceeded",
27 "BandwidthLimitExceeded",
28 "LimitExceededException",
29 "RequestThrottled",
30 "SlowDown",
31 "PriorRequestNotComplete",
32 "EC2ThrottledException",
33];
34
35pub const TRANSIENT_ERRORS: &[&str] = &["RequestTimeout", "RequestTimeoutException"];
37
38#[derive(Debug)]
40pub struct AwsErrorCodeClassifier<E> {
41 throttling_errors: Cow<'static, [&'static str]>,
42 transient_errors: Cow<'static, [&'static str]>,
43 _inner: PhantomData<E>,
44}
45
46impl<E> Default for AwsErrorCodeClassifier<E> {
47 fn default() -> Self {
48 Self {
49 throttling_errors: THROTTLING_ERRORS.into(),
50 transient_errors: TRANSIENT_ERRORS.into(),
51 _inner: PhantomData,
52 }
53 }
54}
55
56#[derive(Debug)]
58pub struct AwsErrorCodeClassifierBuilder<E> {
59 throttling_errors: Option<Cow<'static, [&'static str]>>,
60 transient_errors: Option<Cow<'static, [&'static str]>>,
61 _inner: PhantomData<E>,
62}
63
64impl<E> AwsErrorCodeClassifierBuilder<E> {
65 pub fn transient_errors(
67 mut self,
68 transient_errors: impl Into<Cow<'static, [&'static str]>>,
69 ) -> Self {
70 self.transient_errors = Some(transient_errors.into());
71 self
72 }
73
74 pub fn build(self) -> AwsErrorCodeClassifier<E> {
76 AwsErrorCodeClassifier {
77 throttling_errors: self.throttling_errors.unwrap_or(THROTTLING_ERRORS.into()),
78 transient_errors: self.transient_errors.unwrap_or(TRANSIENT_ERRORS.into()),
79 _inner: self._inner,
80 }
81 }
82}
83
84impl<E> AwsErrorCodeClassifier<E> {
85 pub fn new() -> Self {
87 Self::default()
88 }
89
90 pub fn builder() -> AwsErrorCodeClassifierBuilder<E> {
92 AwsErrorCodeClassifierBuilder {
93 throttling_errors: None,
94 transient_errors: None,
95 _inner: PhantomData,
96 }
97 }
98}
99
100impl<E> ClassifyRetry for AwsErrorCodeClassifier<E>
101where
102 E: StdError + ProvideErrorMetadata + Send + Sync + 'static,
103{
104 fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
105 let output_or_error = ctx.output_or_error();
107 let error = match output_or_error {
109 Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
110 Some(Err(err)) => err,
111 };
112
113 let retry_after = ctx
114 .response()
115 .and_then(|res| res.headers().get("x-amz-retry-after"))
116 .and_then(|header| header.parse::<u64>().ok())
117 .map(std::time::Duration::from_millis);
118
119 let error_code = OrchestratorError::as_operation_error(error)
120 .and_then(|err| err.downcast_ref::<E>())
121 .and_then(|err| err.code());
122
123 if let Some(error_code) = error_code {
124 if self.throttling_errors.contains(&error_code) {
125 return RetryAction::RetryIndicated(RetryReason::RetryableError {
126 kind: ErrorKind::ThrottlingError,
127 retry_after,
128 });
129 }
130 if self.transient_errors.contains(&error_code) {
131 return RetryAction::RetryIndicated(RetryReason::RetryableError {
132 kind: ErrorKind::TransientError,
133 retry_after,
134 });
135 }
136 };
137
138 debug_assert!(
139 retry_after.is_none(),
140 "retry_after should be None if the error wasn't an identifiable AWS error"
141 );
142
143 RetryAction::NoActionIndicated
144 }
145
146 fn name(&self) -> &'static str {
147 "AWS Error Code"
148 }
149
150 fn priority(&self) -> RetryClassifierPriority {
151 RetryClassifierPriority::run_before(
152 RetryClassifierPriority::modeled_as_retryable_classifier(),
153 )
154 }
155}
156
157#[cfg(test)]
158mod test {
159 use crate::retries::classifiers::AwsErrorCodeClassifier;
160 use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
161 use aws_smithy_runtime_api::client::interceptors::context::{Error, Input};
162 use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
163 use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
164 use aws_smithy_types::body::SdkBody;
165 use aws_smithy_types::error::metadata::ProvideErrorMetadata;
166 use aws_smithy_types::error::ErrorMetadata;
167 use aws_smithy_types::retry::ErrorKind;
168 use std::fmt;
169 use std::time::Duration;
170
171 #[derive(Debug)]
172 struct CodedError {
173 metadata: ErrorMetadata,
174 }
175
176 impl CodedError {
177 fn new(code: &'static str) -> Self {
178 Self {
179 metadata: ErrorMetadata::builder().code(code).build(),
180 }
181 }
182 }
183
184 impl fmt::Display for CodedError {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 write!(f, "Coded Error")
187 }
188 }
189
190 impl std::error::Error for CodedError {}
191
192 impl ProvideErrorMetadata for CodedError {
193 fn meta(&self) -> &ErrorMetadata {
194 &self.metadata
195 }
196 }
197
198 #[test]
199 fn classify_by_error_code() {
200 let policy = AwsErrorCodeClassifier::<CodedError>::new();
201 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
202 ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(
203 CodedError::new("Throttling"),
204 ))));
205
206 assert_eq!(policy.classify_retry(&ctx), RetryAction::throttling_error());
207
208 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
209 ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(
210 CodedError::new("RequestTimeout"),
211 ))));
212 assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error())
213 }
214
215 #[test]
216 fn classify_generic() {
217 let policy = AwsErrorCodeClassifier::<ErrorMetadata>::new();
218 let err = ErrorMetadata::builder().code("SlowDown").build();
219 let test_response = http_02x::Response::new("OK").map(SdkBody::from);
220
221 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
222 ctx.set_response(test_response.try_into().unwrap());
223 ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(err))));
224
225 assert_eq!(policy.classify_retry(&ctx), RetryAction::throttling_error());
226 }
227
228 #[test]
229 fn test_retry_after_header() {
230 let policy = AwsErrorCodeClassifier::<ErrorMetadata>::new();
231 let err = ErrorMetadata::builder().code("SlowDown").build();
232 let res = http_02x::Response::builder()
233 .header("x-amz-retry-after", "5000")
234 .body("retry later")
235 .unwrap()
236 .map(SdkBody::from);
237 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
238 ctx.set_response(res.try_into().unwrap());
239 ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(err))));
240
241 assert_eq!(
242 policy.classify_retry(&ctx),
243 RetryAction::retryable_error_with_explicit_delay(
244 ErrorKind::ThrottlingError,
245 Duration::from_secs(5)
246 )
247 );
248 }
249}