aws_smithy_checksums/body/
validate.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Functionality for validating an HTTP body against a given precalculated checksum and emitting an
7//! error if it doesn't match.
8
9use crate::http::HttpChecksum;
10
11use aws_smithy_types::body::SdkBody;
12
13use bytes::Bytes;
14use http::{HeaderMap, HeaderValue};
15use http_body::SizeHint;
16use pin_project_lite::pin_project;
17
18use std::fmt::Display;
19use std::pin::Pin;
20use std::task::{Context, Poll};
21
22pin_project! {
23    /// A body-wrapper that will calculate the `InnerBody`'s checksum and emit an error if it
24    /// doesn't match the precalculated checksum.
25    pub struct ChecksumBody<InnerBody> {
26        #[pin]
27        inner: InnerBody,
28        checksum: Option<Box<dyn HttpChecksum>>,
29        precalculated_checksum: Bytes,
30    }
31}
32
33impl ChecksumBody<SdkBody> {
34    /// Given an `SdkBody`, a `Box<dyn HttpChecksum>`, and a precalculated checksum represented
35    /// as `Bytes`, create a new `ChecksumBody<SdkBody>`.
36    pub fn new(
37        body: SdkBody,
38        checksum: Box<dyn HttpChecksum>,
39        precalculated_checksum: Bytes,
40    ) -> Self {
41        Self {
42            inner: body,
43            checksum: Some(checksum),
44            precalculated_checksum,
45        }
46    }
47
48    fn poll_inner(
49        self: Pin<&mut Self>,
50        cx: &mut Context<'_>,
51    ) -> Poll<Option<Result<Bytes, aws_smithy_types::body::Error>>> {
52        use http_body::Body;
53
54        let this = self.project();
55        let checksum = this.checksum;
56
57        match this.inner.poll_data(cx) {
58            Poll::Ready(Some(Ok(data))) => {
59                tracing::trace!(
60                    "reading {} bytes from the body and updating the checksum calculation",
61                    data.len()
62                );
63                let checksum = match checksum.as_mut() {
64                    Some(checksum) => checksum,
65                    None => {
66                        unreachable!("The checksum must exist because it's only taken out once the inner body has been completely polled.");
67                    }
68                };
69
70                checksum.update(&data);
71                Poll::Ready(Some(Ok(data)))
72            }
73            // Once the inner body has stopped returning data, check the checksum
74            // and return an error if it doesn't match.
75            Poll::Ready(None) => {
76                tracing::trace!("finished reading from body, calculating final checksum");
77                let checksum = match checksum.take() {
78                    Some(checksum) => checksum,
79                    None => {
80                        // If the checksum was already taken and this was polled again anyways,
81                        // then return nothing
82                        return Poll::Ready(None);
83                    }
84                };
85
86                let actual_checksum = checksum.finalize();
87                if *this.precalculated_checksum == actual_checksum {
88                    Poll::Ready(None)
89                } else {
90                    // So many parens it's starting to look like LISP
91                    Poll::Ready(Some(Err(Box::new(Error::ChecksumMismatch {
92                        expected: this.precalculated_checksum.clone(),
93                        actual: actual_checksum,
94                    }))))
95                }
96            }
97            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
98            Poll::Pending => Poll::Pending,
99        }
100    }
101}
102
103/// Errors related to checksum calculation and validation
104#[derive(Debug, Eq, PartialEq)]
105#[non_exhaustive]
106pub enum Error {
107    /// The actual checksum didn't match the expected checksum. The checksummed data has been
108    /// altered since the expected checksum was calculated.
109    ChecksumMismatch { expected: Bytes, actual: Bytes },
110}
111
112impl Display for Error {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
114        match self {
115            Error::ChecksumMismatch { expected, actual } => write!(
116                f,
117                "body checksum mismatch. expected body checksum to be {} but it was {}",
118                hex::encode(expected),
119                hex::encode(actual)
120            ),
121        }
122    }
123}
124
125impl std::error::Error for Error {}
126
127impl http_body::Body for ChecksumBody<SdkBody> {
128    type Data = Bytes;
129    type Error = aws_smithy_types::body::Error;
130
131    fn poll_data(
132        self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
135        self.poll_inner(cx)
136    }
137
138    fn poll_trailers(
139        self: Pin<&mut Self>,
140        cx: &mut Context<'_>,
141    ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
142        self.project().inner.poll_trailers(cx)
143    }
144
145    fn is_end_stream(&self) -> bool {
146        self.checksum.is_none()
147    }
148
149    fn size_hint(&self) -> SizeHint {
150        self.inner.size_hint()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use crate::body::validate::{ChecksumBody, Error};
157    use crate::ChecksumAlgorithm;
158    use aws_smithy_types::body::SdkBody;
159    use bytes::{Buf, Bytes};
160    use bytes_utils::SegmentedBuf;
161    use http_body::Body;
162    use std::io::Read;
163
164    fn calculate_crc32_checksum(input: &str) -> Bytes {
165        let checksum = crc32fast::hash(input.as_bytes());
166        Bytes::copy_from_slice(&checksum.to_be_bytes())
167    }
168
169    #[tokio::test]
170    async fn test_checksum_validated_body_errors_on_mismatch() {
171        let input_text = "This is some test text for an SdkBody";
172        let actual_checksum = calculate_crc32_checksum(input_text);
173        let body = SdkBody::from(input_text);
174        let non_matching_checksum = Bytes::copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
175        let mut body = ChecksumBody::new(
176            body,
177            "crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl(),
178            non_matching_checksum.clone(),
179        );
180
181        while let Some(data) = body.data().await {
182            match data {
183                Ok(_) => { /* Do nothing */ }
184                Err(e) => {
185                    match e.downcast_ref::<Error>().unwrap() {
186                        Error::ChecksumMismatch { expected, actual } => {
187                            assert_eq!(expected, &non_matching_checksum);
188                            assert_eq!(actual, &actual_checksum);
189                        }
190                    }
191
192                    return;
193                }
194            }
195        }
196
197        panic!("didn't hit expected error condition");
198    }
199
200    #[tokio::test]
201    async fn test_checksum_validated_body_succeeds_on_match() {
202        let input_text = "This is some test text for an SdkBody";
203        let actual_checksum = calculate_crc32_checksum(input_text);
204        let body = SdkBody::from(input_text);
205        let http_checksum = "crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl();
206        let mut body = ChecksumBody::new(body, http_checksum, actual_checksum);
207
208        let mut output = SegmentedBuf::new();
209        while let Some(buf) = body.data().await {
210            output.push(buf.unwrap());
211        }
212
213        let mut output_text = String::new();
214        output
215            .reader()
216            .read_to_string(&mut output_text)
217            .expect("Doesn't cause IO errors");
218        // Verify data is complete and unaltered
219        assert_eq!(input_text, output_text);
220    }
221}