aws_smithy_runtime/client/http/body/
content_length_enforcement.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! RuntimePlugin to ensure that the amount of data received matches the `Content-Length` header
7
8use aws_smithy_runtime_api::box_error::BoxError;
9use aws_smithy_runtime_api::client::interceptors::context::{
10    BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextRef,
11};
12use aws_smithy_runtime_api::client::interceptors::Intercept;
13use aws_smithy_runtime_api::client::runtime_components::{
14    RuntimeComponents, RuntimeComponentsBuilder,
15};
16use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
17use aws_smithy_runtime_api::http::Response;
18use aws_smithy_types::body::SdkBody;
19use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
20use bytes::Buf;
21use http_body_1x::{Frame, SizeHint};
22use pin_project_lite::pin_project;
23use std::borrow::Cow;
24use std::error::Error;
25use std::fmt::{Display, Formatter};
26use std::pin::Pin;
27use std::task::{ready, Context, Poll};
28pin_project! {
29    /// A body-wrapper that will calculate the `InnerBody`'s checksum and emit it as a trailer.
30    struct ContentLengthEnforcingBody<InnerBody> {
31            #[pin]
32            body: InnerBody,
33            expected_length: u64,
34            bytes_received: u64,
35    }
36}
37
38/// An error returned when a body did not have the expected content length
39#[derive(Debug)]
40pub struct ContentLengthError {
41    expected: u64,
42    received: u64,
43}
44
45impl Error for ContentLengthError {}
46
47impl Display for ContentLengthError {
48    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49        write!(
50            f,
51            "Invalid Content-Length: Expected {} bytes but {} bytes were received",
52            self.expected, self.received
53        )
54    }
55}
56
57impl ContentLengthEnforcingBody<SdkBody> {
58    /// Wraps an existing [`SdkBody`] in a content-length enforcement layer
59    fn wrap(body: SdkBody, content_length: u64) -> SdkBody {
60        body.map_preserve_contents(move |b| {
61            SdkBody::from_body_1_x(ContentLengthEnforcingBody {
62                body: b,
63                expected_length: content_length,
64                bytes_received: 0,
65            })
66        })
67    }
68}
69
70impl<
71        E: Into<aws_smithy_types::body::Error>,
72        Data: Buf,
73        InnerBody: http_body_1x::Body<Error = E, Data = Data>,
74    > http_body_1x::Body for ContentLengthEnforcingBody<InnerBody>
75{
76    type Data = Data;
77    type Error = aws_smithy_types::body::Error;
78
79    fn poll_frame(
80        mut self: Pin<&mut Self>,
81        cx: &mut Context<'_>,
82    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
83        let this = self.as_mut().project();
84        match ready!(this.body.poll_frame(cx)) {
85            None => {
86                if *this.expected_length == *this.bytes_received {
87                    Poll::Ready(None)
88                } else {
89                    Poll::Ready(Some(Err(ContentLengthError {
90                        expected: *this.expected_length,
91                        received: *this.bytes_received,
92                    }
93                    .into())))
94                }
95            }
96            Some(Err(e)) => Poll::Ready(Some(Err(e.into()))),
97            Some(Ok(frame)) => {
98                if let Some(data) = frame.data_ref() {
99                    *this.bytes_received += data.remaining() as u64;
100                }
101                Poll::Ready(Some(Ok(frame)))
102            }
103        }
104    }
105
106    fn is_end_stream(&self) -> bool {
107        self.body.is_end_stream()
108    }
109
110    fn size_hint(&self) -> SizeHint {
111        self.body.size_hint()
112    }
113}
114
115#[derive(Debug, Default)]
116struct EnforceContentLengthInterceptor {}
117
118#[derive(Debug)]
119struct EnableContentLengthEnforcement;
120impl Storable for EnableContentLengthEnforcement {
121    type Storer = StoreReplace<EnableContentLengthEnforcement>;
122}
123
124impl Intercept for EnforceContentLengthInterceptor {
125    fn name(&self) -> &'static str {
126        "EnforceContentLength"
127    }
128
129    fn read_before_transmit(
130        &self,
131        context: &BeforeTransmitInterceptorContextRef<'_>,
132        _runtime_components: &RuntimeComponents,
133        cfg: &mut ConfigBag,
134    ) -> Result<(), BoxError> {
135        if context.request().method() == "GET" {
136            cfg.interceptor_state()
137                .store_put(EnableContentLengthEnforcement);
138        }
139        Ok(())
140    }
141    fn modify_before_deserialization(
142        &self,
143        context: &mut BeforeDeserializationInterceptorContextMut<'_>,
144        _runtime_components: &RuntimeComponents,
145        cfg: &mut ConfigBag,
146    ) -> Result<(), BoxError> {
147        // if we didn't enable it for this request, bail out
148        if cfg.load::<EnableContentLengthEnforcement>().is_none() {
149            return Ok(());
150        }
151        let content_length = match extract_content_length(context.response()) {
152            Err(err) => {
153                tracing::warn!(err = ?err, "could not parse content length from content-length header. This header will be ignored");
154                return Ok(());
155            }
156            Ok(Some(content_length)) => content_length,
157            Ok(None) => return Ok(()),
158        };
159
160        tracing::trace!(
161            expected_length = content_length,
162            "Wrapping response body in content-length enforcement."
163        );
164
165        let body = context.response_mut().take_body();
166        let wrapped = body.map_preserve_contents(move |body| {
167            ContentLengthEnforcingBody::wrap(body, content_length)
168        });
169        *context.response_mut().body_mut() = wrapped;
170        Ok(())
171    }
172}
173
174fn extract_content_length<B>(response: &Response<B>) -> Result<Option<u64>, BoxError> {
175    let Some(content_length) = response.headers().get("content-length") else {
176        tracing::trace!("No content length header was set. Will not validate content length");
177        return Ok(None);
178    };
179    if response.headers().get_all("content-length").count() != 1 {
180        return Err("Found multiple content length headers. This is invalid".into());
181    }
182
183    Ok(Some(content_length.parse::<u64>()?))
184}
185
186/// Runtime plugin that enforces response bodies match their expected content length
187#[derive(Debug, Default)]
188pub struct EnforceContentLengthRuntimePlugin {}
189
190impl EnforceContentLengthRuntimePlugin {
191    /// Creates a runtime plugin which installs Content-Length enforcement middleware for response bodies
192    pub fn new() -> Self {
193        Self {}
194    }
195}
196
197impl RuntimePlugin for EnforceContentLengthRuntimePlugin {
198    fn runtime_components(
199        &self,
200        _current_components: &RuntimeComponentsBuilder,
201    ) -> Cow<'_, RuntimeComponentsBuilder> {
202        Cow::Owned(
203            RuntimeComponentsBuilder::new("EnforceContentLength")
204                .with_interceptor(EnforceContentLengthInterceptor {}),
205        )
206    }
207}
208
209#[cfg(all(feature = "test-util", test))]
210mod test {
211    use crate::assert_str_contains;
212    use crate::client::http::body::content_length_enforcement::{
213        extract_content_length, ContentLengthEnforcingBody,
214    };
215    use aws_smithy_runtime_api::http::Response;
216    use aws_smithy_types::body::SdkBody;
217    use aws_smithy_types::byte_stream::ByteStream;
218    use aws_smithy_types::error::display::DisplayErrorContext;
219    use bytes::Bytes;
220    use http_02x::header::CONTENT_LENGTH;
221    use http_body_04x::Body;
222    use http_body_1x::Frame;
223    use std::error::Error;
224    use std::pin::Pin;
225    use std::task::{Context, Poll};
226
227    /// Body for tests so we ensure our code works on a body split across multiple frames
228    struct ManyFrameBody {
229        data: Vec<u8>,
230    }
231
232    impl ManyFrameBody {
233        #[allow(clippy::new_ret_no_self)]
234        fn new(input: impl Into<String>) -> SdkBody {
235            let mut data = input.into().as_bytes().to_vec();
236            data.reverse();
237            SdkBody::from_body_1_x(Self { data })
238        }
239    }
240
241    impl http_body_1x::Body for ManyFrameBody {
242        type Data = Bytes;
243        type Error = <SdkBody as Body>::Error;
244
245        fn poll_frame(
246            mut self: Pin<&mut Self>,
247            _cx: &mut Context<'_>,
248        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
249            match self.data.pop() {
250                Some(next) => Poll::Ready(Some(Ok(Frame::data(Bytes::from(vec![next]))))),
251                None => Poll::Ready(None),
252            }
253        }
254    }
255
256    #[tokio::test]
257    async fn stream_too_short() {
258        let body = ManyFrameBody::new("123");
259        let enforced = ContentLengthEnforcingBody::wrap(body, 10);
260        let err = expect_body_error(enforced).await;
261        assert_str_contains!(
262            format!("{}", DisplayErrorContext(err)),
263            "Expected 10 bytes but 3 bytes were received"
264        );
265    }
266
267    #[tokio::test]
268    async fn stream_too_long() {
269        let body = ManyFrameBody::new("abcdefghijk");
270        let enforced = ContentLengthEnforcingBody::wrap(body, 5);
271        let err = expect_body_error(enforced).await;
272        assert_str_contains!(
273            format!("{}", DisplayErrorContext(err)),
274            "Expected 5 bytes but 11 bytes were received"
275        );
276    }
277
278    #[tokio::test]
279    async fn stream_just_right() {
280        let body = ManyFrameBody::new("abcdefghijk");
281        let enforced = ContentLengthEnforcingBody::wrap(body, 11);
282        let data = enforced.collect().await.unwrap().to_bytes();
283        assert_eq!(b"abcdefghijk", data.as_ref());
284    }
285
286    async fn expect_body_error(body: SdkBody) -> impl Error {
287        ByteStream::new(body)
288            .collect()
289            .await
290            .expect_err("body should have failed")
291    }
292
293    #[test]
294    fn extract_header() {
295        let mut resp1 = Response::new(200.try_into().unwrap(), ());
296        resp1.headers_mut().insert(CONTENT_LENGTH, "123");
297        assert_eq!(extract_content_length(&resp1).unwrap(), Some(123));
298        resp1.headers_mut().append(CONTENT_LENGTH, "124");
299        // duplicate content length header
300        extract_content_length(&resp1).expect_err("duplicate headers");
301
302        // not an integer
303        resp1.headers_mut().insert(CONTENT_LENGTH, "-123.5");
304        extract_content_length(&resp1).expect_err("not an integer");
305
306        // not an integer
307        resp1.headers_mut().insert(CONTENT_LENGTH, "");
308        extract_content_length(&resp1).expect_err("empty");
309
310        resp1.headers_mut().remove(CONTENT_LENGTH);
311        assert_eq!(extract_content_length(&resp1).unwrap(), None);
312    }
313}