aws_sigv4/http_request/
canonical_request.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::date_time::{format_date, format_date_time};
7use crate::http_request::error::CanonicalRequestError;
8use crate::http_request::settings::SessionTokenMode;
9use crate::http_request::settings::UriPathNormalizationMode;
10use crate::http_request::sign::SignableRequest;
11use crate::http_request::uri_path_normalization::normalize_uri_path;
12use crate::http_request::url_escape::percent_encode_path;
13use crate::http_request::{PayloadChecksumKind, SignableBody, SignatureLocation, SigningParams};
14use crate::http_request::{PercentEncodingMode, SigningSettings};
15use crate::sign::v4::sha256_hex_string;
16use crate::SignatureVersion;
17use aws_smithy_http::query_writer::QueryWriter;
18use http0::header::{AsHeaderName, HeaderName, HOST};
19use http0::uri::{Port, Scheme};
20use http0::{HeaderMap, HeaderValue, Uri};
21use std::borrow::Cow;
22use std::cmp::Ordering;
23use std::fmt;
24use std::str::FromStr;
25use std::time::SystemTime;
26
27#[cfg(feature = "sigv4a")]
28pub(crate) mod sigv4a;
29
30pub(crate) mod header {
31    pub(crate) const X_AMZ_CONTENT_SHA_256: &str = "x-amz-content-sha256";
32    pub(crate) const X_AMZ_DATE: &str = "x-amz-date";
33    pub(crate) const X_AMZ_SECURITY_TOKEN: &str = "x-amz-security-token";
34    pub(crate) const X_AMZ_USER_AGENT: &str = "x-amz-user-agent";
35    pub(crate) const X_AMZ_CHECKSUM_MODE: &str = "x-amz-checksum-mode";
36}
37
38pub(crate) mod param {
39    pub(crate) const X_AMZ_ALGORITHM: &str = "X-Amz-Algorithm";
40    pub(crate) const X_AMZ_CREDENTIAL: &str = "X-Amz-Credential";
41    pub(crate) const X_AMZ_DATE: &str = "X-Amz-Date";
42    pub(crate) const X_AMZ_EXPIRES: &str = "X-Amz-Expires";
43    pub(crate) const X_AMZ_SECURITY_TOKEN: &str = "X-Amz-Security-Token";
44    pub(crate) const X_AMZ_SIGNED_HEADERS: &str = "X-Amz-SignedHeaders";
45    pub(crate) const X_AMZ_SIGNATURE: &str = "X-Amz-Signature";
46}
47
48pub(crate) const HMAC_256: &str = "AWS4-HMAC-SHA256";
49
50const UNSIGNED_PAYLOAD: &str = "UNSIGNED-PAYLOAD";
51const STREAMING_UNSIGNED_PAYLOAD_TRAILER: &str = "STREAMING-UNSIGNED-PAYLOAD-TRAILER";
52
53#[derive(Debug, PartialEq)]
54pub(crate) struct HeaderValues<'a> {
55    pub(crate) content_sha256: Cow<'a, str>,
56    pub(crate) date_time: String,
57    pub(crate) security_token: Option<&'a str>,
58    pub(crate) signed_headers: SignedHeaders,
59    #[cfg(feature = "sigv4a")]
60    pub(crate) region_set: Option<&'a str>,
61}
62
63#[derive(Debug, PartialEq)]
64pub(crate) struct QueryParamValues<'a> {
65    pub(crate) algorithm: &'static str,
66    pub(crate) content_sha256: Cow<'a, str>,
67    pub(crate) credential: String,
68    pub(crate) date_time: String,
69    pub(crate) expires: String,
70    pub(crate) security_token: Option<&'a str>,
71    pub(crate) signed_headers: SignedHeaders,
72    #[cfg(feature = "sigv4a")]
73    pub(crate) region_set: Option<&'a str>,
74}
75
76#[derive(Debug, PartialEq)]
77pub(crate) enum SignatureValues<'a> {
78    Headers(HeaderValues<'a>),
79    QueryParams(QueryParamValues<'a>),
80}
81
82impl<'a> SignatureValues<'a> {
83    pub(crate) fn signed_headers(&self) -> &SignedHeaders {
84        match self {
85            SignatureValues::Headers(values) => &values.signed_headers,
86            SignatureValues::QueryParams(values) => &values.signed_headers,
87        }
88    }
89
90    fn content_sha256(&self) -> &str {
91        match self {
92            SignatureValues::Headers(values) => &values.content_sha256,
93            SignatureValues::QueryParams(values) => &values.content_sha256,
94        }
95    }
96
97    pub(crate) fn as_headers(&self) -> Option<&HeaderValues<'_>> {
98        match self {
99            SignatureValues::Headers(values) => Some(values),
100            _ => None,
101        }
102    }
103
104    pub(crate) fn into_query_params(self) -> Result<QueryParamValues<'a>, Self> {
105        match self {
106            SignatureValues::QueryParams(values) => Ok(values),
107            _ => Err(self),
108        }
109    }
110}
111
112#[derive(Debug, PartialEq)]
113pub(crate) struct CanonicalRequest<'a> {
114    pub(crate) method: &'a str,
115    pub(crate) path: Cow<'a, str>,
116    pub(crate) params: Option<String>,
117    pub(crate) headers: HeaderMap,
118    pub(crate) values: SignatureValues<'a>,
119}
120
121impl<'a> CanonicalRequest<'a> {
122    /// Construct a CanonicalRequest from a [`SignableRequest`] and [`SigningParams`].
123    ///
124    /// The returned canonical request includes information required for signing as well
125    /// as query parameters or header values that go along with the signature in a request.
126    ///
127    /// ## Behavior
128    ///
129    /// There are several settings which alter signing behavior:
130    /// - If a `security_token` is provided as part of the credentials it will be included in the signed headers
131    /// - If `settings.percent_encoding_mode` specifies double encoding, `%` in the URL will be re-encoded as `%25`
132    /// - If `settings.payload_checksum_kind` is XAmzSha256, add a x-amz-content-sha256 with the body
133    ///   checksum. This is the same checksum used as the "payload_hash" in the canonical request
134    /// - If `settings.session_token_mode` specifies X-Amz-Security-Token to be
135    ///   included before calculating the signature, add it, otherwise omit it.
136    /// - `settings.signature_location` determines where the signature will be placed in a request,
137    ///   and also alters the kinds of signing values that go along with it in the request.
138    pub(crate) fn from<'b>(
139        req: &'b SignableRequest<'b>,
140        params: &'b SigningParams<'b>,
141    ) -> Result<CanonicalRequest<'b>, CanonicalRequestError> {
142        let creds = params
143            .credentials()
144            .map_err(|_| CanonicalRequestError::unsupported_identity_type())?;
145        // Path encoding: if specified, re-encode % as %25
146        // Set method and path into CanonicalRequest
147        let path = req.uri().path();
148        let path = match params.settings().uri_path_normalization_mode {
149            UriPathNormalizationMode::Enabled => normalize_uri_path(path),
150            UriPathNormalizationMode::Disabled => Cow::Borrowed(path),
151        };
152        let path = match params.settings().percent_encoding_mode {
153            // The string is already URI encoded, we don't need to encode everything again, just `%`
154            PercentEncodingMode::Double => Cow::Owned(percent_encode_path(&path)),
155            PercentEncodingMode::Single => path,
156        };
157        let payload_hash = Self::payload_hash(req.body());
158
159        let date_time = format_date_time(*params.time());
160        let (signed_headers, canonical_headers) =
161            Self::headers(req, params, &payload_hash, &date_time)?;
162        let signed_headers = SignedHeaders::new(signed_headers);
163
164        let security_token = match params.settings().session_token_mode {
165            SessionTokenMode::Include => creds.session_token(),
166            SessionTokenMode::Exclude => None,
167        };
168
169        let values = match params.settings().signature_location {
170            SignatureLocation::Headers => SignatureValues::Headers(HeaderValues {
171                content_sha256: payload_hash,
172                date_time,
173                security_token,
174                signed_headers,
175                #[cfg(feature = "sigv4a")]
176                region_set: params.region_set(),
177            }),
178            SignatureLocation::QueryParams => {
179                let credential = match params {
180                    SigningParams::V4(params) => {
181                        format!(
182                            "{}/{}/{}/{}/aws4_request",
183                            creds.access_key_id(),
184                            format_date(params.time),
185                            params.region,
186                            params.name,
187                        )
188                    }
189                    #[cfg(feature = "sigv4a")]
190                    SigningParams::V4a(params) => {
191                        format!(
192                            "{}/{}/{}/aws4_request",
193                            creds.access_key_id(),
194                            format_date(params.time),
195                            params.name,
196                        )
197                    }
198                };
199
200                SignatureValues::QueryParams(QueryParamValues {
201                    algorithm: params.algorithm(),
202                    content_sha256: payload_hash,
203                    credential,
204                    date_time,
205                    expires: params
206                        .settings()
207                        .expires_in
208                        .expect("presigning requires expires_in")
209                        .as_secs()
210                        .to_string(),
211                    security_token,
212                    signed_headers,
213                    #[cfg(feature = "sigv4a")]
214                    region_set: params.region_set(),
215                })
216            }
217        };
218
219        let creq = CanonicalRequest {
220            method: req.method(),
221            path,
222            params: Self::params(req.uri(), &values, params.settings()),
223            headers: canonical_headers,
224            values,
225        };
226        Ok(creq)
227    }
228
229    fn headers(
230        req: &SignableRequest<'_>,
231        params: &SigningParams<'_>,
232        payload_hash: &str,
233        date_time: &str,
234    ) -> Result<(Vec<CanonicalHeaderName>, HeaderMap), CanonicalRequestError> {
235        // Header computation:
236        // The canonical request will include headers not present in the input. We need to clone and
237        // normalize the headers from the original request and add:
238        // - host
239        // - x-amz-date
240        // - x-amz-security-token (if provided)
241        // - x-amz-content-sha256 (if requested by signing settings)
242        let mut canonical_headers = HeaderMap::with_capacity(req.headers().len());
243        for (name, value) in req.headers().iter() {
244            // Header names and values need to be normalized according to Step 4 of https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
245            // Using append instead of insert means this will not clobber headers that have the same lowercased name
246            canonical_headers.append(
247                HeaderName::from_str(&name.to_lowercase())?,
248                normalize_header_value(value)?,
249            );
250        }
251
252        Self::insert_host_header(&mut canonical_headers, req.uri());
253
254        let token_header_name = params
255            .settings()
256            .session_token_name_override
257            .unwrap_or(header::X_AMZ_SECURITY_TOKEN);
258
259        if params.settings().signature_location == SignatureLocation::Headers {
260            let creds = params
261                .credentials()
262                .map_err(|_| CanonicalRequestError::unsupported_identity_type())?;
263            Self::insert_date_header(&mut canonical_headers, date_time);
264
265            if let Some(security_token) = creds.session_token() {
266                let mut sec_header = HeaderValue::from_str(security_token)?;
267                sec_header.set_sensitive(true);
268                canonical_headers.insert(token_header_name, sec_header);
269            }
270
271            if params.settings().payload_checksum_kind == PayloadChecksumKind::XAmzSha256 {
272                let header = HeaderValue::from_str(payload_hash)?;
273                canonical_headers.insert(header::X_AMZ_CONTENT_SHA_256, header);
274            }
275
276            #[cfg(feature = "sigv4a")]
277            if let Some(region_set) = params.region_set() {
278                let header = HeaderValue::from_str(region_set)?;
279                canonical_headers.insert(sigv4a::header::X_AMZ_REGION_SET, header);
280            }
281        }
282
283        let mut signed_headers = Vec::with_capacity(canonical_headers.len());
284        for name in canonical_headers.keys() {
285            if let Some(excluded_headers) = params.settings().excluded_headers.as_ref() {
286                if excluded_headers.iter().any(|it| name.as_str() == it) {
287                    continue;
288                }
289            }
290
291            if params.settings().session_token_mode == SessionTokenMode::Exclude
292                && name == HeaderName::from_static(token_header_name)
293            {
294                continue;
295            }
296
297            if params.settings().signature_location == SignatureLocation::QueryParams {
298                // The X-Amz-User-Agent and x-amz-checksum-mode headers should not be signed if this is for a presigned URL
299                if name == HeaderName::from_static(header::X_AMZ_USER_AGENT)
300                    || name == HeaderName::from_static(header::X_AMZ_CHECKSUM_MODE)
301                {
302                    continue;
303                }
304            }
305            signed_headers.push(CanonicalHeaderName(name.clone()));
306        }
307
308        Ok((signed_headers, canonical_headers))
309    }
310
311    fn payload_hash<'b>(body: &'b SignableBody<'b>) -> Cow<'b, str> {
312        // Payload hash computation
313        //
314        // Based on the input body, set the payload_hash of the canonical request:
315        // Either:
316        // - compute a hash
317        // - use the precomputed hash
318        // - use `UnsignedPayload`
319        // - use `UnsignedPayload` for streaming requests
320        // - use `StreamingUnsignedPayloadTrailer` for streaming requests with trailers
321        match body {
322            SignableBody::Bytes(data) => Cow::Owned(sha256_hex_string(data)),
323            SignableBody::Precomputed(digest) => Cow::Borrowed(digest.as_str()),
324            SignableBody::UnsignedPayload => Cow::Borrowed(UNSIGNED_PAYLOAD),
325            SignableBody::StreamingUnsignedPayloadTrailer => {
326                Cow::Borrowed(STREAMING_UNSIGNED_PAYLOAD_TRAILER)
327            }
328        }
329    }
330
331    fn params(
332        uri: &Uri,
333        values: &SignatureValues<'_>,
334        settings: &SigningSettings,
335    ) -> Option<String> {
336        let mut params: Vec<(Cow<'_, str>, Cow<'_, str>)> =
337            form_urlencoded::parse(uri.query().unwrap_or_default().as_bytes()).collect();
338        fn add_param<'a>(params: &mut Vec<(Cow<'a, str>, Cow<'a, str>)>, k: &'a str, v: &'a str) {
339            params.push((Cow::Borrowed(k), Cow::Borrowed(v)));
340        }
341
342        if let SignatureValues::QueryParams(values) = values {
343            add_param(&mut params, param::X_AMZ_DATE, &values.date_time);
344            add_param(&mut params, param::X_AMZ_EXPIRES, &values.expires);
345
346            #[cfg(feature = "sigv4a")]
347            if let Some(regions) = values.region_set {
348                add_param(&mut params, sigv4a::param::X_AMZ_REGION_SET, regions);
349            }
350
351            add_param(&mut params, param::X_AMZ_ALGORITHM, values.algorithm);
352            add_param(&mut params, param::X_AMZ_CREDENTIAL, &values.credential);
353            add_param(
354                &mut params,
355                param::X_AMZ_SIGNED_HEADERS,
356                values.signed_headers.as_str(),
357            );
358
359            if let Some(security_token) = values.security_token {
360                add_param(
361                    &mut params,
362                    settings
363                        .session_token_name_override
364                        .unwrap_or(param::X_AMZ_SECURITY_TOKEN),
365                    security_token,
366                );
367            }
368        }
369        // Sort by param name, and then by param value
370        params.sort();
371
372        let mut query = QueryWriter::new(uri);
373        query.clear_params();
374        for (key, value) in params {
375            query.insert(&key, &value);
376        }
377
378        let query = query.build_query();
379        if query.is_empty() {
380            None
381        } else {
382            Some(query)
383        }
384    }
385
386    fn insert_host_header(
387        canonical_headers: &mut HeaderMap<HeaderValue>,
388        uri: &Uri,
389    ) -> HeaderValue {
390        match canonical_headers.get(&HOST) {
391            Some(header) => header.clone(),
392            None => {
393                let port = uri.port();
394                let scheme = uri.scheme();
395                let authority = uri
396                    .authority()
397                    .expect("request uri authority must be set for signing")
398                    .as_str();
399                let host = uri
400                    .host()
401                    .expect("request uri host must be set for signing");
402
403                // Check if port is default (80 for HTTP, 443 for HTTPS) and if so exclude it from the
404                // Host header when signing since RFC 2616 indicates that the default port should not be
405                // sent in the Host header (and Hyper strips default ports if they are present)
406                // https://datatracker.ietf.org/doc/html/rfc2616#section-14.23
407                // https://github.com/awslabs/aws-sdk-rust/issues/1244
408                let header_value = if is_port_scheme_default(scheme, port) {
409                    host
410                } else {
411                    authority
412                };
413
414                let header = HeaderValue::try_from(header_value)
415                    .expect("endpoint must contain valid header characters");
416                canonical_headers.insert(HOST, header.clone());
417                header
418            }
419        }
420    }
421
422    fn insert_date_header(
423        canonical_headers: &mut HeaderMap<HeaderValue>,
424        date_time: &str,
425    ) -> HeaderValue {
426        let x_amz_date = HeaderName::from_static(header::X_AMZ_DATE);
427        let date_header = HeaderValue::try_from(date_time).expect("date is valid header value");
428        canonical_headers.insert(x_amz_date, date_header.clone());
429        date_header
430    }
431
432    fn header_values_for(&self, key: impl AsHeaderName) -> String {
433        let values: Vec<&str> = self
434            .headers
435            .get_all(key)
436            .into_iter()
437            .map(|value| {
438                std::str::from_utf8(value.as_bytes())
439                    .expect("SDK request header values are valid UTF-8")
440            })
441            .collect();
442        values.join(",")
443    }
444}
445
446impl<'a> fmt::Display for CanonicalRequest<'a> {
447    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
448        writeln!(f, "{}", self.method)?;
449        writeln!(f, "{}", self.path)?;
450        writeln!(f, "{}", self.params.as_deref().unwrap_or(""))?;
451        // write out _all_ the headers
452        for header in &self.values.signed_headers().headers {
453            write!(f, "{}:", header.0.as_str())?;
454            writeln!(f, "{}", self.header_values_for(&header.0))?;
455        }
456        writeln!(f)?;
457        // write out the signed headers
458        writeln!(f, "{}", self.values.signed_headers().as_str())?;
459        write!(f, "{}", self.values.content_sha256())?;
460        Ok(())
461    }
462}
463
464/// Removes excess spaces before and after a given byte string, and converts multiple sequential
465/// spaces to a single space e.g. "  Some  example   text  " -> "Some example text".
466///
467/// This function ONLY affects spaces and not other kinds of whitespace.
468fn trim_all(text: &str) -> Cow<'_, str> {
469    let text = text.trim_matches(' ');
470    let requires_filter = text
471        .chars()
472        .zip(text.chars().skip(1))
473        .any(|(a, b)| a == ' ' && b == ' ');
474    if !requires_filter {
475        Cow::Borrowed(text)
476    } else {
477        // The normal trim function will trim non-breaking spaces and other various whitespace chars.
478        // S3 ONLY trims spaces so we use trim_matches to trim spaces only
479        Cow::Owned(
480            text.chars()
481                // Filter out consecutive spaces
482                .zip(text.chars().skip(1).chain(std::iter::once('!')))
483                .filter(|(a, b)| *a != ' ' || *b != ' ')
484                .map(|(a, _)| a)
485                .collect(),
486        )
487    }
488}
489
490/// Works just like [trim_all] but acts on HeaderValues instead of bytes.
491/// Will ensure that the underlying bytes are valid UTF-8.
492fn normalize_header_value(header_value: &str) -> Result<HeaderValue, CanonicalRequestError> {
493    let trimmed_value = trim_all(header_value);
494    HeaderValue::from_str(&trimmed_value).map_err(CanonicalRequestError::from)
495}
496
497#[inline]
498fn is_port_scheme_default(scheme: Option<&Scheme>, port: Option<Port<&str>>) -> bool {
499    if let (Some(scheme), Some(port)) = (scheme, port) {
500        return [("http", "80"), ("https", "443")].contains(&(scheme.as_str(), port.as_str()));
501    }
502
503    false
504}
505
506#[derive(Debug, PartialEq, Default)]
507pub(crate) struct SignedHeaders {
508    headers: Vec<CanonicalHeaderName>,
509    formatted: String,
510}
511
512impl SignedHeaders {
513    fn new(mut headers: Vec<CanonicalHeaderName>) -> Self {
514        headers.sort();
515        let formatted = Self::fmt(&headers);
516        SignedHeaders { headers, formatted }
517    }
518
519    fn fmt(headers: &[CanonicalHeaderName]) -> String {
520        let mut value = String::new();
521        let mut iter = headers.iter().peekable();
522        while let Some(next) = iter.next() {
523            value += next.0.as_str();
524            if iter.peek().is_some() {
525                value.push(';');
526            }
527        }
528        value
529    }
530
531    pub(crate) fn as_str(&self) -> &str {
532        &self.formatted
533    }
534}
535
536impl fmt::Display for SignedHeaders {
537    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
538        write!(f, "{}", self.formatted)
539    }
540}
541
542#[derive(Debug, PartialEq, Eq, Clone)]
543struct CanonicalHeaderName(HeaderName);
544
545impl PartialOrd for CanonicalHeaderName {
546    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
547        Some(self.cmp(other))
548    }
549}
550
551impl Ord for CanonicalHeaderName {
552    fn cmp(&self, other: &Self) -> Ordering {
553        self.0.as_str().cmp(other.0.as_str())
554    }
555}
556
557#[derive(PartialEq, Debug, Clone)]
558pub(crate) struct SigningScope<'a> {
559    pub(crate) time: SystemTime,
560    pub(crate) region: &'a str,
561    pub(crate) service: &'a str,
562}
563
564impl<'a> SigningScope<'a> {
565    pub(crate) fn v4a_display(&self) -> String {
566        format!("{}/{}/aws4_request", format_date(self.time), self.service)
567    }
568}
569
570impl<'a> fmt::Display for SigningScope<'a> {
571    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
572        write!(
573            f,
574            "{}/{}/{}/aws4_request",
575            format_date(self.time),
576            self.region,
577            self.service
578        )
579    }
580}
581
582#[derive(PartialEq, Debug, Clone)]
583pub(crate) struct StringToSign<'a> {
584    pub(crate) algorithm: &'static str,
585    pub(crate) scope: SigningScope<'a>,
586    pub(crate) time: SystemTime,
587    pub(crate) region: &'a str,
588    pub(crate) service: &'a str,
589    pub(crate) hashed_creq: &'a str,
590    signature_version: SignatureVersion,
591}
592
593impl<'a> StringToSign<'a> {
594    pub(crate) fn new_v4(
595        time: SystemTime,
596        region: &'a str,
597        service: &'a str,
598        hashed_creq: &'a str,
599    ) -> Self {
600        let scope = SigningScope {
601            time,
602            region,
603            service,
604        };
605        Self {
606            algorithm: HMAC_256,
607            scope,
608            time,
609            region,
610            service,
611            hashed_creq,
612            signature_version: SignatureVersion::V4,
613        }
614    }
615
616    #[cfg(feature = "sigv4a")]
617    pub(crate) fn new_v4a(
618        time: SystemTime,
619        region_set: &'a str,
620        service: &'a str,
621        hashed_creq: &'a str,
622    ) -> Self {
623        use crate::sign::v4a::ECDSA_256;
624
625        let scope = SigningScope {
626            time,
627            region: region_set,
628            service,
629        };
630        Self {
631            algorithm: ECDSA_256,
632            scope,
633            time,
634            region: region_set,
635            service,
636            hashed_creq,
637            signature_version: SignatureVersion::V4a,
638        }
639    }
640}
641
642impl<'a> fmt::Display for StringToSign<'a> {
643    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
644        write!(
645            f,
646            "{}\n{}\n{}\n{}",
647            self.algorithm,
648            format_date_time(self.time),
649            match self.signature_version {
650                SignatureVersion::V4 => self.scope.to_string(),
651                SignatureVersion::V4a => self.scope.v4a_display(),
652            },
653            self.hashed_creq
654        )
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use crate::date_time::test_parsers::parse_date_time;
661    use crate::http_request::canonical_request::{
662        normalize_header_value, trim_all, CanonicalRequest, SigningScope, StringToSign,
663    };
664    use crate::http_request::test;
665    use crate::http_request::{
666        PayloadChecksumKind, SessionTokenMode, SignableBody, SignableRequest, SignatureLocation,
667        SigningParams, SigningSettings,
668    };
669    use crate::sign::v4;
670    use crate::sign::v4::sha256_hex_string;
671    use aws_credential_types::Credentials;
672    use aws_smithy_http::query_writer::QueryWriter;
673    use aws_smithy_runtime_api::client::identity::Identity;
674    use http0::{HeaderValue, Uri};
675    use pretty_assertions::assert_eq;
676    use proptest::{prelude::*, proptest};
677    use std::borrow::Cow;
678    use std::time::Duration;
679
680    fn signing_params(identity: &Identity, settings: SigningSettings) -> SigningParams<'_> {
681        v4::signing_params::Builder::default()
682            .identity(identity)
683            .region("test-region")
684            .name("testservicename")
685            .time(parse_date_time("20210511T154045Z").unwrap())
686            .settings(settings)
687            .build()
688            .unwrap()
689            .into()
690    }
691
692    #[test]
693    fn test_repeated_header() {
694        let mut req = test::v4::test_request("get-vanilla-query-order-key-case");
695        req.headers.push((
696            "x-amz-object-attributes".to_string(),
697            "Checksum".to_string(),
698        ));
699        req.headers.push((
700            "x-amz-object-attributes".to_string(),
701            "ObjectSize".to_string(),
702        ));
703        let req = SignableRequest::from(&req);
704        let settings = SigningSettings {
705            payload_checksum_kind: PayloadChecksumKind::XAmzSha256,
706            session_token_mode: SessionTokenMode::Exclude,
707            ..Default::default()
708        };
709        let identity = Credentials::for_tests().into();
710        let signing_params = signing_params(&identity, settings);
711        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
712
713        assert_eq!(
714            creq.values.signed_headers().to_string(),
715            "host;x-amz-content-sha256;x-amz-date;x-amz-object-attributes"
716        );
717        assert_eq!(
718            creq.header_values_for("x-amz-object-attributes"),
719            "Checksum,ObjectSize",
720        );
721    }
722
723    #[test]
724    fn test_host_header_properly_handles_ports() {
725        fn host_header_test_setup(endpoint: String) -> String {
726            let mut req = test::v4::test_request("get-vanilla");
727            req.uri = endpoint;
728            let req = SignableRequest::from(&req);
729            let settings = SigningSettings {
730                payload_checksum_kind: PayloadChecksumKind::XAmzSha256,
731                session_token_mode: SessionTokenMode::Exclude,
732                ..Default::default()
733            };
734            let identity = Credentials::for_tests().into();
735            let signing_params = signing_params(&identity, settings);
736            let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
737            creq.header_values_for("host")
738        }
739
740        // HTTP request with 80 port should not be signed with that port
741        let http_80_host_header = host_header_test_setup("http://localhost:80".into());
742        assert_eq!(http_80_host_header, "localhost",);
743
744        // HTTP request with non-80 port should be signed with that port
745        let http_1234_host_header = host_header_test_setup("http://localhost:1234".into());
746        assert_eq!(http_1234_host_header, "localhost:1234",);
747
748        // HTTPS request with 443 port should not be signed with that port
749        let https_443_host_header = host_header_test_setup("https://localhost:443".into());
750        assert_eq!(https_443_host_header, "localhost",);
751
752        // HTTPS request with non-443 port should be signed with that port
753        let https_1234_host_header = host_header_test_setup("https://localhost:1234".into());
754        assert_eq!(https_1234_host_header, "localhost:1234",);
755    }
756
757    #[test]
758    fn test_set_xamz_sha_256() {
759        let req = test::v4::test_request("get-vanilla-query-order-key-case");
760        let req = SignableRequest::from(&req);
761        let settings = SigningSettings {
762            payload_checksum_kind: PayloadChecksumKind::XAmzSha256,
763            session_token_mode: SessionTokenMode::Exclude,
764            ..Default::default()
765        };
766        let identity = Credentials::for_tests().into();
767        let mut signing_params = signing_params(&identity, settings);
768        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
769        assert_eq!(
770            creq.values.content_sha256(),
771            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
772        );
773        // assert that the sha256 header was added
774        assert_eq!(
775            creq.values.signed_headers().as_str(),
776            "host;x-amz-content-sha256;x-amz-date"
777        );
778
779        signing_params.set_payload_checksum_kind(PayloadChecksumKind::NoHeader);
780        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
781        assert_eq!(creq.values.signed_headers().as_str(), "host;x-amz-date");
782    }
783
784    #[test]
785    fn test_unsigned_payload() {
786        let mut req = test::v4::test_request("get-vanilla-query-order-key-case");
787        req.set_body(SignableBody::UnsignedPayload);
788        let req: SignableRequest<'_> = SignableRequest::from(&req);
789
790        let settings = SigningSettings {
791            payload_checksum_kind: PayloadChecksumKind::XAmzSha256,
792            ..Default::default()
793        };
794        let identity = Credentials::for_tests().into();
795        let signing_params = signing_params(&identity, settings);
796        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
797        assert_eq!(creq.values.content_sha256(), "UNSIGNED-PAYLOAD");
798        assert!(creq.to_string().ends_with("UNSIGNED-PAYLOAD"));
799    }
800
801    #[test]
802    fn test_precomputed_payload() {
803        let payload_hash = "44ce7dd67c959e0d3524ffac1771dfbba87d2b6b4b4e99e42034a8b803f8b072";
804        let mut req = test::v4::test_request("get-vanilla-query-order-key-case");
805        req.set_body(SignableBody::Precomputed(String::from(payload_hash)));
806        let req = SignableRequest::from(&req);
807        let settings = SigningSettings {
808            payload_checksum_kind: PayloadChecksumKind::XAmzSha256,
809            ..Default::default()
810        };
811        let identity = Credentials::for_tests().into();
812        let signing_params = signing_params(&identity, settings);
813        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
814        assert_eq!(creq.values.content_sha256(), payload_hash);
815        assert!(creq.to_string().ends_with(payload_hash));
816    }
817
818    #[test]
819    fn test_generate_scope() {
820        let expected = "20150830/us-east-1/iam/aws4_request\n";
821        let scope = SigningScope {
822            time: parse_date_time("20150830T123600Z").unwrap(),
823            region: "us-east-1",
824            service: "iam",
825        };
826        assert_eq!(format!("{}\n", scope), expected);
827    }
828
829    #[test]
830    fn test_string_to_sign() {
831        let time = parse_date_time("20150830T123600Z").unwrap();
832        let creq = test::v4::test_canonical_request("get-vanilla-query-order-key-case");
833        let expected_sts = test::v4::test_sts("get-vanilla-query-order-key-case");
834        let encoded = sha256_hex_string(creq.as_bytes());
835
836        let actual = StringToSign::new_v4(time, "us-east-1", "service", &encoded);
837        assert_eq!(expected_sts, actual.to_string());
838    }
839
840    #[test]
841    fn test_digest_of_canonical_request() {
842        let creq = test::v4::test_canonical_request("get-vanilla-query-order-key-case");
843        let expected = "816cd5b414d056048ba4f7c5386d6e0533120fb1fcfa93762cf0fc39e2cf19e0";
844        let actual = sha256_hex_string(creq.as_bytes());
845        assert_eq!(expected, actual);
846    }
847
848    #[test]
849    fn test_double_url_encode_path() {
850        let req = test::v4::test_request("double-encode-path");
851        let req = SignableRequest::from(&req);
852        let identity = Credentials::for_tests().into();
853        let signing_params = signing_params(&identity, SigningSettings::default());
854        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
855
856        let expected = test::v4::test_canonical_request("double-encode-path");
857        let actual = format!("{}", creq);
858        assert_eq!(actual, expected);
859    }
860
861    #[test]
862    fn test_double_url_encode() {
863        let req = test::v4::test_request("double-url-encode");
864        let req = SignableRequest::from(&req);
865        let identity = Credentials::for_tests().into();
866        let signing_params = signing_params(&identity, SigningSettings::default());
867        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
868
869        let expected = test::v4::test_canonical_request("double-url-encode");
870        let actual = format!("{}", creq);
871        assert_eq!(actual, expected);
872    }
873
874    #[test]
875    fn test_tilde_in_uri() {
876        let req = http0::Request::builder()
877            .uri("https://s3.us-east-1.amazonaws.com/my-bucket?list-type=2&prefix=~objprefix&single&k=&unreserved=-_.~").body("").unwrap().into();
878        let req = SignableRequest::from(&req);
879        let identity = Credentials::for_tests().into();
880        let signing_params = signing_params(&identity, SigningSettings::default());
881        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
882        assert_eq!(
883            Some("k=&list-type=2&prefix=~objprefix&single=&unreserved=-_.~"),
884            creq.params.as_deref(),
885        );
886    }
887
888    #[test]
889    fn test_signing_urls_with_percent_encoded_query_strings() {
890        let all_printable_ascii_chars: String = (32u8..127).map(char::from).collect();
891        let uri = Uri::from_static("https://s3.us-east-1.amazonaws.com/my-bucket");
892
893        let mut query_writer = QueryWriter::new(&uri);
894        query_writer.insert("list-type", "2");
895        query_writer.insert("prefix", &all_printable_ascii_chars);
896
897        let req = http0::Request::builder()
898            .uri(query_writer.build_uri())
899            .body("")
900            .unwrap()
901            .into();
902        let req = SignableRequest::from(&req);
903        let identity = Credentials::for_tests().into();
904        let signing_params = signing_params(&identity, SigningSettings::default());
905        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
906
907        let expected = "list-type=2&prefix=%20%21%22%23%24%25%26%27%28%29%2A%2B%2C-.%2F0123456789%3A%3B%3C%3D%3E%3F%40ABCDEFGHIJKLMNOPQRSTUVWXYZ%5B%5C%5D%5E_%60abcdefghijklmnopqrstuvwxyz%7B%7C%7D~";
908        let actual = creq.params.unwrap();
909        assert_eq!(expected, actual);
910    }
911
912    #[test]
913    fn test_omit_session_token() {
914        let req = test::v4::test_request("get-vanilla-query-order-key-case");
915        let req = SignableRequest::from(&req);
916        let settings = SigningSettings {
917            session_token_mode: SessionTokenMode::Include,
918            ..Default::default()
919        };
920        let identity = Credentials::for_tests_with_session_token().into();
921        let mut signing_params = signing_params(&identity, settings);
922
923        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
924        assert_eq!(
925            creq.values.signed_headers().as_str(),
926            "host;x-amz-date;x-amz-security-token"
927        );
928        assert_eq!(
929            creq.headers.get("x-amz-security-token").unwrap(),
930            "notarealsessiontoken"
931        );
932
933        signing_params.set_session_token_mode(SessionTokenMode::Exclude);
934        let creq = CanonicalRequest::from(&req, &signing_params).unwrap();
935        assert_eq!(
936            creq.headers.get("x-amz-security-token").unwrap(),
937            "notarealsessiontoken"
938        );
939        assert_eq!(creq.values.signed_headers().as_str(), "host;x-amz-date");
940    }
941
942    // It should exclude authorization, user-agent, x-amzn-trace-id, and transfer-encoding headers from presigning
943    #[test]
944    fn non_presigning_header_exclusion() {
945        let request = http0::Request::builder()
946            .uri("https://some-endpoint.some-region.amazonaws.com")
947            .header("authorization", "test-authorization")
948            .header("content-type", "application/xml")
949            .header("content-length", "0")
950            .header("user-agent", "test-user-agent")
951            .header("x-amzn-trace-id", "test-trace-id")
952            .header("x-amz-user-agent", "test-user-agent")
953            .header("transfer-encoding", "chunked")
954            .body("")
955            .unwrap()
956            .into();
957        let request = SignableRequest::from(&request);
958
959        let settings = SigningSettings {
960            signature_location: SignatureLocation::Headers,
961            ..Default::default()
962        };
963
964        let identity = Credentials::for_tests().into();
965        let signing_params = signing_params(&identity, settings);
966        let canonical = CanonicalRequest::from(&request, &signing_params).unwrap();
967
968        let values = canonical.values.as_headers().unwrap();
969        assert_eq!(
970            "content-length;content-type;host;x-amz-date;x-amz-user-agent",
971            values.signed_headers.as_str()
972        );
973    }
974
975    // It should exclude authorization, user-agent, x-amz-user-agent, x-amzn-trace-id, and transfer-encoding headers from presigning
976    #[test]
977    fn presigning_header_exclusion() {
978        let request = http0::Request::builder()
979            .uri("https://some-endpoint.some-region.amazonaws.com")
980            .header("authorization", "test-authorization")
981            .header("content-type", "application/xml")
982            .header("content-length", "0")
983            .header("user-agent", "test-user-agent")
984            .header("x-amzn-trace-id", "test-trace-id")
985            .header("x-amz-user-agent", "test-user-agent")
986            .header("transfer-encoding", "chunked")
987            .body("")
988            .unwrap()
989            .into();
990        let request = SignableRequest::from(&request);
991
992        let settings = SigningSettings {
993            signature_location: SignatureLocation::QueryParams,
994            expires_in: Some(Duration::from_secs(30)),
995            ..Default::default()
996        };
997
998        let identity = Credentials::for_tests().into();
999        let signing_params = signing_params(&identity, settings);
1000        let canonical = CanonicalRequest::from(&request, &signing_params).unwrap();
1001
1002        let values = canonical.values.into_query_params().unwrap();
1003        assert_eq!(
1004            "content-length;content-type;host",
1005            values.signed_headers.as_str()
1006        );
1007    }
1008
1009    #[allow(clippy::ptr_arg)] // The proptest macro requires this arg to be a Vec instead of a slice.
1010    fn valid_input(input: &Vec<String>) -> bool {
1011        [
1012            "content-length".to_owned(),
1013            "content-type".to_owned(),
1014            "host".to_owned(),
1015        ]
1016        .iter()
1017        .all(|element| !input.contains(element))
1018    }
1019
1020    proptest! {
1021        #[test]
1022        fn presigning_header_exclusion_with_explicit_exclusion_list_specified(
1023            excluded_headers in prop::collection::vec("[a-z]{1,20}", 1..10).prop_filter(
1024                "`excluded_headers` should pass the `valid_input` check",
1025                valid_input,
1026            )
1027        ) {
1028            let mut request_builder = http0::Request::builder()
1029                .uri("https://some-endpoint.some-region.amazonaws.com")
1030                .header("content-type", "application/xml")
1031                .header("content-length", "0");
1032            for key in &excluded_headers {
1033                request_builder = request_builder.header(key, "value");
1034            }
1035            let request = request_builder.body("").unwrap().into();
1036
1037            let request = SignableRequest::from(&request);
1038
1039            let settings = SigningSettings {
1040                signature_location: SignatureLocation::QueryParams,
1041                expires_in: Some(Duration::from_secs(30)),
1042                excluded_headers: Some(
1043                    excluded_headers
1044                        .into_iter()
1045                        .map(std::borrow::Cow::Owned)
1046                        .collect(),
1047                ),
1048                ..Default::default()
1049            };
1050
1051        let identity = Credentials::for_tests().into();
1052        let signing_params = signing_params(&identity, settings);
1053            let canonical = CanonicalRequest::from(&request, &signing_params).unwrap();
1054
1055            let values = canonical.values.into_query_params().unwrap();
1056            assert_eq!(
1057                "content-length;content-type;host",
1058                values.signed_headers.as_str()
1059            );
1060        }
1061    }
1062
1063    #[test]
1064    fn test_trim_all_handles_spaces_correctly() {
1065        assert_eq!(Cow::Borrowed("don't touch me"), trim_all("don't touch me"));
1066        assert_eq!("trim left", trim_all("   trim left"));
1067        assert_eq!("trim right", trim_all("trim right "));
1068        assert_eq!("trim both", trim_all("   trim both  "));
1069        assert_eq!("", trim_all(" "));
1070        assert_eq!("", trim_all("  "));
1071        assert_eq!("a b", trim_all(" a   b "));
1072        assert_eq!("Some example text", trim_all("  Some  example   text  "));
1073    }
1074
1075    #[test]
1076    fn test_trim_all_ignores_other_forms_of_whitespace() {
1077        // \xA0 is a non-breaking space character
1078        assert_eq!(
1079            "\t\u{A0}Some\u{A0} example \u{A0}text\u{A0}\n",
1080            trim_all("\t\u{A0}Some\u{A0}     example   \u{A0}text\u{A0}\n")
1081        );
1082    }
1083
1084    #[test]
1085    fn trim_spaces_works_on_single_characters() {
1086        assert_eq!(trim_all("2").as_ref(), "2");
1087    }
1088
1089    proptest! {
1090        #[test]
1091        fn test_trim_all_doesnt_elongate_strings(s in ".*") {
1092            assert!(trim_all(&s).len() <= s.len())
1093        }
1094
1095        #[test]
1096        fn test_normalize_header_value_works_on_valid_header_value(v in (".*")) {
1097            assert_eq!(normalize_header_value(&v).is_ok(), HeaderValue::from_str(&v).is_ok());
1098        }
1099
1100        #[test]
1101        fn test_trim_all_does_nothing_when_there_are_no_spaces(s in "[^ ]*") {
1102            assert_eq!(trim_all(&s).as_ref(), s);
1103        }
1104    }
1105}