aws_smithy_checksums/body/
calculate.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Functionality for calculating the checksum of an HTTP body and emitting it as trailers.
7
8use crate::http::HttpChecksum;
9
10use aws_smithy_http::header::append_merge_header_maps;
11use aws_smithy_types::body::SdkBody;
12
13use http::HeaderMap;
14use http_body::SizeHint;
15use pin_project_lite::pin_project;
16
17use std::pin::Pin;
18use std::task::{Context, Poll};
19
20pin_project! {
21    /// A body-wrapper that will calculate the `InnerBody`'s checksum and emit it as a trailer.
22    pub struct ChecksumBody<InnerBody> {
23            #[pin]
24            body: InnerBody,
25            checksum: Option<Box<dyn HttpChecksum>>,
26    }
27}
28
29impl ChecksumBody<SdkBody> {
30    /// Given an `SdkBody` and a `Box<dyn HttpChecksum>`, create a new `ChecksumBody<SdkBody>`.
31    pub fn new(body: SdkBody, checksum: Box<dyn HttpChecksum>) -> Self {
32        Self {
33            body,
34            checksum: Some(checksum),
35        }
36    }
37}
38
39impl http_body::Body for ChecksumBody<SdkBody> {
40    type Data = bytes::Bytes;
41    type Error = aws_smithy_types::body::Error;
42
43    fn poll_data(
44        self: Pin<&mut Self>,
45        cx: &mut Context<'_>,
46    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
47        let this = self.project();
48        match this.checksum {
49            Some(checksum) => {
50                let poll_res = this.body.poll_data(cx);
51                if let Poll::Ready(Some(Ok(data))) = &poll_res {
52                    checksum.update(data);
53                }
54
55                poll_res
56            }
57            None => unreachable!("This can only fail if poll_data is called again after poll_trailers, which is invalid"),
58        }
59    }
60
61    fn poll_trailers(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64    ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
65        let this = self.project();
66        let poll_res = this.body.poll_trailers(cx);
67
68        if let Poll::Ready(Ok(maybe_inner_trailers)) = poll_res {
69            let checksum_headers = if let Some(checksum) = this.checksum.take() {
70                checksum.headers()
71            } else {
72                return Poll::Ready(Ok(None));
73            };
74
75            return match maybe_inner_trailers {
76                Some(inner_trailers) => Poll::Ready(Ok(Some(append_merge_header_maps(
77                    inner_trailers,
78                    checksum_headers,
79                )))),
80                None => Poll::Ready(Ok(Some(checksum_headers))),
81            };
82        }
83
84        poll_res
85    }
86
87    fn is_end_stream(&self) -> bool {
88        // If inner body is finished and we've already consumed the checksum then we must be
89        // at the end of the stream.
90        self.body.is_end_stream() && self.checksum.is_none()
91    }
92
93    fn size_hint(&self) -> SizeHint {
94        self.body.size_hint()
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::ChecksumBody;
101    use crate::{http::CRC_32_HEADER_NAME, ChecksumAlgorithm, CRC_32_NAME};
102    use aws_smithy_types::base64;
103    use aws_smithy_types::body::SdkBody;
104    use bytes::Buf;
105    use bytes_utils::SegmentedBuf;
106    use http_body::Body;
107    use std::fmt::Write;
108    use std::io::Read;
109
110    fn header_value_as_checksum_string(header_value: &http::HeaderValue) -> String {
111        let decoded_checksum = base64::decode(header_value.to_str().unwrap()).unwrap();
112        let decoded_checksum = decoded_checksum
113            .into_iter()
114            .fold(String::new(), |mut acc, byte| {
115                write!(acc, "{byte:02X?}").expect("string will always be writeable");
116                acc
117            });
118
119        format!("0x{}", decoded_checksum)
120    }
121
122    #[tokio::test]
123    async fn test_checksum_body() {
124        let input_text = "This is some test text for an SdkBody";
125        let body = SdkBody::from(input_text);
126        let checksum = CRC_32_NAME
127            .parse::<ChecksumAlgorithm>()
128            .unwrap()
129            .into_impl();
130        let mut body = ChecksumBody::new(body, checksum);
131
132        let mut output = SegmentedBuf::new();
133        while let Some(buf) = body.data().await {
134            output.push(buf.unwrap());
135        }
136
137        let mut output_text = String::new();
138        output
139            .reader()
140            .read_to_string(&mut output_text)
141            .expect("Doesn't cause IO errors");
142        // Verify data is complete and unaltered
143        assert_eq!(input_text, output_text);
144
145        let trailers = body
146            .trailers()
147            .await
148            .expect("checksum generation was without error")
149            .expect("trailers were set");
150        let checksum_trailer = trailers
151            .get(CRC_32_HEADER_NAME)
152            .expect("trailers contain crc32 checksum");
153        let checksum_trailer = header_value_as_checksum_string(checksum_trailer);
154
155        // Known correct checksum for the input "This is some test text for an SdkBody"
156        assert_eq!("0x99B01F72", checksum_trailer);
157    }
158}