aws_smithy_runtime/client/
stalled_stream_protection.rs
1use crate::client::http::body::minimum_throughput::{
7 options::MinimumThroughputBodyOptions, MinimumThroughputDownloadBody, ThroughputReadingBody,
8 UploadThroughput,
9};
10use aws_smithy_async::rt::sleep::SharedAsyncSleep;
11use aws_smithy_async::time::SharedTimeSource;
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::{
14 BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextMut,
15};
16use aws_smithy_runtime_api::client::interceptors::Intercept;
17use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
18use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::ConfigBag;
21use std::mem;
22
23#[derive(Debug, Default)]
25#[non_exhaustive]
26pub struct StalledStreamProtectionInterceptor;
27
28#[deprecated(
31 since = "1.2.0",
32 note = "This kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag."
33)]
34pub enum StalledStreamProtectionInterceptorKind {
35 RequestBody,
37 ResponseBody,
39 RequestAndResponseBody,
41}
42
43impl StalledStreamProtectionInterceptor {
44 #[deprecated(
46 since = "1.2.0",
47 note = "The kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag. Construct the interceptor using Default."
48 )]
49 #[allow(deprecated)]
50 pub fn new(_kind: StalledStreamProtectionInterceptorKind) -> Self {
51 Default::default()
52 }
53}
54
55impl Intercept for StalledStreamProtectionInterceptor {
56 fn name(&self) -> &'static str {
57 "StalledStreamProtectionInterceptor"
58 }
59
60 fn modify_before_transmit(
61 &self,
62 context: &mut BeforeTransmitInterceptorContextMut<'_>,
63 runtime_components: &RuntimeComponents,
64 cfg: &mut ConfigBag,
65 ) -> Result<(), BoxError> {
66 if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>().cloned() {
67 if sspcfg.upload_enabled() {
68 if let Some(0) = context.request().body().content_length() {
69 tracing::trace!(
70 "skipping stalled stream protection for zero length request body"
71 );
72 return Ok(());
73 }
74 let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
75 let now = time_source.now();
76
77 let options: MinimumThroughputBodyOptions = sspcfg.into();
78 let throughput = UploadThroughput::new(options.check_window(), now);
79 cfg.interceptor_state().store_put(throughput.clone());
80
81 tracing::trace!("adding stalled stream protection to request body");
82 let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken());
83 let it = it.map_preserve_contents(move |body| {
84 let time_source = time_source.clone();
85 SdkBody::from_body_0_4(ThroughputReadingBody::new(
86 time_source,
87 throughput.clone(),
88 body,
89 ))
90 });
91 let _ = mem::replace(context.request_mut().body_mut(), it);
92 }
93 }
94
95 Ok(())
96 }
97
98 fn modify_before_deserialization(
99 &self,
100 context: &mut BeforeDeserializationInterceptorContextMut<'_>,
101 runtime_components: &RuntimeComponents,
102 cfg: &mut ConfigBag,
103 ) -> Result<(), BoxError> {
104 if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>() {
105 if sspcfg.download_enabled() {
106 let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
107 tracing::trace!("adding stalled stream protection to response body");
108 let sspcfg = sspcfg.clone();
109 let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken());
110 let it = it.map_preserve_contents(move |body| {
111 let sspcfg = sspcfg.clone();
112 let async_sleep = async_sleep.clone();
113 let time_source = time_source.clone();
114 let mtb = MinimumThroughputDownloadBody::new(
115 time_source,
116 async_sleep,
117 body,
118 sspcfg.into(),
119 );
120 SdkBody::from_body_0_4(mtb)
121 });
122 let _ = mem::replace(context.response_mut().body_mut(), it);
123 }
124 }
125 Ok(())
126 }
127}
128
129fn get_runtime_component_deps(
130 runtime_components: &RuntimeComponents,
131) -> Result<(SharedAsyncSleep, SharedTimeSource), BoxError> {
132 let async_sleep = runtime_components.sleep_impl().ok_or(
133 "An async sleep implementation is required when stalled stream protection is enabled",
134 )?;
135 let time_source = runtime_components
136 .time_source()
137 .ok_or("A time source is required when stalled stream protection is enabled")?;
138 Ok((async_sleep, time_source))
139}