aws_config/imds/client/
token.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! IMDS Token Middleware
7//! Requests to IMDS are two part:
8//! 1. A PUT request to the token API is made
9//! 2. A GET request is made to the requested API. The Token is added as a header.
10//!
11//! This module implements a middleware that will:
12//! - Load a token via the token API
13//! - Cache the token according to the TTL
14//! - Retry token loading when it fails
15//! - Attach the token to the request in the `x-aws-ec2-metadata-token` header
16
17use crate::identity::IdentityCache;
18use crate::imds::client::error::{ImdsError, TokenError, TokenErrorKind};
19use aws_smithy_async::time::SharedTimeSource;
20use aws_smithy_runtime::client::orchestrator::operation::Operation;
21use aws_smithy_runtime::expiring_cache::ExpiringCache;
22use aws_smithy_runtime_api::box_error::BoxError;
23use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
24use aws_smithy_runtime_api::client::auth::{
25    AuthScheme, AuthSchemeEndpointConfig, AuthSchemeId, Sign,
26};
27use aws_smithy_runtime_api::client::identity::{
28    Identity, IdentityFuture, ResolveIdentity, SharedIdentityResolver,
29};
30use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError};
31use aws_smithy_runtime_api::client::runtime_components::{
32    GetIdentityResolver, RuntimeComponents, RuntimeComponentsBuilder,
33};
34use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
35use aws_smithy_types::body::SdkBody;
36use aws_smithy_types::config_bag::ConfigBag;
37use http::{HeaderValue, Uri};
38use std::borrow::Cow;
39use std::fmt;
40use std::sync::Arc;
41use std::time::{Duration, SystemTime};
42
43/// Token Refresh Buffer
44///
45/// Tokens are cached to remove the need to reload the token between subsequent requests. To ensure
46/// that a request never fails with a 401 (expired token), a buffer window exists during which the token
47/// may not be expired, but will still be refreshed.
48const TOKEN_REFRESH_BUFFER: Duration = Duration::from_secs(120);
49
50const X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS: &str = "x-aws-ec2-metadata-token-ttl-seconds";
51const X_AWS_EC2_METADATA_TOKEN: &str = "x-aws-ec2-metadata-token";
52const IMDS_TOKEN_AUTH_SCHEME: AuthSchemeId = AuthSchemeId::new(X_AWS_EC2_METADATA_TOKEN);
53
54#[derive(Debug)]
55struct TtlToken {
56    value: HeaderValue,
57    ttl: Duration,
58}
59
60/// IMDS Token
61#[derive(Clone)]
62struct Token {
63    value: HeaderValue,
64    expiry: SystemTime,
65}
66impl fmt::Debug for Token {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        f.debug_struct("Token")
69            .field("value", &"** redacted **")
70            .field("expiry", &self.expiry)
71            .finish()
72    }
73}
74
75/// Token Runtime Plugin
76///
77/// This runtime plugin wires up the necessary components to load/cache a token
78/// when required and handle caching/expiry. This token will get attached to the
79/// request to IMDS on the `x-aws-ec2-metadata-token` header.
80#[derive(Debug)]
81pub(super) struct TokenRuntimePlugin {
82    components: RuntimeComponentsBuilder,
83}
84
85impl TokenRuntimePlugin {
86    pub(super) fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
87        Self {
88            components: RuntimeComponentsBuilder::new("TokenRuntimePlugin")
89                .with_auth_scheme(TokenAuthScheme::new())
90                .with_auth_scheme_option_resolver(Some(StaticAuthSchemeOptionResolver::new(vec![
91                    IMDS_TOKEN_AUTH_SCHEME,
92                ])))
93                // The TokenResolver has a cache of its own, so don't use identity caching
94                .with_identity_cache(Some(IdentityCache::no_cache()))
95                .with_identity_resolver(
96                    IMDS_TOKEN_AUTH_SCHEME,
97                    TokenResolver::new(common_plugin, token_ttl),
98                ),
99        }
100    }
101}
102
103impl RuntimePlugin for TokenRuntimePlugin {
104    fn runtime_components(
105        &self,
106        _current_components: &RuntimeComponentsBuilder,
107    ) -> Cow<'_, RuntimeComponentsBuilder> {
108        Cow::Borrowed(&self.components)
109    }
110}
111
112#[derive(Debug)]
113struct TokenResolverInner {
114    cache: ExpiringCache<Token, ImdsError>,
115    refresh: Operation<(), TtlToken, TokenError>,
116}
117
118#[derive(Clone, Debug)]
119struct TokenResolver {
120    inner: Arc<TokenResolverInner>,
121}
122
123impl TokenResolver {
124    fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
125        Self {
126            inner: Arc::new(TokenResolverInner {
127                cache: ExpiringCache::new(TOKEN_REFRESH_BUFFER),
128                refresh: Operation::builder()
129                    .service_name("imds")
130                    .operation_name("get-token")
131                    .runtime_plugin(common_plugin)
132                    .no_auth()
133                    .with_connection_poisoning()
134                    .serializer(move |_| {
135                        Ok(http::Request::builder()
136                            .method("PUT")
137                            .uri(Uri::from_static("/latest/api/token"))
138                            .header(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, token_ttl.as_secs())
139                            .body(SdkBody::empty())
140                            .expect("valid HTTP request")
141                            .try_into()
142                            .unwrap())
143                    })
144                    .deserializer(move |response| {
145                        parse_token_response(response).map_err(OrchestratorError::operation)
146                    })
147                    .build(),
148            }),
149        }
150    }
151
152    async fn get_token(
153        &self,
154        time_source: SharedTimeSource,
155    ) -> Result<(Token, SystemTime), ImdsError> {
156        let result = self.inner.refresh.invoke(()).await;
157        let now = time_source.now();
158        result
159            .map(|token| {
160                let token = Token {
161                    value: token.value,
162                    expiry: now + token.ttl,
163                };
164                let expiry = token.expiry;
165                (token, expiry)
166            })
167            .map_err(ImdsError::failed_to_load_token)
168    }
169}
170
171fn parse_token_response(response: &HttpResponse) -> Result<TtlToken, TokenError> {
172    match response.status().as_u16() {
173        400 => return Err(TokenErrorKind::InvalidParameters.into()),
174        403 => return Err(TokenErrorKind::Forbidden.into()),
175        _ => {}
176    }
177    let mut value =
178        HeaderValue::from_bytes(response.body().bytes().expect("non-streaming response"))
179            .map_err(|_| TokenErrorKind::InvalidToken)?;
180    value.set_sensitive(true);
181
182    let ttl: u64 = response
183        .headers()
184        .get(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS)
185        .ok_or(TokenErrorKind::NoTtl)?
186        .parse()
187        .map_err(|_parse_error| TokenErrorKind::InvalidTtl)?;
188    Ok(TtlToken {
189        value,
190        ttl: Duration::from_secs(ttl),
191    })
192}
193
194impl ResolveIdentity for TokenResolver {
195    fn resolve_identity<'a>(
196        &'a self,
197        components: &'a RuntimeComponents,
198        _config_bag: &'a ConfigBag,
199    ) -> IdentityFuture<'a> {
200        let time_source = components
201            .time_source()
202            .expect("time source required for IMDS token caching");
203        IdentityFuture::new(async {
204            let now = time_source.now();
205            let preloaded_token = self.inner.cache.yield_or_clear_if_expired(now).await;
206            let token = match preloaded_token {
207                Some(token) => {
208                    tracing::trace!(
209                        buffer_time=?TOKEN_REFRESH_BUFFER,
210                        expiration=?token.expiry,
211                        now=?now,
212                        "loaded IMDS token from cache");
213                    Ok(token)
214                }
215                None => {
216                    tracing::debug!("IMDS token cache miss");
217                    self.inner
218                        .cache
219                        .get_or_load(|| async { self.get_token(time_source).await })
220                        .await
221                }
222            }?;
223
224            let expiry = token.expiry;
225            Ok(Identity::new(token, Some(expiry)))
226        })
227    }
228}
229
230#[derive(Debug)]
231struct TokenAuthScheme {
232    signer: TokenSigner,
233}
234
235impl TokenAuthScheme {
236    fn new() -> Self {
237        Self {
238            signer: TokenSigner,
239        }
240    }
241}
242
243impl AuthScheme for TokenAuthScheme {
244    fn scheme_id(&self) -> AuthSchemeId {
245        IMDS_TOKEN_AUTH_SCHEME
246    }
247
248    fn identity_resolver(
249        &self,
250        identity_resolvers: &dyn GetIdentityResolver,
251    ) -> Option<SharedIdentityResolver> {
252        identity_resolvers.identity_resolver(IMDS_TOKEN_AUTH_SCHEME)
253    }
254
255    fn signer(&self) -> &dyn Sign {
256        &self.signer
257    }
258}
259
260#[derive(Debug)]
261struct TokenSigner;
262
263impl Sign for TokenSigner {
264    fn sign_http_request(
265        &self,
266        request: &mut HttpRequest,
267        identity: &Identity,
268        _auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
269        _runtime_components: &RuntimeComponents,
270        _config_bag: &ConfigBag,
271    ) -> Result<(), BoxError> {
272        let token = identity.data::<Token>().expect("correct type");
273        request
274            .headers_mut()
275            .append(X_AWS_EC2_METADATA_TOKEN, token.value.clone());
276        Ok(())
277    }
278}