aws_smithy_checksums/body/
validate.rs
1use crate::http::HttpChecksum;
10
11use aws_smithy_types::body::SdkBody;
12
13use bytes::Bytes;
14use http::{HeaderMap, HeaderValue};
15use http_body::SizeHint;
16use pin_project_lite::pin_project;
17
18use std::fmt::Display;
19use std::pin::Pin;
20use std::task::{Context, Poll};
21
22pin_project! {
23 pub struct ChecksumBody<InnerBody> {
26 #[pin]
27 inner: InnerBody,
28 checksum: Option<Box<dyn HttpChecksum>>,
29 precalculated_checksum: Bytes,
30 }
31}
32
33impl ChecksumBody<SdkBody> {
34 pub fn new(
37 body: SdkBody,
38 checksum: Box<dyn HttpChecksum>,
39 precalculated_checksum: Bytes,
40 ) -> Self {
41 Self {
42 inner: body,
43 checksum: Some(checksum),
44 precalculated_checksum,
45 }
46 }
47
48 fn poll_inner(
49 self: Pin<&mut Self>,
50 cx: &mut Context<'_>,
51 ) -> Poll<Option<Result<Bytes, aws_smithy_types::body::Error>>> {
52 use http_body::Body;
53
54 let this = self.project();
55 let checksum = this.checksum;
56
57 match this.inner.poll_data(cx) {
58 Poll::Ready(Some(Ok(data))) => {
59 tracing::trace!(
60 "reading {} bytes from the body and updating the checksum calculation",
61 data.len()
62 );
63 let checksum = match checksum.as_mut() {
64 Some(checksum) => checksum,
65 None => {
66 unreachable!("The checksum must exist because it's only taken out once the inner body has been completely polled.");
67 }
68 };
69
70 checksum.update(&data);
71 Poll::Ready(Some(Ok(data)))
72 }
73 Poll::Ready(None) => {
76 tracing::trace!("finished reading from body, calculating final checksum");
77 let checksum = match checksum.take() {
78 Some(checksum) => checksum,
79 None => {
80 return Poll::Ready(None);
83 }
84 };
85
86 let actual_checksum = checksum.finalize();
87 if *this.precalculated_checksum == actual_checksum {
88 Poll::Ready(None)
89 } else {
90 Poll::Ready(Some(Err(Box::new(Error::ChecksumMismatch {
92 expected: this.precalculated_checksum.clone(),
93 actual: actual_checksum,
94 }))))
95 }
96 }
97 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
98 Poll::Pending => Poll::Pending,
99 }
100 }
101}
102
103#[derive(Debug, Eq, PartialEq)]
105#[non_exhaustive]
106pub enum Error {
107 ChecksumMismatch { expected: Bytes, actual: Bytes },
110}
111
112impl Display for Error {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
114 match self {
115 Error::ChecksumMismatch { expected, actual } => write!(
116 f,
117 "body checksum mismatch. expected body checksum to be {} but it was {}",
118 hex::encode(expected),
119 hex::encode(actual)
120 ),
121 }
122 }
123}
124
125impl std::error::Error for Error {}
126
127impl http_body::Body for ChecksumBody<SdkBody> {
128 type Data = Bytes;
129 type Error = aws_smithy_types::body::Error;
130
131 fn poll_data(
132 self: Pin<&mut Self>,
133 cx: &mut Context<'_>,
134 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
135 self.poll_inner(cx)
136 }
137
138 fn poll_trailers(
139 self: Pin<&mut Self>,
140 cx: &mut Context<'_>,
141 ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
142 self.project().inner.poll_trailers(cx)
143 }
144
145 fn is_end_stream(&self) -> bool {
146 self.checksum.is_none()
147 }
148
149 fn size_hint(&self) -> SizeHint {
150 self.inner.size_hint()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use crate::body::validate::{ChecksumBody, Error};
157 use crate::ChecksumAlgorithm;
158 use aws_smithy_types::body::SdkBody;
159 use bytes::{Buf, Bytes};
160 use bytes_utils::SegmentedBuf;
161 use http_body::Body;
162 use std::io::Read;
163
164 fn calculate_crc32_checksum(input: &str) -> Bytes {
165 let checksum = crc32fast::hash(input.as_bytes());
166 Bytes::copy_from_slice(&checksum.to_be_bytes())
167 }
168
169 #[tokio::test]
170 async fn test_checksum_validated_body_errors_on_mismatch() {
171 let input_text = "This is some test text for an SdkBody";
172 let actual_checksum = calculate_crc32_checksum(input_text);
173 let body = SdkBody::from(input_text);
174 let non_matching_checksum = Bytes::copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
175 let mut body = ChecksumBody::new(
176 body,
177 "crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl(),
178 non_matching_checksum.clone(),
179 );
180
181 while let Some(data) = body.data().await {
182 match data {
183 Ok(_) => { }
184 Err(e) => {
185 match e.downcast_ref::<Error>().unwrap() {
186 Error::ChecksumMismatch { expected, actual } => {
187 assert_eq!(expected, &non_matching_checksum);
188 assert_eq!(actual, &actual_checksum);
189 }
190 }
191
192 return;
193 }
194 }
195 }
196
197 panic!("didn't hit expected error condition");
198 }
199
200 #[tokio::test]
201 async fn test_checksum_validated_body_succeeds_on_match() {
202 let input_text = "This is some test text for an SdkBody";
203 let actual_checksum = calculate_crc32_checksum(input_text);
204 let body = SdkBody::from(input_text);
205 let http_checksum = "crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl();
206 let mut body = ChecksumBody::new(body, http_checksum, actual_checksum);
207
208 let mut output = SegmentedBuf::new();
209 while let Some(buf) = body.data().await {
210 output.push(buf.unwrap());
211 }
212
213 let mut output_text = String::new();
214 output
215 .reader()
216 .read_to_string(&mut output_text)
217 .expect("Doesn't cause IO errors");
218 assert_eq!(input_text, output_text);
220 }
221}