aws_smithy_runtime/client/
stalled_stream_protection.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::client::http::body::minimum_throughput::{
7    options::MinimumThroughputBodyOptions, MinimumThroughputDownloadBody, ThroughputReadingBody,
8    UploadThroughput,
9};
10use aws_smithy_async::rt::sleep::SharedAsyncSleep;
11use aws_smithy_async::time::SharedTimeSource;
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::{
14    BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextMut,
15};
16use aws_smithy_runtime_api::client::interceptors::Intercept;
17use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
18use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::ConfigBag;
21use std::mem;
22
23/// Adds stalled stream protection when sending requests and/or receiving responses.
24#[derive(Debug, Default)]
25#[non_exhaustive]
26pub struct StalledStreamProtectionInterceptor;
27
28/// Stalled stream protection can be enable for request bodies, response bodies,
29/// or both.
30#[deprecated(
31    since = "1.2.0",
32    note = "This kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag."
33)]
34pub enum StalledStreamProtectionInterceptorKind {
35    /// Enable stalled stream protection for request bodies.
36    RequestBody,
37    /// Enable stalled stream protection for response bodies.
38    ResponseBody,
39    /// Enable stalled stream protection for both request and response bodies.
40    RequestAndResponseBody,
41}
42
43impl StalledStreamProtectionInterceptor {
44    /// Create a new stalled stream protection interceptor.
45    #[deprecated(
46        since = "1.2.0",
47        note = "The kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag. Construct the interceptor using Default."
48    )]
49    #[allow(deprecated)]
50    pub fn new(_kind: StalledStreamProtectionInterceptorKind) -> Self {
51        Default::default()
52    }
53}
54
55impl Intercept for StalledStreamProtectionInterceptor {
56    fn name(&self) -> &'static str {
57        "StalledStreamProtectionInterceptor"
58    }
59
60    fn modify_before_transmit(
61        &self,
62        context: &mut BeforeTransmitInterceptorContextMut<'_>,
63        runtime_components: &RuntimeComponents,
64        cfg: &mut ConfigBag,
65    ) -> Result<(), BoxError> {
66        if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>().cloned() {
67            if sspcfg.upload_enabled() {
68                if let Some(0) = context.request().body().content_length() {
69                    tracing::trace!(
70                        "skipping stalled stream protection for zero length request body"
71                    );
72                    return Ok(());
73                }
74                let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
75                let now = time_source.now();
76
77                let options: MinimumThroughputBodyOptions = sspcfg.into();
78                let throughput = UploadThroughput::new(options.check_window(), now);
79                cfg.interceptor_state().store_put(throughput.clone());
80
81                tracing::trace!("adding stalled stream protection to request body");
82                let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken());
83                let it = it.map_preserve_contents(move |body| {
84                    let time_source = time_source.clone();
85                    SdkBody::from_body_0_4(ThroughputReadingBody::new(
86                        time_source,
87                        throughput.clone(),
88                        body,
89                    ))
90                });
91                let _ = mem::replace(context.request_mut().body_mut(), it);
92            }
93        }
94
95        Ok(())
96    }
97
98    fn modify_before_deserialization(
99        &self,
100        context: &mut BeforeDeserializationInterceptorContextMut<'_>,
101        runtime_components: &RuntimeComponents,
102        cfg: &mut ConfigBag,
103    ) -> Result<(), BoxError> {
104        if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>() {
105            if sspcfg.download_enabled() {
106                let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
107                tracing::trace!("adding stalled stream protection to response body");
108                let sspcfg = sspcfg.clone();
109                let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken());
110                let it = it.map_preserve_contents(move |body| {
111                    let sspcfg = sspcfg.clone();
112                    let async_sleep = async_sleep.clone();
113                    let time_source = time_source.clone();
114                    let mtb = MinimumThroughputDownloadBody::new(
115                        time_source,
116                        async_sleep,
117                        body,
118                        sspcfg.into(),
119                    );
120                    SdkBody::from_body_0_4(mtb)
121                });
122                let _ = mem::replace(context.response_mut().body_mut(), it);
123            }
124        }
125        Ok(())
126    }
127}
128
129fn get_runtime_component_deps(
130    runtime_components: &RuntimeComponents,
131) -> Result<(SharedAsyncSleep, SharedTimeSource), BoxError> {
132    let async_sleep = runtime_components.sleep_impl().ok_or(
133        "An async sleep implementation is required when stalled stream protection is enabled",
134    )?;
135    let time_source = runtime_components
136        .time_source()
137        .ok_or("A time source is required when stalled stream protection is enabled")?;
138    Ok((async_sleep, time_source))
139}