aws_sdk_s3/
http_response_checksum.rs

1// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT.
2/*
3 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7#![allow(dead_code)]
8
9//! Interceptor for handling Smithy `@httpChecksum` response checksumming
10
11use aws_smithy_checksums::ChecksumAlgorithm;
12use aws_smithy_runtime::client::sdk_feature::SmithySdkFeature;
13use aws_smithy_runtime_api::box_error::BoxError;
14use aws_smithy_runtime_api::client::interceptors::context::{
15    BeforeDeserializationInterceptorContextMut, BeforeSerializationInterceptorContextMut, Input,
16};
17use aws_smithy_runtime_api::client::interceptors::Intercept;
18use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
19use aws_smithy_runtime_api::http::Headers;
20use aws_smithy_types::body::SdkBody;
21use aws_smithy_types::checksum_config::ResponseChecksumValidation;
22use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
23use std::{fmt, mem};
24
25#[derive(Debug)]
26struct ResponseChecksumInterceptorState {
27    validation_enabled: bool,
28}
29impl Storable for ResponseChecksumInterceptorState {
30    type Storer = StoreReplace<Self>;
31}
32
33pub(crate) struct ResponseChecksumInterceptor<VE, CM> {
34    response_algorithms: &'static [&'static str],
35    validation_enabled: VE,
36    checksum_mutator: CM,
37}
38
39impl<VE, CM> fmt::Debug for ResponseChecksumInterceptor<VE, CM> {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        f.debug_struct("ResponseChecksumInterceptor")
42            .field("response_algorithms", &self.response_algorithms)
43            .finish()
44    }
45}
46
47impl<VE, CM> ResponseChecksumInterceptor<VE, CM> {
48    pub(crate) fn new(response_algorithms: &'static [&'static str], validation_enabled: VE, checksum_mutator: CM) -> Self {
49        Self {
50            response_algorithms,
51            validation_enabled,
52            checksum_mutator,
53        }
54    }
55}
56
57impl<VE, CM> Intercept for ResponseChecksumInterceptor<VE, CM>
58where
59    VE: Fn(&Input) -> bool + Send + Sync,
60    CM: Fn(&mut Input, &ConfigBag) -> Result<(), BoxError> + Send + Sync,
61{
62    fn name(&self) -> &'static str {
63        "ResponseChecksumInterceptor"
64    }
65
66    fn modify_before_serialization(
67        &self,
68        context: &mut BeforeSerializationInterceptorContextMut<'_>,
69        _runtime_components: &RuntimeComponents,
70        cfg: &mut ConfigBag,
71    ) -> Result<(), BoxError> {
72        (self.checksum_mutator)(context.input_mut(), cfg)?;
73        let validation_enabled = (self.validation_enabled)(context.input());
74
75        let mut layer = Layer::new("ResponseChecksumInterceptor");
76        layer.store_put(ResponseChecksumInterceptorState { validation_enabled });
77        cfg.push_layer(layer);
78
79        let response_checksum_validation = cfg
80            .load::<ResponseChecksumValidation>()
81            .unwrap_or(&ResponseChecksumValidation::WhenSupported);
82
83        // Set the user-agent feature metric for the response checksum config
84        match response_checksum_validation {
85            ResponseChecksumValidation::WhenSupported => {
86                cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsResWhenSupported);
87            }
88            ResponseChecksumValidation::WhenRequired => {
89                cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsResWhenRequired);
90            }
91            unsupported => tracing::warn!(
92                more_info = "Unsupported value of ResponseChecksumValidation when setting user-agent metrics",
93                unsupported = ?unsupported),
94        };
95
96        Ok(())
97    }
98
99    fn modify_before_deserialization(
100        &self,
101        context: &mut BeforeDeserializationInterceptorContextMut<'_>,
102        _runtime_components: &RuntimeComponents,
103        cfg: &mut ConfigBag,
104    ) -> Result<(), BoxError> {
105        let state = cfg
106            .load::<ResponseChecksumInterceptorState>()
107            .expect("set in `read_before_serialization`");
108
109        // This value is set by the user on the SdkConfig to indicate their preference
110        // We provide a default here for users that use a client config instead of the SdkConfig
111        let response_checksum_validation = cfg
112            .load::<ResponseChecksumValidation>()
113            .unwrap_or(&ResponseChecksumValidation::WhenSupported);
114
115        // If validation has not been explicitly enabled we check the ResponseChecksumValidation
116        // from the SdkConfig. If it is WhenSupported (or unknown) we enable validation and if it
117        // is WhenRequired we leave it disabled since there is no way to indicate that a response
118        // checksum is required.
119        let validation_enabled = if !state.validation_enabled {
120            match response_checksum_validation {
121                ResponseChecksumValidation::WhenRequired => false,
122                ResponseChecksumValidation::WhenSupported => true,
123                _ => true,
124            }
125        } else {
126            true
127        };
128
129        if validation_enabled {
130            let response = context.response_mut();
131            let maybe_checksum_headers = check_headers_for_precalculated_checksum(response.headers(), self.response_algorithms);
132
133            if let Some((checksum_algorithm, precalculated_checksum)) = maybe_checksum_headers {
134                let mut body = SdkBody::taken();
135                mem::swap(&mut body, response.body_mut());
136
137                let mut body = wrap_body_with_checksum_validator(body, checksum_algorithm, precalculated_checksum);
138                mem::swap(&mut body, response.body_mut());
139            }
140        }
141
142        Ok(())
143    }
144}
145
146/// Given an `SdkBody`, a `aws_smithy_checksums::ChecksumAlgorithm`, and a pre-calculated checksum,
147/// return an `SdkBody` where the body will processed with the checksum algorithm and checked
148/// against the pre-calculated checksum.
149pub(crate) fn wrap_body_with_checksum_validator(
150    body: SdkBody,
151    checksum_algorithm: ChecksumAlgorithm,
152    precalculated_checksum: bytes::Bytes,
153) -> SdkBody {
154    use aws_smithy_checksums::body::validate;
155
156    body.map(move |body| {
157        SdkBody::from_body_0_4(validate::ChecksumBody::new(
158            body,
159            checksum_algorithm.into_impl(),
160            precalculated_checksum.clone(),
161        ))
162    })
163}
164
165/// Given a `HeaderMap`, extract any checksum included in the headers as `Some(Bytes)`.
166/// If no checksum header is set, return `None`. If multiple checksum headers are set, the one that
167/// is fastest to compute will be chosen.
168pub(crate) fn check_headers_for_precalculated_checksum(headers: &Headers, response_algorithms: &[&str]) -> Option<(ChecksumAlgorithm, bytes::Bytes)> {
169    let checksum_algorithms_to_check = aws_smithy_checksums::http::CHECKSUM_ALGORITHMS_IN_PRIORITY_ORDER
170        .into_iter()
171        // Process list of algorithms, from fastest to slowest, that may have been used to checksum
172        // the response body, ignoring any that aren't marked as supported algorithms by the model.
173        .flat_map(|algo| {
174            // For loop is necessary b/c the compiler doesn't infer the correct lifetimes for iter().find()
175            for res_algo in response_algorithms {
176                if algo.eq_ignore_ascii_case(res_algo) {
177                    return Some(algo);
178                }
179            }
180
181            None
182        });
183
184    for checksum_algorithm in checksum_algorithms_to_check {
185        let checksum_algorithm: ChecksumAlgorithm = checksum_algorithm
186            .parse()
187            .expect("CHECKSUM_ALGORITHMS_IN_PRIORITY_ORDER only contains valid checksum algorithm names");
188        if let Some(base64_encoded_precalculated_checksum) = headers.get(checksum_algorithm.into_impl().header_name()) {
189            // S3 needs special handling for checksums of objects uploaded with `MultiPartUpload`.
190            if is_part_level_checksum(base64_encoded_precalculated_checksum) {
191                tracing::warn!(
192                      more_info = "See https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html#large-object-checksums for more information.",
193                      "This checksum is a part-level checksum which can't be validated by the Rust SDK. Disable checksum validation for this request to fix this warning.",
194                  );
195
196                return None;
197            }
198
199            let precalculated_checksum = match aws_smithy_types::base64::decode(base64_encoded_precalculated_checksum) {
200                Ok(decoded_checksum) => decoded_checksum.into(),
201                Err(_) => {
202                    tracing::error!("Checksum received from server could not be base64 decoded. No checksum validation will be performed.");
203                    return None;
204                }
205            };
206
207            return Some((checksum_algorithm, precalculated_checksum));
208        }
209    }
210
211    None
212}
213
214fn is_part_level_checksum(checksum: &str) -> bool {
215    let mut found_number = false;
216    let mut found_dash = false;
217
218    for ch in checksum.chars().rev() {
219        // this could be bad
220        if ch.is_ascii_digit() {
221            found_number = true;
222            continue;
223        }
224
225        // We saw a number first followed by the dash, yup, it's a part-level checksum
226        if found_number && ch == '-' {
227            if found_dash {
228                // Found a second dash?? This isn't a part-level checksum.
229                return false;
230            }
231
232            found_dash = true;
233            continue;
234        }
235
236        break;
237    }
238
239    found_number && found_dash
240}
241
242#[cfg(test)]
243mod tests {
244    use super::{is_part_level_checksum, wrap_body_with_checksum_validator};
245    use aws_smithy_types::body::SdkBody;
246    use aws_smithy_types::byte_stream::ByteStream;
247    use aws_smithy_types::error::display::DisplayErrorContext;
248    use bytes::Bytes;
249
250    #[tokio::test]
251    async fn test_build_checksum_validated_body_works() {
252        let checksum_algorithm = "crc32".parse().unwrap();
253        let input_text = "Hello world";
254        let precalculated_checksum = Bytes::from_static(&[0x8b, 0xd6, 0x9e, 0x52]);
255        let body = ByteStream::new(SdkBody::from(input_text));
256
257        let body = body.map(move |sdk_body| wrap_body_with_checksum_validator(sdk_body, checksum_algorithm, precalculated_checksum.clone()));
258
259        let mut validated_body = Vec::new();
260        if let Err(e) = tokio::io::copy(&mut body.into_async_read(), &mut validated_body).await {
261            tracing::error!("{}", DisplayErrorContext(&e));
262            panic!("checksum validation has failed");
263        };
264        let body = std::str::from_utf8(&validated_body).unwrap();
265
266        assert_eq!(input_text, body);
267    }
268
269    #[test]
270    fn test_is_multipart_object_checksum() {
271        // These ARE NOT part-level checksums
272        assert!(!is_part_level_checksum("abcd"));
273        assert!(!is_part_level_checksum("abcd="));
274        assert!(!is_part_level_checksum("abcd=="));
275        assert!(!is_part_level_checksum("1234"));
276        assert!(!is_part_level_checksum("1234="));
277        assert!(!is_part_level_checksum("1234=="));
278        // These ARE part-level checksums
279        assert!(is_part_level_checksum("abcd-1"));
280        assert!(is_part_level_checksum("abcd=-12"));
281        assert!(is_part_level_checksum("abcd12-134"));
282        assert!(is_part_level_checksum("abcd==-10000"));
283        // These are gibberish and shouldn't be regarded as a part-level checksum
284        assert!(!is_part_level_checksum(""));
285        assert!(!is_part_level_checksum("Spaces? In my header values?"));
286        assert!(!is_part_level_checksum("abcd==-134!#{!#"));
287        assert!(!is_part_level_checksum("abcd==-"));
288        assert!(!is_part_level_checksum("abcd==--11"));
289        assert!(!is_part_level_checksum("abcd==-AA"));
290    }
291
292    #[test]
293    fn part_level_checksum_detection_works() {
294        let a_real_checksum = is_part_level_checksum("C9A5A6878D97B48CC965C1E41859F034-14");
295        assert!(a_real_checksum);
296        let close_but_not_quite = is_part_level_checksum("a4-");
297        assert!(!close_but_not_quite);
298        let backwards = is_part_level_checksum("14-C9A5A6878D97B48CC965C1E41859F034");
299        assert!(!backwards);
300        let double_dash = is_part_level_checksum("C9A5A6878D97B48CC965C1E41859F03-4-14");
301        assert!(!double_dash);
302    }
303}