aws_sdk_s3/
http_request_checksum.rs

1// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT.
2/*
3 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7#![allow(dead_code)]
8
9//! Interceptor for handling Smithy `@httpChecksum` request checksumming with AWS SigV4
10
11use aws_runtime::auth::PayloadSigningOverride;
12use aws_runtime::content_encoding::header_value::AWS_CHUNKED;
13use aws_runtime::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions};
14use aws_smithy_checksums::ChecksumAlgorithm;
15use aws_smithy_checksums::{body::calculate, http::HttpChecksum};
16use aws_smithy_runtime::client::sdk_feature::SmithySdkFeature;
17use aws_smithy_runtime_api::box_error::BoxError;
18use aws_smithy_runtime_api::client::interceptors::context::{BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, Input};
19use aws_smithy_runtime_api::client::interceptors::Intercept;
20use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
21use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
22use aws_smithy_runtime_api::http::Request;
23use aws_smithy_types::body::SdkBody;
24use aws_smithy_types::checksum_config::RequestChecksumCalculation;
25use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
26use aws_smithy_types::error::operation::BuildError;
27use http::HeaderValue;
28use http_body::Body;
29use std::str::FromStr;
30use std::{fmt, mem};
31
32use crate::presigning::PresigningMarker;
33
34/// Errors related to constructing checksum-validated HTTP requests
35#[derive(Debug)]
36pub(crate) enum Error {
37    /// Only request bodies with a known size can be checksum validated
38    UnsizedRequestBody,
39    ChecksumHeadersAreUnsupportedForStreamingBody,
40}
41
42impl fmt::Display for Error {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            Self::UnsizedRequestBody => write!(f, "Only request bodies with a known size can be checksum validated."),
46            Self::ChecksumHeadersAreUnsupportedForStreamingBody => write!(
47                f,
48                "Checksum header insertion is only supported for non-streaming HTTP bodies. \
49                   To checksum validate a streaming body, the checksums must be sent as trailers."
50            ),
51        }
52    }
53}
54
55impl std::error::Error for Error {}
56
57#[derive(Debug, Clone)]
58struct RequestChecksumInterceptorState {
59    /// The checksum algorithm to calculate
60    checksum_algorithm: Option<String>,
61    /// This value is set in the model on the `httpChecksum` trait
62    request_checksum_required: bool,
63}
64impl Storable for RequestChecksumInterceptorState {
65    type Storer = StoreReplace<Self>;
66}
67
68type CustomDefaultFn = Box<dyn Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static>;
69
70pub(crate) struct DefaultRequestChecksumOverride {
71    custom_default: CustomDefaultFn,
72}
73impl fmt::Debug for DefaultRequestChecksumOverride {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        f.debug_struct("DefaultRequestChecksumOverride").finish()
76    }
77}
78impl Storable for DefaultRequestChecksumOverride {
79    type Storer = StoreReplace<Self>;
80}
81impl DefaultRequestChecksumOverride {
82    pub(crate) fn new<F>(custom_default: F) -> Self
83    where
84        F: Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static,
85    {
86        Self {
87            custom_default: Box::new(custom_default),
88        }
89    }
90    pub(crate) fn custom_default(&self, original: Option<ChecksumAlgorithm>, config_bag: &ConfigBag) -> Option<ChecksumAlgorithm> {
91        (self.custom_default)(original, config_bag)
92    }
93}
94
95pub(crate) struct RequestChecksumInterceptor<AP, CM> {
96    algorithm_provider: AP,
97    checksum_mutator: CM,
98}
99
100impl<AP, CM> fmt::Debug for RequestChecksumInterceptor<AP, CM> {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        f.debug_struct("RequestChecksumInterceptor").finish()
103    }
104}
105
106impl<AP, CM> RequestChecksumInterceptor<AP, CM> {
107    pub(crate) fn new(algorithm_provider: AP, checksum_mutator: CM) -> Self {
108        Self {
109            algorithm_provider,
110            checksum_mutator,
111        }
112    }
113}
114
115impl<AP, CM> Intercept for RequestChecksumInterceptor<AP, CM>
116where
117    AP: Fn(&Input) -> (Option<String>, bool) + Send + Sync,
118    CM: Fn(&mut Request, &ConfigBag) -> Result<bool, BoxError> + Send + Sync,
119{
120    fn name(&self) -> &'static str {
121        "RequestChecksumInterceptor"
122    }
123
124    fn modify_before_serialization(
125        &self,
126        context: &mut BeforeSerializationInterceptorContextMut<'_>,
127        _runtime_components: &RuntimeComponents,
128        cfg: &mut ConfigBag,
129    ) -> Result<(), BoxError> {
130        let (checksum_algorithm, request_checksum_required) = (self.algorithm_provider)(context.input());
131
132        let mut layer = Layer::new("RequestChecksumInterceptor");
133        layer.store_put(RequestChecksumInterceptorState {
134            checksum_algorithm,
135            request_checksum_required,
136        });
137        cfg.push_layer(layer);
138
139        Ok(())
140    }
141
142    /// Calculate a checksum and modify the request to include the checksum as a header
143    /// (for in-memory request bodies) or a trailer (for streaming request bodies).
144    /// Streaming bodies must be sized or this will return an error.
145    fn modify_before_retry_loop(
146        &self,
147        context: &mut BeforeTransmitInterceptorContextMut<'_>,
148        _runtime_components: &RuntimeComponents,
149        cfg: &mut ConfigBag,
150    ) -> Result<(), BoxError> {
151        let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
152
153        let user_set_checksum_value = (self.checksum_mutator)(context.request_mut(), cfg).expect("Checksum header mutation should not fail");
154
155        // If the user manually set a checksum header we short circuit
156        if user_set_checksum_value {
157            return Ok(());
158        }
159
160        // This value is from the trait, but is needed for runtime logic
161        let request_checksum_required = state.request_checksum_required;
162
163        // If the algorithm fails to parse it is not one we support and we error
164        let checksum_algorithm = state
165            .checksum_algorithm
166            .clone()
167            .map(|s| ChecksumAlgorithm::from_str(s.as_str()))
168            .transpose()?;
169
170        // This value is set by the user on the SdkConfig to indicate their preference
171        // We provide a default here for users that use a client config instead of the SdkConfig
172        let request_checksum_calculation = cfg
173            .load::<RequestChecksumCalculation>()
174            .unwrap_or(&RequestChecksumCalculation::WhenSupported);
175
176        // Need to know if this is a presigned req because we do not calculate checksums for those.
177        let is_presigned_req = cfg.load::<PresigningMarker>().is_some();
178
179        // Determine if we actually calculate the checksum. If this is a presigned request we do not
180        // If the user setting is WhenSupported (the default) we always calculate it (because this interceptor
181        // isn't added if it isn't supported). If it is WhenRequired we only calculate it if the checksum
182        // is marked required on the trait.
183        let calculate_checksum = match (request_checksum_calculation, is_presigned_req) {
184            (_, true) => false,
185            (RequestChecksumCalculation::WhenRequired, false) => request_checksum_required,
186            (RequestChecksumCalculation::WhenSupported, false) => true,
187            _ => true,
188        };
189
190        // Calculate the checksum if necessary
191        if calculate_checksum {
192            // If a checksum override is set in the ConfigBag we use that instead (currently only used by S3Express)
193            // If we have made it this far without a checksum being set we set the default (currently Crc32)
194            let checksum_algorithm = incorporate_custom_default(checksum_algorithm, cfg).unwrap_or_default();
195
196            // Set the user-agent metric for the selected checksum algorithm
197            match checksum_algorithm {
198                ChecksumAlgorithm::Crc32 => {
199                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32);
200                }
201                ChecksumAlgorithm::Crc32c => {
202                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32c);
203                }
204                ChecksumAlgorithm::Crc64Nvme => {
205                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc64);
206                }
207                #[allow(deprecated)]
208                ChecksumAlgorithm::Md5 => {
209                    tracing::warn!(more_info = "Unsupported ChecksumAlgorithm MD5 set");
210                }
211                ChecksumAlgorithm::Sha1 => {
212                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha1);
213                }
214                ChecksumAlgorithm::Sha256 => {
215                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha256);
216                }
217                unsupported => tracing::warn!(
218                    more_info = "Unsupported value of ChecksumAlgorithm detected when setting user-agent metrics",
219                    unsupported = ?unsupported),
220            }
221
222            let request = context.request_mut();
223            add_checksum_for_request_body(request, checksum_algorithm, cfg)?;
224        }
225
226        Ok(())
227    }
228
229    /// Set the user-agent metrics for `RequestChecksumCalculation` here to avoid ownership issues
230    /// with the mutable borrow of cfg in `modify_before_signing`
231    fn read_after_serialization(
232        &self,
233        _context: &aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextRef<'_>,
234        _runtime_components: &RuntimeComponents,
235        cfg: &mut ConfigBag,
236    ) -> Result<(), BoxError> {
237        let request_checksum_calculation = cfg
238            .load::<RequestChecksumCalculation>()
239            .unwrap_or(&RequestChecksumCalculation::WhenSupported);
240
241        match request_checksum_calculation {
242            RequestChecksumCalculation::WhenSupported => {
243                cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenSupported);
244            }
245            RequestChecksumCalculation::WhenRequired => {
246                cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenRequired);
247            }
248            unsupported => tracing::warn!(
249                    more_info = "Unsupported value of RequestChecksumCalculation when setting user-agent metrics",
250                    unsupported = ?unsupported),
251        };
252
253        Ok(())
254    }
255}
256
257fn incorporate_custom_default(checksum: Option<ChecksumAlgorithm>, cfg: &ConfigBag) -> Option<ChecksumAlgorithm> {
258    match cfg.load::<DefaultRequestChecksumOverride>() {
259        Some(checksum_override) => checksum_override.custom_default(checksum, cfg),
260        None => checksum,
261    }
262}
263
264fn add_checksum_for_request_body(request: &mut HttpRequest, checksum_algorithm: ChecksumAlgorithm, cfg: &mut ConfigBag) -> Result<(), BoxError> {
265    match request.body().bytes() {
266        // Body is in-memory: read it and insert the checksum as a header.
267        Some(data) => {
268            let mut checksum = checksum_algorithm.into_impl();
269
270            // If the header has not already been set we set it. If it was already set by the user
271            // we do nothing and maintain their set value.
272            if request.headers().get(checksum.header_name()).is_none() {
273                tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
274                checksum.update(data);
275
276                request.headers_mut().insert(checksum.header_name(), checksum.header_value());
277            }
278        }
279        // Body is streaming: wrap the body so it will emit a checksum as a trailer.
280        None => {
281            tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
282            cfg.interceptor_state().store_put(PayloadSigningOverride::StreamingUnsignedPayloadTrailer);
283            wrap_streaming_request_body_in_checksum_calculating_body(request, checksum_algorithm)?;
284        }
285    }
286    Ok(())
287}
288
289fn wrap_streaming_request_body_in_checksum_calculating_body(
290    request: &mut HttpRequest,
291    checksum_algorithm: ChecksumAlgorithm,
292) -> Result<(), BuildError> {
293    let checksum = checksum_algorithm.into_impl();
294
295    // If the user already set the header value then do nothing and return early
296    if request.headers().get(checksum.header_name()).is_some() {
297        return Ok(());
298    }
299
300    let original_body_size = request
301        .body()
302        .size_hint()
303        .exact()
304        .ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
305
306    let mut body = {
307        let body = mem::replace(request.body_mut(), SdkBody::taken());
308
309        body.map(move |body| {
310            let checksum = checksum_algorithm.into_impl();
311            let trailer_len = HttpChecksum::size(checksum.as_ref());
312            let body = calculate::ChecksumBody::new(body, checksum);
313            let aws_chunked_body_options = AwsChunkedBodyOptions::new(original_body_size, vec![trailer_len]);
314
315            let body = AwsChunkedBody::new(body, aws_chunked_body_options);
316
317            SdkBody::from_body_0_4(body)
318        })
319    };
320
321    let encoded_content_length = body.size_hint().exact().ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
322
323    let headers = request.headers_mut();
324
325    headers.insert(http::header::HeaderName::from_static("x-amz-trailer"), checksum.header_name());
326
327    headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(encoded_content_length));
328    headers.insert(
329        http::header::HeaderName::from_static("x-amz-decoded-content-length"),
330        HeaderValue::from(original_body_size),
331    );
332    headers.insert(
333        http::header::CONTENT_ENCODING,
334        HeaderValue::from_str(AWS_CHUNKED)
335            .map_err(BuildError::other)
336            .expect("\"aws-chunked\" will always be a valid HeaderValue"),
337    );
338
339    mem::swap(request.body_mut(), &mut body);
340
341    Ok(())
342}
343
344#[cfg(test)]
345mod tests {
346    use crate::http_request_checksum::wrap_streaming_request_body_in_checksum_calculating_body;
347    use aws_smithy_checksums::ChecksumAlgorithm;
348    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
349    use aws_smithy_types::base64;
350    use aws_smithy_types::body::SdkBody;
351    use aws_smithy_types::byte_stream::ByteStream;
352    use bytes::BytesMut;
353    use http_body::Body;
354    use tempfile::NamedTempFile;
355
356    #[tokio::test]
357    async fn test_checksum_body_is_retryable() {
358        let input_text = "Hello world";
359        let chunk_len_hex = format!("{:X}", input_text.len());
360        let mut request: HttpRequest = http::Request::builder()
361            .body(SdkBody::retryable(move || SdkBody::from(input_text)))
362            .unwrap()
363            .try_into()
364            .unwrap();
365
366        // ensure original SdkBody is retryable
367        assert!(request.body().try_clone().is_some());
368
369        let checksum_algorithm: ChecksumAlgorithm = "crc32".parse().unwrap();
370        wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
371
372        // ensure wrapped SdkBody is retryable
373        let mut body = request.body().try_clone().expect("body is retryable");
374
375        let mut body_data = BytesMut::new();
376        while let Some(data) = body.data().await {
377            body_data.extend_from_slice(&data.unwrap())
378        }
379        let body = std::str::from_utf8(&body_data).unwrap();
380        assert_eq!(
381            format!("{chunk_len_hex}\r\n{input_text}\r\n0\r\nx-amz-checksum-crc32:i9aeUg==\r\n\r\n"),
382            body
383        );
384    }
385
386    #[tokio::test]
387    async fn test_checksum_body_from_file_is_retryable() {
388        use std::io::Write;
389        let mut file = NamedTempFile::new().unwrap();
390        let checksum_algorithm: ChecksumAlgorithm = "crc32c".parse().unwrap();
391
392        let mut crc32c_checksum = checksum_algorithm.into_impl();
393        for i in 0..10000 {
394            let line = format!("This is a large file created for testing purposes {}", i);
395            file.as_file_mut().write_all(line.as_bytes()).unwrap();
396            crc32c_checksum.update(line.as_bytes());
397        }
398        let crc32c_checksum = crc32c_checksum.finalize();
399
400        let mut request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
401
402        // ensure original SdkBody is retryable
403        assert!(request.body().try_clone().is_some());
404
405        wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
406
407        // ensure wrapped SdkBody is retryable
408        let mut body = request.body().try_clone().expect("body is retryable");
409
410        let mut body_data = BytesMut::new();
411        while let Some(data) = body.data().await {
412            body_data.extend_from_slice(&data.unwrap())
413        }
414        let body = std::str::from_utf8(&body_data).unwrap();
415        let expected_checksum = base64::encode(&crc32c_checksum);
416        let expected = format!("This is a large file created for testing purposes 9999\r\n0\r\nx-amz-checksum-crc32c:{expected_checksum}\r\n\r\n");
417        assert!(body.ends_with(&expected), "expected {body} to end with '{expected}'");
418    }
419}