1#![allow(dead_code)]
8
9use aws_runtime::auth::PayloadSigningOverride;
12use aws_runtime::content_encoding::header_value::AWS_CHUNKED;
13use aws_runtime::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions};
14use aws_smithy_checksums::ChecksumAlgorithm;
15use aws_smithy_checksums::{body::calculate, http::HttpChecksum};
16use aws_smithy_runtime::client::sdk_feature::SmithySdkFeature;
17use aws_smithy_runtime_api::box_error::BoxError;
18use aws_smithy_runtime_api::client::interceptors::context::{BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, Input};
19use aws_smithy_runtime_api::client::interceptors::Intercept;
20use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
21use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
22use aws_smithy_runtime_api::http::Request;
23use aws_smithy_types::body::SdkBody;
24use aws_smithy_types::checksum_config::RequestChecksumCalculation;
25use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
26use aws_smithy_types::error::operation::BuildError;
27use http::HeaderValue;
28use http_body::Body;
29use std::str::FromStr;
30use std::{fmt, mem};
31
32use crate::presigning::PresigningMarker;
33
34#[derive(Debug)]
36pub(crate) enum Error {
37 UnsizedRequestBody,
39 ChecksumHeadersAreUnsupportedForStreamingBody,
40}
41
42impl fmt::Display for Error {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 Self::UnsizedRequestBody => write!(f, "Only request bodies with a known size can be checksum validated."),
46 Self::ChecksumHeadersAreUnsupportedForStreamingBody => write!(
47 f,
48 "Checksum header insertion is only supported for non-streaming HTTP bodies. \
49 To checksum validate a streaming body, the checksums must be sent as trailers."
50 ),
51 }
52 }
53}
54
55impl std::error::Error for Error {}
56
57#[derive(Debug, Clone)]
58struct RequestChecksumInterceptorState {
59 checksum_algorithm: Option<String>,
61 request_checksum_required: bool,
63}
64impl Storable for RequestChecksumInterceptorState {
65 type Storer = StoreReplace<Self>;
66}
67
68type CustomDefaultFn = Box<dyn Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static>;
69
70pub(crate) struct DefaultRequestChecksumOverride {
71 custom_default: CustomDefaultFn,
72}
73impl fmt::Debug for DefaultRequestChecksumOverride {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 f.debug_struct("DefaultRequestChecksumOverride").finish()
76 }
77}
78impl Storable for DefaultRequestChecksumOverride {
79 type Storer = StoreReplace<Self>;
80}
81impl DefaultRequestChecksumOverride {
82 pub(crate) fn new<F>(custom_default: F) -> Self
83 where
84 F: Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static,
85 {
86 Self {
87 custom_default: Box::new(custom_default),
88 }
89 }
90 pub(crate) fn custom_default(&self, original: Option<ChecksumAlgorithm>, config_bag: &ConfigBag) -> Option<ChecksumAlgorithm> {
91 (self.custom_default)(original, config_bag)
92 }
93}
94
95pub(crate) struct RequestChecksumInterceptor<AP, CM> {
96 algorithm_provider: AP,
97 checksum_mutator: CM,
98}
99
100impl<AP, CM> fmt::Debug for RequestChecksumInterceptor<AP, CM> {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 f.debug_struct("RequestChecksumInterceptor").finish()
103 }
104}
105
106impl<AP, CM> RequestChecksumInterceptor<AP, CM> {
107 pub(crate) fn new(algorithm_provider: AP, checksum_mutator: CM) -> Self {
108 Self {
109 algorithm_provider,
110 checksum_mutator,
111 }
112 }
113}
114
115impl<AP, CM> Intercept for RequestChecksumInterceptor<AP, CM>
116where
117 AP: Fn(&Input) -> (Option<String>, bool) + Send + Sync,
118 CM: Fn(&mut Request, &ConfigBag) -> Result<bool, BoxError> + Send + Sync,
119{
120 fn name(&self) -> &'static str {
121 "RequestChecksumInterceptor"
122 }
123
124 fn modify_before_serialization(
125 &self,
126 context: &mut BeforeSerializationInterceptorContextMut<'_>,
127 _runtime_components: &RuntimeComponents,
128 cfg: &mut ConfigBag,
129 ) -> Result<(), BoxError> {
130 let (checksum_algorithm, request_checksum_required) = (self.algorithm_provider)(context.input());
131
132 let mut layer = Layer::new("RequestChecksumInterceptor");
133 layer.store_put(RequestChecksumInterceptorState {
134 checksum_algorithm,
135 request_checksum_required,
136 });
137 cfg.push_layer(layer);
138
139 Ok(())
140 }
141
142 fn modify_before_retry_loop(
146 &self,
147 context: &mut BeforeTransmitInterceptorContextMut<'_>,
148 _runtime_components: &RuntimeComponents,
149 cfg: &mut ConfigBag,
150 ) -> Result<(), BoxError> {
151 let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
152
153 let user_set_checksum_value = (self.checksum_mutator)(context.request_mut(), cfg).expect("Checksum header mutation should not fail");
154
155 if user_set_checksum_value {
157 return Ok(());
158 }
159
160 let request_checksum_required = state.request_checksum_required;
162
163 let checksum_algorithm = state
165 .checksum_algorithm
166 .clone()
167 .map(|s| ChecksumAlgorithm::from_str(s.as_str()))
168 .transpose()?;
169
170 let request_checksum_calculation = cfg
173 .load::<RequestChecksumCalculation>()
174 .unwrap_or(&RequestChecksumCalculation::WhenSupported);
175
176 let is_presigned_req = cfg.load::<PresigningMarker>().is_some();
178
179 let calculate_checksum = match (request_checksum_calculation, is_presigned_req) {
184 (_, true) => false,
185 (RequestChecksumCalculation::WhenRequired, false) => request_checksum_required,
186 (RequestChecksumCalculation::WhenSupported, false) => true,
187 _ => true,
188 };
189
190 if calculate_checksum {
192 let checksum_algorithm = incorporate_custom_default(checksum_algorithm, cfg).unwrap_or_default();
195
196 match checksum_algorithm {
198 ChecksumAlgorithm::Crc32 => {
199 cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32);
200 }
201 ChecksumAlgorithm::Crc32c => {
202 cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32c);
203 }
204 ChecksumAlgorithm::Crc64Nvme => {
205 cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc64);
206 }
207 #[allow(deprecated)]
208 ChecksumAlgorithm::Md5 => {
209 tracing::warn!(more_info = "Unsupported ChecksumAlgorithm MD5 set");
210 }
211 ChecksumAlgorithm::Sha1 => {
212 cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha1);
213 }
214 ChecksumAlgorithm::Sha256 => {
215 cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha256);
216 }
217 unsupported => tracing::warn!(
218 more_info = "Unsupported value of ChecksumAlgorithm detected when setting user-agent metrics",
219 unsupported = ?unsupported),
220 }
221
222 let request = context.request_mut();
223 add_checksum_for_request_body(request, checksum_algorithm, cfg)?;
224 }
225
226 Ok(())
227 }
228
229 fn read_after_serialization(
232 &self,
233 _context: &aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextRef<'_>,
234 _runtime_components: &RuntimeComponents,
235 cfg: &mut ConfigBag,
236 ) -> Result<(), BoxError> {
237 let request_checksum_calculation = cfg
238 .load::<RequestChecksumCalculation>()
239 .unwrap_or(&RequestChecksumCalculation::WhenSupported);
240
241 match request_checksum_calculation {
242 RequestChecksumCalculation::WhenSupported => {
243 cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenSupported);
244 }
245 RequestChecksumCalculation::WhenRequired => {
246 cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenRequired);
247 }
248 unsupported => tracing::warn!(
249 more_info = "Unsupported value of RequestChecksumCalculation when setting user-agent metrics",
250 unsupported = ?unsupported),
251 };
252
253 Ok(())
254 }
255}
256
257fn incorporate_custom_default(checksum: Option<ChecksumAlgorithm>, cfg: &ConfigBag) -> Option<ChecksumAlgorithm> {
258 match cfg.load::<DefaultRequestChecksumOverride>() {
259 Some(checksum_override) => checksum_override.custom_default(checksum, cfg),
260 None => checksum,
261 }
262}
263
264fn add_checksum_for_request_body(request: &mut HttpRequest, checksum_algorithm: ChecksumAlgorithm, cfg: &mut ConfigBag) -> Result<(), BoxError> {
265 match request.body().bytes() {
266 Some(data) => {
268 let mut checksum = checksum_algorithm.into_impl();
269
270 if request.headers().get(checksum.header_name()).is_none() {
273 tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
274 checksum.update(data);
275
276 request.headers_mut().insert(checksum.header_name(), checksum.header_value());
277 }
278 }
279 None => {
281 tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
282 cfg.interceptor_state().store_put(PayloadSigningOverride::StreamingUnsignedPayloadTrailer);
283 wrap_streaming_request_body_in_checksum_calculating_body(request, checksum_algorithm)?;
284 }
285 }
286 Ok(())
287}
288
289fn wrap_streaming_request_body_in_checksum_calculating_body(
290 request: &mut HttpRequest,
291 checksum_algorithm: ChecksumAlgorithm,
292) -> Result<(), BuildError> {
293 let checksum = checksum_algorithm.into_impl();
294
295 if request.headers().get(checksum.header_name()).is_some() {
297 return Ok(());
298 }
299
300 let original_body_size = request
301 .body()
302 .size_hint()
303 .exact()
304 .ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
305
306 let mut body = {
307 let body = mem::replace(request.body_mut(), SdkBody::taken());
308
309 body.map(move |body| {
310 let checksum = checksum_algorithm.into_impl();
311 let trailer_len = HttpChecksum::size(checksum.as_ref());
312 let body = calculate::ChecksumBody::new(body, checksum);
313 let aws_chunked_body_options = AwsChunkedBodyOptions::new(original_body_size, vec![trailer_len]);
314
315 let body = AwsChunkedBody::new(body, aws_chunked_body_options);
316
317 SdkBody::from_body_0_4(body)
318 })
319 };
320
321 let encoded_content_length = body.size_hint().exact().ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
322
323 let headers = request.headers_mut();
324
325 headers.insert(http::header::HeaderName::from_static("x-amz-trailer"), checksum.header_name());
326
327 headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(encoded_content_length));
328 headers.insert(
329 http::header::HeaderName::from_static("x-amz-decoded-content-length"),
330 HeaderValue::from(original_body_size),
331 );
332 headers.insert(
333 http::header::CONTENT_ENCODING,
334 HeaderValue::from_str(AWS_CHUNKED)
335 .map_err(BuildError::other)
336 .expect("\"aws-chunked\" will always be a valid HeaderValue"),
337 );
338
339 mem::swap(request.body_mut(), &mut body);
340
341 Ok(())
342}
343
344#[cfg(test)]
345mod tests {
346 use crate::http_request_checksum::wrap_streaming_request_body_in_checksum_calculating_body;
347 use aws_smithy_checksums::ChecksumAlgorithm;
348 use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
349 use aws_smithy_types::base64;
350 use aws_smithy_types::body::SdkBody;
351 use aws_smithy_types::byte_stream::ByteStream;
352 use bytes::BytesMut;
353 use http_body::Body;
354 use tempfile::NamedTempFile;
355
356 #[tokio::test]
357 async fn test_checksum_body_is_retryable() {
358 let input_text = "Hello world";
359 let chunk_len_hex = format!("{:X}", input_text.len());
360 let mut request: HttpRequest = http::Request::builder()
361 .body(SdkBody::retryable(move || SdkBody::from(input_text)))
362 .unwrap()
363 .try_into()
364 .unwrap();
365
366 assert!(request.body().try_clone().is_some());
368
369 let checksum_algorithm: ChecksumAlgorithm = "crc32".parse().unwrap();
370 wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
371
372 let mut body = request.body().try_clone().expect("body is retryable");
374
375 let mut body_data = BytesMut::new();
376 while let Some(data) = body.data().await {
377 body_data.extend_from_slice(&data.unwrap())
378 }
379 let body = std::str::from_utf8(&body_data).unwrap();
380 assert_eq!(
381 format!("{chunk_len_hex}\r\n{input_text}\r\n0\r\nx-amz-checksum-crc32:i9aeUg==\r\n\r\n"),
382 body
383 );
384 }
385
386 #[tokio::test]
387 async fn test_checksum_body_from_file_is_retryable() {
388 use std::io::Write;
389 let mut file = NamedTempFile::new().unwrap();
390 let checksum_algorithm: ChecksumAlgorithm = "crc32c".parse().unwrap();
391
392 let mut crc32c_checksum = checksum_algorithm.into_impl();
393 for i in 0..10000 {
394 let line = format!("This is a large file created for testing purposes {}", i);
395 file.as_file_mut().write_all(line.as_bytes()).unwrap();
396 crc32c_checksum.update(line.as_bytes());
397 }
398 let crc32c_checksum = crc32c_checksum.finalize();
399
400 let mut request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
401
402 assert!(request.body().try_clone().is_some());
404
405 wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
406
407 let mut body = request.body().try_clone().expect("body is retryable");
409
410 let mut body_data = BytesMut::new();
411 while let Some(data) = body.data().await {
412 body_data.extend_from_slice(&data.unwrap())
413 }
414 let body = std::str::from_utf8(&body_data).unwrap();
415 let expected_checksum = base64::encode(&crc32c_checksum);
416 let expected = format!("This is a large file created for testing purposes 9999\r\n0\r\nx-amz-checksum-crc32c:{expected_checksum}\r\n\r\n");
417 assert!(body.ends_with(&expected), "expected {body} to end with '{expected}'");
418 }
419}