aws_smithy_runtime/client/http/
connection_poisoning.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::client::retries::classifiers::run_classifiers_on_ctx;
7use aws_smithy_runtime_api::box_error::BoxError;
8use aws_smithy_runtime_api::client::connection::ConnectionMetadata;
9use aws_smithy_runtime_api::client::interceptors::context::{
10    AfterDeserializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
11};
12use aws_smithy_runtime_api::client::interceptors::Intercept;
13use aws_smithy_runtime_api::client::retries::classifiers::RetryAction;
14use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
15use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
16use aws_smithy_types::retry::{ReconnectMode, RetryConfig};
17use std::fmt;
18use std::sync::{Arc, Mutex};
19use tracing::{debug, error};
20
21/// An interceptor for poisoning connections in response to certain events.
22///
23/// This interceptor, when paired with a compatible connection, allows the connection to be
24/// poisoned in reaction to certain events *(like receiving a transient error.)* This allows users
25/// to avoid sending requests to a server that isn't responding. This can increase the load on a
26/// server, because more connections will be made overall.
27///
28/// **In order for this interceptor to work,** the configured connection must interact with the
29/// "connection retriever" stored in an HTTP request's `extensions` map. For an example of this,
30/// see [`HyperConnector`]. When a connection is made available to the retriever, this interceptor
31/// will call a `.poison` method on it, signalling that the connection should be dropped. It is
32/// up to the connection implementer to handle this.
33///
34/// [`HyperConnector`]: https://github.com/smithy-lang/smithy-rs/blob/26a914ece072bba2dd9b5b49003204b70e7666ac/rust-runtime/aws-smithy-runtime/src/client/http/hyper_014.rs#L347
35#[non_exhaustive]
36#[derive(Debug, Default)]
37pub struct ConnectionPoisoningInterceptor {}
38
39impl ConnectionPoisoningInterceptor {
40    /// Create a new `ConnectionPoisoningInterceptor`.
41    pub fn new() -> Self {
42        Self::default()
43    }
44}
45
46impl Intercept for ConnectionPoisoningInterceptor {
47    fn name(&self) -> &'static str {
48        "ConnectionPoisoningInterceptor"
49    }
50
51    fn modify_before_transmit(
52        &self,
53        context: &mut BeforeTransmitInterceptorContextMut<'_>,
54        _runtime_components: &RuntimeComponents,
55        cfg: &mut ConfigBag,
56    ) -> Result<(), BoxError> {
57        let capture_smithy_connection = CaptureSmithyConnection::new();
58        context
59            .request_mut()
60            .add_extension(capture_smithy_connection.clone());
61        cfg.interceptor_state().store_put(capture_smithy_connection);
62
63        Ok(())
64    }
65
66    fn read_after_deserialization(
67        &self,
68        context: &AfterDeserializationInterceptorContextRef<'_>,
69        runtime_components: &RuntimeComponents,
70        cfg: &mut ConfigBag,
71    ) -> Result<(), BoxError> {
72        let reconnect_mode = cfg
73            .load::<RetryConfig>()
74            .map(RetryConfig::reconnect_mode)
75            .unwrap_or(ReconnectMode::ReconnectOnTransientError);
76        let captured_connection = cfg.load::<CaptureSmithyConnection>().cloned();
77        let retry_classifier_result =
78            run_classifiers_on_ctx(runtime_components.retry_classifiers(), context.inner());
79        let error_is_transient = retry_classifier_result == RetryAction::transient_error();
80        let connection_poisoning_is_enabled =
81            reconnect_mode == ReconnectMode::ReconnectOnTransientError;
82
83        if error_is_transient && connection_poisoning_is_enabled {
84            debug!("received a transient error, marking the connection for closure...");
85
86            if let Some(captured_connection) = captured_connection.and_then(|conn| conn.get()) {
87                captured_connection.poison();
88                debug!("the connection was marked for closure")
89            } else {
90                error!(
91                    "unable to mark the connection for closure because no connection was found! The underlying HTTP connector never set a connection."
92                );
93            }
94        }
95
96        Ok(())
97    }
98}
99
100type LoaderFn = dyn Fn() -> Option<ConnectionMetadata> + Send + Sync;
101
102/// State for a middleware that will monitor and manage connections.
103#[derive(Clone, Default)]
104pub struct CaptureSmithyConnection {
105    loader: Arc<Mutex<Option<Box<LoaderFn>>>>,
106}
107
108impl CaptureSmithyConnection {
109    /// Create a new connection monitor.
110    pub fn new() -> Self {
111        Self {
112            loader: Default::default(),
113        }
114    }
115
116    /// Set the retriever that will capture the `hyper` connection.
117    pub fn set_connection_retriever<F>(&self, f: F)
118    where
119        F: Fn() -> Option<ConnectionMetadata> + Send + Sync + 'static,
120    {
121        *self.loader.lock().unwrap() = Some(Box::new(f));
122    }
123
124    /// Get the associated connection metadata.
125    pub fn get(&self) -> Option<ConnectionMetadata> {
126        match self.loader.lock().unwrap().as_ref() {
127            Some(loader) => loader(),
128            None => {
129                tracing::debug!("no loader was set on the CaptureSmithyConnection");
130                None
131            }
132        }
133    }
134}
135
136impl fmt::Debug for CaptureSmithyConnection {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        write!(f, "CaptureSmithyConnection")
139    }
140}
141
142impl Storable for CaptureSmithyConnection {
143    type Storer = StoreReplace<Self>;
144}
145
146#[cfg(test)]
147mod test {
148    use super::*;
149
150    #[test]
151    #[allow(clippy::redundant_clone)]
152    fn retrieve_connection_metadata() {
153        let retriever = CaptureSmithyConnection::new();
154        let retriever_clone = retriever.clone();
155        assert!(retriever.get().is_none());
156        retriever.set_connection_retriever(|| {
157            Some(
158                ConnectionMetadata::builder()
159                    .proxied(true)
160                    .poison_fn(|| {})
161                    .build(),
162            )
163        });
164
165        assert!(retriever.get().is_some());
166        assert!(retriever_clone.get().is_some());
167    }
168}