aws_smithy_checksums/body/
calculate.rs
1use crate::http::HttpChecksum;
9
10use aws_smithy_http::header::append_merge_header_maps;
11use aws_smithy_types::body::SdkBody;
12
13use http::HeaderMap;
14use http_body::SizeHint;
15use pin_project_lite::pin_project;
16
17use std::pin::Pin;
18use std::task::{Context, Poll};
19
20pin_project! {
21 pub struct ChecksumBody<InnerBody> {
23 #[pin]
24 body: InnerBody,
25 checksum: Option<Box<dyn HttpChecksum>>,
26 }
27}
28
29impl ChecksumBody<SdkBody> {
30 pub fn new(body: SdkBody, checksum: Box<dyn HttpChecksum>) -> Self {
32 Self {
33 body,
34 checksum: Some(checksum),
35 }
36 }
37}
38
39impl http_body::Body for ChecksumBody<SdkBody> {
40 type Data = bytes::Bytes;
41 type Error = aws_smithy_types::body::Error;
42
43 fn poll_data(
44 self: Pin<&mut Self>,
45 cx: &mut Context<'_>,
46 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
47 let this = self.project();
48 match this.checksum {
49 Some(checksum) => {
50 let poll_res = this.body.poll_data(cx);
51 if let Poll::Ready(Some(Ok(data))) = &poll_res {
52 checksum.update(data);
53 }
54
55 poll_res
56 }
57 None => unreachable!("This can only fail if poll_data is called again after poll_trailers, which is invalid"),
58 }
59 }
60
61 fn poll_trailers(
62 self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
65 let this = self.project();
66 let poll_res = this.body.poll_trailers(cx);
67
68 if let Poll::Ready(Ok(maybe_inner_trailers)) = poll_res {
69 let checksum_headers = if let Some(checksum) = this.checksum.take() {
70 checksum.headers()
71 } else {
72 return Poll::Ready(Ok(None));
73 };
74
75 return match maybe_inner_trailers {
76 Some(inner_trailers) => Poll::Ready(Ok(Some(append_merge_header_maps(
77 inner_trailers,
78 checksum_headers,
79 )))),
80 None => Poll::Ready(Ok(Some(checksum_headers))),
81 };
82 }
83
84 poll_res
85 }
86
87 fn is_end_stream(&self) -> bool {
88 self.body.is_end_stream() && self.checksum.is_none()
91 }
92
93 fn size_hint(&self) -> SizeHint {
94 self.body.size_hint()
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::ChecksumBody;
101 use crate::{http::CRC_32_HEADER_NAME, ChecksumAlgorithm, CRC_32_NAME};
102 use aws_smithy_types::base64;
103 use aws_smithy_types::body::SdkBody;
104 use bytes::Buf;
105 use bytes_utils::SegmentedBuf;
106 use http_body::Body;
107 use std::fmt::Write;
108 use std::io::Read;
109
110 fn header_value_as_checksum_string(header_value: &http::HeaderValue) -> String {
111 let decoded_checksum = base64::decode(header_value.to_str().unwrap()).unwrap();
112 let decoded_checksum = decoded_checksum
113 .into_iter()
114 .fold(String::new(), |mut acc, byte| {
115 write!(acc, "{byte:02X?}").expect("string will always be writeable");
116 acc
117 });
118
119 format!("0x{}", decoded_checksum)
120 }
121
122 #[tokio::test]
123 async fn test_checksum_body() {
124 let input_text = "This is some test text for an SdkBody";
125 let body = SdkBody::from(input_text);
126 let checksum = CRC_32_NAME
127 .parse::<ChecksumAlgorithm>()
128 .unwrap()
129 .into_impl();
130 let mut body = ChecksumBody::new(body, checksum);
131
132 let mut output = SegmentedBuf::new();
133 while let Some(buf) = body.data().await {
134 output.push(buf.unwrap());
135 }
136
137 let mut output_text = String::new();
138 output
139 .reader()
140 .read_to_string(&mut output_text)
141 .expect("Doesn't cause IO errors");
142 assert_eq!(input_text, output_text);
144
145 let trailers = body
146 .trailers()
147 .await
148 .expect("checksum generation was without error")
149 .expect("trailers were set");
150 let checksum_trailer = trailers
151 .get(CRC_32_HEADER_NAME)
152 .expect("trailers contain crc32 checksum");
153 let checksum_trailer = header_value_as_checksum_string(checksum_trailer);
154
155 assert_eq!("0x99B01F72", checksum_trailer);
157 }
158}