aws_smithy_runtime/client/http/body/
content_length_enforcement.rs
1use aws_smithy_runtime_api::box_error::BoxError;
9use aws_smithy_runtime_api::client::interceptors::context::{
10 BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextRef,
11};
12use aws_smithy_runtime_api::client::interceptors::Intercept;
13use aws_smithy_runtime_api::client::runtime_components::{
14 RuntimeComponents, RuntimeComponentsBuilder,
15};
16use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
17use aws_smithy_runtime_api::http::Response;
18use aws_smithy_types::body::SdkBody;
19use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
20use bytes::Buf;
21use http_body_1x::{Frame, SizeHint};
22use pin_project_lite::pin_project;
23use std::borrow::Cow;
24use std::error::Error;
25use std::fmt::{Display, Formatter};
26use std::pin::Pin;
27use std::task::{ready, Context, Poll};
28pin_project! {
29 struct ContentLengthEnforcingBody<InnerBody> {
31 #[pin]
32 body: InnerBody,
33 expected_length: u64,
34 bytes_received: u64,
35 }
36}
37
38#[derive(Debug)]
40pub struct ContentLengthError {
41 expected: u64,
42 received: u64,
43}
44
45impl Error for ContentLengthError {}
46
47impl Display for ContentLengthError {
48 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49 write!(
50 f,
51 "Invalid Content-Length: Expected {} bytes but {} bytes were received",
52 self.expected, self.received
53 )
54 }
55}
56
57impl ContentLengthEnforcingBody<SdkBody> {
58 fn wrap(body: SdkBody, content_length: u64) -> SdkBody {
60 body.map_preserve_contents(move |b| {
61 SdkBody::from_body_1_x(ContentLengthEnforcingBody {
62 body: b,
63 expected_length: content_length,
64 bytes_received: 0,
65 })
66 })
67 }
68}
69
70impl<
71 E: Into<aws_smithy_types::body::Error>,
72 Data: Buf,
73 InnerBody: http_body_1x::Body<Error = E, Data = Data>,
74 > http_body_1x::Body for ContentLengthEnforcingBody<InnerBody>
75{
76 type Data = Data;
77 type Error = aws_smithy_types::body::Error;
78
79 fn poll_frame(
80 mut self: Pin<&mut Self>,
81 cx: &mut Context<'_>,
82 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
83 let this = self.as_mut().project();
84 match ready!(this.body.poll_frame(cx)) {
85 None => {
86 if *this.expected_length == *this.bytes_received {
87 Poll::Ready(None)
88 } else {
89 Poll::Ready(Some(Err(ContentLengthError {
90 expected: *this.expected_length,
91 received: *this.bytes_received,
92 }
93 .into())))
94 }
95 }
96 Some(Err(e)) => Poll::Ready(Some(Err(e.into()))),
97 Some(Ok(frame)) => {
98 if let Some(data) = frame.data_ref() {
99 *this.bytes_received += data.remaining() as u64;
100 }
101 Poll::Ready(Some(Ok(frame)))
102 }
103 }
104 }
105
106 fn is_end_stream(&self) -> bool {
107 self.body.is_end_stream()
108 }
109
110 fn size_hint(&self) -> SizeHint {
111 self.body.size_hint()
112 }
113}
114
115#[derive(Debug, Default)]
116struct EnforceContentLengthInterceptor {}
117
118#[derive(Debug)]
119struct EnableContentLengthEnforcement;
120impl Storable for EnableContentLengthEnforcement {
121 type Storer = StoreReplace<EnableContentLengthEnforcement>;
122}
123
124impl Intercept for EnforceContentLengthInterceptor {
125 fn name(&self) -> &'static str {
126 "EnforceContentLength"
127 }
128
129 fn read_before_transmit(
130 &self,
131 context: &BeforeTransmitInterceptorContextRef<'_>,
132 _runtime_components: &RuntimeComponents,
133 cfg: &mut ConfigBag,
134 ) -> Result<(), BoxError> {
135 if context.request().method() == "GET" {
136 cfg.interceptor_state()
137 .store_put(EnableContentLengthEnforcement);
138 }
139 Ok(())
140 }
141 fn modify_before_deserialization(
142 &self,
143 context: &mut BeforeDeserializationInterceptorContextMut<'_>,
144 _runtime_components: &RuntimeComponents,
145 cfg: &mut ConfigBag,
146 ) -> Result<(), BoxError> {
147 if cfg.load::<EnableContentLengthEnforcement>().is_none() {
149 return Ok(());
150 }
151 let content_length = match extract_content_length(context.response()) {
152 Err(err) => {
153 tracing::warn!(err = ?err, "could not parse content length from content-length header. This header will be ignored");
154 return Ok(());
155 }
156 Ok(Some(content_length)) => content_length,
157 Ok(None) => return Ok(()),
158 };
159
160 tracing::trace!(
161 expected_length = content_length,
162 "Wrapping response body in content-length enforcement."
163 );
164
165 let body = context.response_mut().take_body();
166 let wrapped = body.map_preserve_contents(move |body| {
167 ContentLengthEnforcingBody::wrap(body, content_length)
168 });
169 *context.response_mut().body_mut() = wrapped;
170 Ok(())
171 }
172}
173
174fn extract_content_length<B>(response: &Response<B>) -> Result<Option<u64>, BoxError> {
175 let Some(content_length) = response.headers().get("content-length") else {
176 tracing::trace!("No content length header was set. Will not validate content length");
177 return Ok(None);
178 };
179 if response.headers().get_all("content-length").count() != 1 {
180 return Err("Found multiple content length headers. This is invalid".into());
181 }
182
183 Ok(Some(content_length.parse::<u64>()?))
184}
185
186#[derive(Debug, Default)]
188pub struct EnforceContentLengthRuntimePlugin {}
189
190impl EnforceContentLengthRuntimePlugin {
191 pub fn new() -> Self {
193 Self {}
194 }
195}
196
197impl RuntimePlugin for EnforceContentLengthRuntimePlugin {
198 fn runtime_components(
199 &self,
200 _current_components: &RuntimeComponentsBuilder,
201 ) -> Cow<'_, RuntimeComponentsBuilder> {
202 Cow::Owned(
203 RuntimeComponentsBuilder::new("EnforceContentLength")
204 .with_interceptor(EnforceContentLengthInterceptor {}),
205 )
206 }
207}
208
209#[cfg(all(feature = "test-util", test))]
210mod test {
211 use crate::assert_str_contains;
212 use crate::client::http::body::content_length_enforcement::{
213 extract_content_length, ContentLengthEnforcingBody,
214 };
215 use aws_smithy_runtime_api::http::Response;
216 use aws_smithy_types::body::SdkBody;
217 use aws_smithy_types::byte_stream::ByteStream;
218 use aws_smithy_types::error::display::DisplayErrorContext;
219 use bytes::Bytes;
220 use http_02x::header::CONTENT_LENGTH;
221 use http_body_04x::Body;
222 use http_body_1x::Frame;
223 use std::error::Error;
224 use std::pin::Pin;
225 use std::task::{Context, Poll};
226
227 struct ManyFrameBody {
229 data: Vec<u8>,
230 }
231
232 impl ManyFrameBody {
233 #[allow(clippy::new_ret_no_self)]
234 fn new(input: impl Into<String>) -> SdkBody {
235 let mut data = input.into().as_bytes().to_vec();
236 data.reverse();
237 SdkBody::from_body_1_x(Self { data })
238 }
239 }
240
241 impl http_body_1x::Body for ManyFrameBody {
242 type Data = Bytes;
243 type Error = <SdkBody as Body>::Error;
244
245 fn poll_frame(
246 mut self: Pin<&mut Self>,
247 _cx: &mut Context<'_>,
248 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
249 match self.data.pop() {
250 Some(next) => Poll::Ready(Some(Ok(Frame::data(Bytes::from(vec![next]))))),
251 None => Poll::Ready(None),
252 }
253 }
254 }
255
256 #[tokio::test]
257 async fn stream_too_short() {
258 let body = ManyFrameBody::new("123");
259 let enforced = ContentLengthEnforcingBody::wrap(body, 10);
260 let err = expect_body_error(enforced).await;
261 assert_str_contains!(
262 format!("{}", DisplayErrorContext(err)),
263 "Expected 10 bytes but 3 bytes were received"
264 );
265 }
266
267 #[tokio::test]
268 async fn stream_too_long() {
269 let body = ManyFrameBody::new("abcdefghijk");
270 let enforced = ContentLengthEnforcingBody::wrap(body, 5);
271 let err = expect_body_error(enforced).await;
272 assert_str_contains!(
273 format!("{}", DisplayErrorContext(err)),
274 "Expected 5 bytes but 11 bytes were received"
275 );
276 }
277
278 #[tokio::test]
279 async fn stream_just_right() {
280 let body = ManyFrameBody::new("abcdefghijk");
281 let enforced = ContentLengthEnforcingBody::wrap(body, 11);
282 let data = enforced.collect().await.unwrap().to_bytes();
283 assert_eq!(b"abcdefghijk", data.as_ref());
284 }
285
286 async fn expect_body_error(body: SdkBody) -> impl Error {
287 ByteStream::new(body)
288 .collect()
289 .await
290 .expect_err("body should have failed")
291 }
292
293 #[test]
294 fn extract_header() {
295 let mut resp1 = Response::new(200.try_into().unwrap(), ());
296 resp1.headers_mut().insert(CONTENT_LENGTH, "123");
297 assert_eq!(extract_content_length(&resp1).unwrap(), Some(123));
298 resp1.headers_mut().append(CONTENT_LENGTH, "124");
299 extract_content_length(&resp1).expect_err("duplicate headers");
301
302 resp1.headers_mut().insert(CONTENT_LENGTH, "-123.5");
304 extract_content_length(&resp1).expect_err("not an integer");
305
306 resp1.headers_mut().insert(CONTENT_LENGTH, "");
308 extract_content_length(&resp1).expect_err("empty");
309
310 resp1.headers_mut().remove(CONTENT_LENGTH);
311 assert_eq!(extract_content_length(&resp1).unwrap(), None);
312 }
313}