aws_smithy_runtime/client/http/
connection_poisoning.rs
1use crate::client::retries::classifiers::run_classifiers_on_ctx;
7use aws_smithy_runtime_api::box_error::BoxError;
8use aws_smithy_runtime_api::client::connection::ConnectionMetadata;
9use aws_smithy_runtime_api::client::interceptors::context::{
10 AfterDeserializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
11};
12use aws_smithy_runtime_api::client::interceptors::Intercept;
13use aws_smithy_runtime_api::client::retries::classifiers::RetryAction;
14use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
15use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
16use aws_smithy_types::retry::{ReconnectMode, RetryConfig};
17use std::fmt;
18use std::sync::{Arc, Mutex};
19use tracing::{debug, error};
20
21#[non_exhaustive]
36#[derive(Debug, Default)]
37pub struct ConnectionPoisoningInterceptor {}
38
39impl ConnectionPoisoningInterceptor {
40 pub fn new() -> Self {
42 Self::default()
43 }
44}
45
46impl Intercept for ConnectionPoisoningInterceptor {
47 fn name(&self) -> &'static str {
48 "ConnectionPoisoningInterceptor"
49 }
50
51 fn modify_before_transmit(
52 &self,
53 context: &mut BeforeTransmitInterceptorContextMut<'_>,
54 _runtime_components: &RuntimeComponents,
55 cfg: &mut ConfigBag,
56 ) -> Result<(), BoxError> {
57 let capture_smithy_connection = CaptureSmithyConnection::new();
58 context
59 .request_mut()
60 .add_extension(capture_smithy_connection.clone());
61 cfg.interceptor_state().store_put(capture_smithy_connection);
62
63 Ok(())
64 }
65
66 fn read_after_deserialization(
67 &self,
68 context: &AfterDeserializationInterceptorContextRef<'_>,
69 runtime_components: &RuntimeComponents,
70 cfg: &mut ConfigBag,
71 ) -> Result<(), BoxError> {
72 let reconnect_mode = cfg
73 .load::<RetryConfig>()
74 .map(RetryConfig::reconnect_mode)
75 .unwrap_or(ReconnectMode::ReconnectOnTransientError);
76 let captured_connection = cfg.load::<CaptureSmithyConnection>().cloned();
77 let retry_classifier_result =
78 run_classifiers_on_ctx(runtime_components.retry_classifiers(), context.inner());
79 let error_is_transient = retry_classifier_result == RetryAction::transient_error();
80 let connection_poisoning_is_enabled =
81 reconnect_mode == ReconnectMode::ReconnectOnTransientError;
82
83 if error_is_transient && connection_poisoning_is_enabled {
84 debug!("received a transient error, marking the connection for closure...");
85
86 if let Some(captured_connection) = captured_connection.and_then(|conn| conn.get()) {
87 captured_connection.poison();
88 debug!("the connection was marked for closure")
89 } else {
90 error!(
91 "unable to mark the connection for closure because no connection was found! The underlying HTTP connector never set a connection."
92 );
93 }
94 }
95
96 Ok(())
97 }
98}
99
100type LoaderFn = dyn Fn() -> Option<ConnectionMetadata> + Send + Sync;
101
102#[derive(Clone, Default)]
104pub struct CaptureSmithyConnection {
105 loader: Arc<Mutex<Option<Box<LoaderFn>>>>,
106}
107
108impl CaptureSmithyConnection {
109 pub fn new() -> Self {
111 Self {
112 loader: Default::default(),
113 }
114 }
115
116 pub fn set_connection_retriever<F>(&self, f: F)
118 where
119 F: Fn() -> Option<ConnectionMetadata> + Send + Sync + 'static,
120 {
121 *self.loader.lock().unwrap() = Some(Box::new(f));
122 }
123
124 pub fn get(&self) -> Option<ConnectionMetadata> {
126 match self.loader.lock().unwrap().as_ref() {
127 Some(loader) => loader(),
128 None => {
129 tracing::debug!("no loader was set on the CaptureSmithyConnection");
130 None
131 }
132 }
133 }
134}
135
136impl fmt::Debug for CaptureSmithyConnection {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 write!(f, "CaptureSmithyConnection")
139 }
140}
141
142impl Storable for CaptureSmithyConnection {
143 type Storer = StoreReplace<Self>;
144}
145
146#[cfg(test)]
147mod test {
148 use super::*;
149
150 #[test]
151 #[allow(clippy::redundant_clone)]
152 fn retrieve_connection_metadata() {
153 let retriever = CaptureSmithyConnection::new();
154 let retriever_clone = retriever.clone();
155 assert!(retriever.get().is_none());
156 retriever.set_connection_retriever(|| {
157 Some(
158 ConnectionMetadata::builder()
159 .proxied(true)
160 .poison_fn(|| {})
161 .build(),
162 )
163 });
164
165 assert!(retriever.get().is_some());
166 assert!(retriever_clone.get().is_some());
167 }
168}