aws_sdk_s3/
http_response_checksum.rs
1#![allow(dead_code)]
8
9use aws_smithy_checksums::ChecksumAlgorithm;
12use aws_smithy_runtime::client::sdk_feature::SmithySdkFeature;
13use aws_smithy_runtime_api::box_error::BoxError;
14use aws_smithy_runtime_api::client::interceptors::context::{
15 BeforeDeserializationInterceptorContextMut, BeforeSerializationInterceptorContextMut, Input,
16};
17use aws_smithy_runtime_api::client::interceptors::Intercept;
18use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
19use aws_smithy_runtime_api::http::Headers;
20use aws_smithy_types::body::SdkBody;
21use aws_smithy_types::checksum_config::ResponseChecksumValidation;
22use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
23use std::{fmt, mem};
24
25#[derive(Debug)]
26struct ResponseChecksumInterceptorState {
27 validation_enabled: bool,
28}
29impl Storable for ResponseChecksumInterceptorState {
30 type Storer = StoreReplace<Self>;
31}
32
33pub(crate) struct ResponseChecksumInterceptor<VE, CM> {
34 response_algorithms: &'static [&'static str],
35 validation_enabled: VE,
36 checksum_mutator: CM,
37}
38
39impl<VE, CM> fmt::Debug for ResponseChecksumInterceptor<VE, CM> {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 f.debug_struct("ResponseChecksumInterceptor")
42 .field("response_algorithms", &self.response_algorithms)
43 .finish()
44 }
45}
46
47impl<VE, CM> ResponseChecksumInterceptor<VE, CM> {
48 pub(crate) fn new(response_algorithms: &'static [&'static str], validation_enabled: VE, checksum_mutator: CM) -> Self {
49 Self {
50 response_algorithms,
51 validation_enabled,
52 checksum_mutator,
53 }
54 }
55}
56
57impl<VE, CM> Intercept for ResponseChecksumInterceptor<VE, CM>
58where
59 VE: Fn(&Input) -> bool + Send + Sync,
60 CM: Fn(&mut Input, &ConfigBag) -> Result<(), BoxError> + Send + Sync,
61{
62 fn name(&self) -> &'static str {
63 "ResponseChecksumInterceptor"
64 }
65
66 fn modify_before_serialization(
67 &self,
68 context: &mut BeforeSerializationInterceptorContextMut<'_>,
69 _runtime_components: &RuntimeComponents,
70 cfg: &mut ConfigBag,
71 ) -> Result<(), BoxError> {
72 (self.checksum_mutator)(context.input_mut(), cfg)?;
73 let validation_enabled = (self.validation_enabled)(context.input());
74
75 let mut layer = Layer::new("ResponseChecksumInterceptor");
76 layer.store_put(ResponseChecksumInterceptorState { validation_enabled });
77 cfg.push_layer(layer);
78
79 let response_checksum_validation = cfg
80 .load::<ResponseChecksumValidation>()
81 .unwrap_or(&ResponseChecksumValidation::WhenSupported);
82
83 match response_checksum_validation {
85 ResponseChecksumValidation::WhenSupported => {
86 cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsResWhenSupported);
87 }
88 ResponseChecksumValidation::WhenRequired => {
89 cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsResWhenRequired);
90 }
91 unsupported => tracing::warn!(
92 more_info = "Unsupported value of ResponseChecksumValidation when setting user-agent metrics",
93 unsupported = ?unsupported),
94 };
95
96 Ok(())
97 }
98
99 fn modify_before_deserialization(
100 &self,
101 context: &mut BeforeDeserializationInterceptorContextMut<'_>,
102 _runtime_components: &RuntimeComponents,
103 cfg: &mut ConfigBag,
104 ) -> Result<(), BoxError> {
105 let state = cfg
106 .load::<ResponseChecksumInterceptorState>()
107 .expect("set in `read_before_serialization`");
108
109 let response_checksum_validation = cfg
112 .load::<ResponseChecksumValidation>()
113 .unwrap_or(&ResponseChecksumValidation::WhenSupported);
114
115 let validation_enabled = if !state.validation_enabled {
120 match response_checksum_validation {
121 ResponseChecksumValidation::WhenRequired => false,
122 ResponseChecksumValidation::WhenSupported => true,
123 _ => true,
124 }
125 } else {
126 true
127 };
128
129 if validation_enabled {
130 let response = context.response_mut();
131 let maybe_checksum_headers = check_headers_for_precalculated_checksum(response.headers(), self.response_algorithms);
132
133 if let Some((checksum_algorithm, precalculated_checksum)) = maybe_checksum_headers {
134 let mut body = SdkBody::taken();
135 mem::swap(&mut body, response.body_mut());
136
137 let mut body = wrap_body_with_checksum_validator(body, checksum_algorithm, precalculated_checksum);
138 mem::swap(&mut body, response.body_mut());
139 }
140 }
141
142 Ok(())
143 }
144}
145
146pub(crate) fn wrap_body_with_checksum_validator(
150 body: SdkBody,
151 checksum_algorithm: ChecksumAlgorithm,
152 precalculated_checksum: bytes::Bytes,
153) -> SdkBody {
154 use aws_smithy_checksums::body::validate;
155
156 body.map(move |body| {
157 SdkBody::from_body_0_4(validate::ChecksumBody::new(
158 body,
159 checksum_algorithm.into_impl(),
160 precalculated_checksum.clone(),
161 ))
162 })
163}
164
165pub(crate) fn check_headers_for_precalculated_checksum(headers: &Headers, response_algorithms: &[&str]) -> Option<(ChecksumAlgorithm, bytes::Bytes)> {
169 let checksum_algorithms_to_check = aws_smithy_checksums::http::CHECKSUM_ALGORITHMS_IN_PRIORITY_ORDER
170 .into_iter()
171 .flat_map(|algo| {
174 for res_algo in response_algorithms {
176 if algo.eq_ignore_ascii_case(res_algo) {
177 return Some(algo);
178 }
179 }
180
181 None
182 });
183
184 for checksum_algorithm in checksum_algorithms_to_check {
185 let checksum_algorithm: ChecksumAlgorithm = checksum_algorithm
186 .parse()
187 .expect("CHECKSUM_ALGORITHMS_IN_PRIORITY_ORDER only contains valid checksum algorithm names");
188 if let Some(base64_encoded_precalculated_checksum) = headers.get(checksum_algorithm.into_impl().header_name()) {
189 if is_part_level_checksum(base64_encoded_precalculated_checksum) {
191 tracing::warn!(
192 more_info = "See https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html#large-object-checksums for more information.",
193 "This checksum is a part-level checksum which can't be validated by the Rust SDK. Disable checksum validation for this request to fix this warning.",
194 );
195
196 return None;
197 }
198
199 let precalculated_checksum = match aws_smithy_types::base64::decode(base64_encoded_precalculated_checksum) {
200 Ok(decoded_checksum) => decoded_checksum.into(),
201 Err(_) => {
202 tracing::error!("Checksum received from server could not be base64 decoded. No checksum validation will be performed.");
203 return None;
204 }
205 };
206
207 return Some((checksum_algorithm, precalculated_checksum));
208 }
209 }
210
211 None
212}
213
214fn is_part_level_checksum(checksum: &str) -> bool {
215 let mut found_number = false;
216 let mut found_dash = false;
217
218 for ch in checksum.chars().rev() {
219 if ch.is_ascii_digit() {
221 found_number = true;
222 continue;
223 }
224
225 if found_number && ch == '-' {
227 if found_dash {
228 return false;
230 }
231
232 found_dash = true;
233 continue;
234 }
235
236 break;
237 }
238
239 found_number && found_dash
240}
241
242#[cfg(test)]
243mod tests {
244 use super::{is_part_level_checksum, wrap_body_with_checksum_validator};
245 use aws_smithy_types::body::SdkBody;
246 use aws_smithy_types::byte_stream::ByteStream;
247 use aws_smithy_types::error::display::DisplayErrorContext;
248 use bytes::Bytes;
249
250 #[tokio::test]
251 async fn test_build_checksum_validated_body_works() {
252 let checksum_algorithm = "crc32".parse().unwrap();
253 let input_text = "Hello world";
254 let precalculated_checksum = Bytes::from_static(&[0x8b, 0xd6, 0x9e, 0x52]);
255 let body = ByteStream::new(SdkBody::from(input_text));
256
257 let body = body.map(move |sdk_body| wrap_body_with_checksum_validator(sdk_body, checksum_algorithm, precalculated_checksum.clone()));
258
259 let mut validated_body = Vec::new();
260 if let Err(e) = tokio::io::copy(&mut body.into_async_read(), &mut validated_body).await {
261 tracing::error!("{}", DisplayErrorContext(&e));
262 panic!("checksum validation has failed");
263 };
264 let body = std::str::from_utf8(&validated_body).unwrap();
265
266 assert_eq!(input_text, body);
267 }
268
269 #[test]
270 fn test_is_multipart_object_checksum() {
271 assert!(!is_part_level_checksum("abcd"));
273 assert!(!is_part_level_checksum("abcd="));
274 assert!(!is_part_level_checksum("abcd=="));
275 assert!(!is_part_level_checksum("1234"));
276 assert!(!is_part_level_checksum("1234="));
277 assert!(!is_part_level_checksum("1234=="));
278 assert!(is_part_level_checksum("abcd-1"));
280 assert!(is_part_level_checksum("abcd=-12"));
281 assert!(is_part_level_checksum("abcd12-134"));
282 assert!(is_part_level_checksum("abcd==-10000"));
283 assert!(!is_part_level_checksum(""));
285 assert!(!is_part_level_checksum("Spaces? In my header values?"));
286 assert!(!is_part_level_checksum("abcd==-134!#{!#"));
287 assert!(!is_part_level_checksum("abcd==-"));
288 assert!(!is_part_level_checksum("abcd==--11"));
289 assert!(!is_part_level_checksum("abcd==-AA"));
290 }
291
292 #[test]
293 fn part_level_checksum_detection_works() {
294 let a_real_checksum = is_part_level_checksum("C9A5A6878D97B48CC965C1E41859F034-14");
295 assert!(a_real_checksum);
296 let close_but_not_quite = is_part_level_checksum("a4-");
297 assert!(!close_but_not_quite);
298 let backwards = is_part_level_checksum("14-C9A5A6878D97B48CC965C1E41859F034");
299 assert!(!backwards);
300 let double_dash = is_part_level_checksum("C9A5A6878D97B48CC965C1E41859F03-4-14");
301 assert!(!double_dash);
302 }
303}