use crate::profile::credentials::ProfileFileError;
use crate::profile::{Profile, ProfileSet};
use crate::sensitive_command::CommandWithSensitiveArgs;
use aws_credential_types::Credentials;
#[derive(Debug)]
pub(crate) struct ProfileChain<'a> {
pub(crate) base: BaseProvider<'a>,
pub(crate) chain: Vec<RoleArn<'a>>,
}
impl<'a> ProfileChain<'a> {
pub(crate) fn base(&self) -> &BaseProvider<'a> {
&self.base
}
pub(crate) fn chain(&self) -> &[RoleArn<'a>] {
self.chain.as_slice()
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub(crate) enum BaseProvider<'a> {
NamedSource(&'a str),
AccessKey(Credentials),
WebIdentityTokenRole {
role_arn: &'a str,
web_identity_token_file: &'a str,
session_name: Option<&'a str>,
},
Sso {
sso_session_name: Option<&'a str>,
sso_region: &'a str,
sso_start_url: &'a str,
sso_account_id: Option<&'a str>,
sso_role_name: Option<&'a str>,
},
CredentialProcess(CommandWithSensitiveArgs<&'a str>),
}
#[derive(Debug)]
pub(crate) struct RoleArn<'a> {
pub(crate) role_arn: &'a str,
pub(crate) external_id: Option<&'a str>,
pub(crate) session_name: Option<&'a str>,
}
pub(crate) fn resolve_chain(
profile_set: &ProfileSet,
) -> Result<ProfileChain<'_>, ProfileFileError> {
if profile_set.is_empty() {
return Err(ProfileFileError::NoProfilesDefined);
}
if profile_set.selected_profile() == "default" && profile_set.get_profile("default").is_none() {
tracing::debug!("No default profile defined");
return Err(ProfileFileError::NoProfilesDefined);
}
let mut source_profile_name = profile_set.selected_profile();
let mut visited_profiles = vec![];
let mut chain = vec![];
let base = loop {
let profile = profile_set.get_profile(source_profile_name).ok_or(
ProfileFileError::MissingProfile {
profile: source_profile_name.into(),
message: format!(
"could not find source profile {} referenced from {}",
source_profile_name,
visited_profiles.last().unwrap_or(&"the root profile")
)
.into(),
},
)?;
if visited_profiles.contains(&source_profile_name) {
return Err(ProfileFileError::CredentialLoop {
profiles: visited_profiles
.into_iter()
.map(|s| s.to_string())
.collect(),
next: source_profile_name.to_string(),
});
}
visited_profiles.push(source_profile_name);
if visited_profiles.len() > 1 {
let try_static = static_creds_from_profile(profile);
if let Ok(static_credentials) = try_static {
break BaseProvider::AccessKey(static_credentials);
}
}
let next_profile = {
if let Some(role_provider) = role_arn_from_profile(profile) {
let next = chain_provider(profile)?;
chain.push(role_provider);
next
} else {
break base_provider(profile_set, profile).map_err(|err| {
if visited_profiles.len() == 1 {
err
} else {
ProfileFileError::InvalidCredentialSource {
profile: profile.name().into(),
message: format!("could not load source profile: {}", err).into(),
}
}
})?;
}
};
match next_profile {
NextProfile::SelfReference => {
break base_provider(profile_set, profile)?;
}
NextProfile::Named(name) => source_profile_name = name,
}
};
chain.reverse();
Ok(ProfileChain { base, chain })
}
mod role {
pub(super) const ROLE_ARN: &str = "role_arn";
pub(super) const EXTERNAL_ID: &str = "external_id";
pub(super) const SESSION_NAME: &str = "role_session_name";
pub(super) const CREDENTIAL_SOURCE: &str = "credential_source";
pub(super) const SOURCE_PROFILE: &str = "source_profile";
}
mod sso {
pub(super) const ACCOUNT_ID: &str = "sso_account_id";
pub(super) const REGION: &str = "sso_region";
pub(super) const ROLE_NAME: &str = "sso_role_name";
pub(super) const START_URL: &str = "sso_start_url";
pub(super) const SESSION_NAME: &str = "sso_session";
}
mod web_identity_token {
pub(super) const TOKEN_FILE: &str = "web_identity_token_file";
}
mod static_credentials {
pub(super) const AWS_ACCESS_KEY_ID: &str = "aws_access_key_id";
pub(super) const AWS_SECRET_ACCESS_KEY: &str = "aws_secret_access_key";
pub(super) const AWS_SESSION_TOKEN: &str = "aws_session_token";
}
mod credential_process {
pub(super) const CREDENTIAL_PROCESS: &str = "credential_process";
}
const PROVIDER_NAME: &str = "ProfileFile";
fn base_provider<'a>(
profile_set: &'a ProfileSet,
profile: &'a Profile,
) -> Result<BaseProvider<'a>, ProfileFileError> {
match profile.get(role::CREDENTIAL_SOURCE) {
Some(source) => Ok(BaseProvider::NamedSource(source)),
None => web_identity_token_from_profile(profile)
.or_else(|| sso_from_profile(profile_set, profile).transpose())
.or_else(|| credential_process_from_profile(profile))
.unwrap_or_else(|| Ok(BaseProvider::AccessKey(static_creds_from_profile(profile)?))),
}
}
enum NextProfile<'a> {
SelfReference,
Named(&'a str),
}
fn chain_provider(profile: &Profile) -> Result<NextProfile<'_>, ProfileFileError> {
let (source_profile, credential_source) = (
profile.get(role::SOURCE_PROFILE),
profile.get(role::CREDENTIAL_SOURCE),
);
match (source_profile, credential_source) {
(Some(_), Some(_)) => Err(ProfileFileError::InvalidCredentialSource {
profile: profile.name().to_string(),
message: "profile contained both source_profile and credential_source. \
Only one or the other can be defined"
.into(),
}),
(None, None) => Err(ProfileFileError::InvalidCredentialSource {
profile: profile.name().to_string(),
message:
"profile must contain `source_profile` or `credential_source` but neither were defined"
.into(),
}),
(Some(source_profile), None) if source_profile == profile.name() => {
Ok(NextProfile::SelfReference)
}
(Some(source_profile), None) => Ok(NextProfile::Named(source_profile)),
(None, Some(_credential_source)) => Ok(NextProfile::SelfReference),
}
}
fn role_arn_from_profile(profile: &Profile) -> Option<RoleArn<'_>> {
if profile.get(web_identity_token::TOKEN_FILE).is_some() {
return None;
}
let role_arn = profile.get(role::ROLE_ARN)?;
let session_name = profile.get(role::SESSION_NAME);
let external_id = profile.get(role::EXTERNAL_ID);
Some(RoleArn {
role_arn,
external_id,
session_name,
})
}
fn sso_from_profile<'a>(
profile_set: &'a ProfileSet,
profile: &'a Profile,
) -> Result<Option<BaseProvider<'a>>, ProfileFileError> {
let sso_account_id = profile.get(sso::ACCOUNT_ID);
let mut sso_region = profile.get(sso::REGION);
let sso_role_name = profile.get(sso::ROLE_NAME);
let mut sso_start_url = profile.get(sso::START_URL);
let sso_session_name = profile.get(sso::SESSION_NAME);
if [
sso_account_id,
sso_region,
sso_role_name,
sso_start_url,
sso_session_name,
]
.iter()
.all(Option::is_none)
{
return Ok(None);
}
let invalid_sso_config = |s: &str| ProfileFileError::InvalidSsoConfig {
profile: profile.name().into(),
message: format!(
"`{s}` can only be specified in the [sso-session] config when a session name is given"
)
.into(),
};
if let Some(sso_session_name) = sso_session_name {
if sso_start_url.is_some() {
return Err(invalid_sso_config(sso::START_URL));
}
if sso_region.is_some() {
return Err(invalid_sso_config(sso::REGION));
}
if let Some(session) = profile_set.sso_session(sso_session_name) {
sso_start_url = session.get(sso::START_URL);
sso_region = session.get(sso::REGION);
} else {
return Err(ProfileFileError::MissingSsoSession {
profile: profile.name().into(),
sso_session: sso_session_name.into(),
});
}
}
let invalid_sso_creds = |left: &str, right: &str| ProfileFileError::InvalidSsoConfig {
profile: profile.name().into(),
message: format!("if `{left}` is set, then `{right}` must also be set").into(),
};
match (sso_account_id, sso_role_name) {
(Some(_), Some(_)) | (None, None) => { }
(Some(_), None) => return Err(invalid_sso_creds(sso::ACCOUNT_ID, sso::ROLE_NAME)),
(None, Some(_)) => return Err(invalid_sso_creds(sso::ROLE_NAME, sso::ACCOUNT_ID)),
}
let missing_field = |s| move || ProfileFileError::missing_field(profile, s);
let sso_region = sso_region.ok_or_else(missing_field(sso::REGION))?;
let sso_start_url = sso_start_url.ok_or_else(missing_field(sso::START_URL))?;
Ok(Some(BaseProvider::Sso {
sso_account_id,
sso_region,
sso_role_name,
sso_start_url,
sso_session_name,
}))
}
fn web_identity_token_from_profile(
profile: &Profile,
) -> Option<Result<BaseProvider<'_>, ProfileFileError>> {
let session_name = profile.get(role::SESSION_NAME);
match (
profile.get(role::ROLE_ARN),
profile.get(web_identity_token::TOKEN_FILE),
) {
(Some(role_arn), Some(token_file)) => Some(Ok(BaseProvider::WebIdentityTokenRole {
role_arn,
web_identity_token_file: token_file,
session_name,
})),
(None, None) => None,
(Some(_role_arn), None) => None,
(None, Some(_token_file)) => Some(Err(ProfileFileError::InvalidCredentialSource {
profile: profile.name().to_string(),
message: "`web_identity_token_file` was specified but `role_arn` was missing".into(),
})),
}
}
fn static_creds_from_profile(profile: &Profile) -> Result<Credentials, ProfileFileError> {
use static_credentials::*;
let access_key = profile.get(AWS_ACCESS_KEY_ID);
let secret_key = profile.get(AWS_SECRET_ACCESS_KEY);
let session_token = profile.get(AWS_SESSION_TOKEN);
if let (None, None, None) = (access_key, secret_key, session_token) {
return Err(ProfileFileError::ProfileDidNotContainCredentials {
profile: profile.name().to_string(),
});
}
let access_key = access_key.ok_or_else(|| ProfileFileError::InvalidCredentialSource {
profile: profile.name().to_string(),
message: "profile missing aws_access_key_id".into(),
})?;
let secret_key = secret_key.ok_or_else(|| ProfileFileError::InvalidCredentialSource {
profile: profile.name().to_string(),
message: "profile missing aws_secret_access_key".into(),
})?;
Ok(Credentials::new(
access_key,
secret_key,
session_token.map(|s| s.to_string()),
None,
PROVIDER_NAME,
))
}
fn credential_process_from_profile(
profile: &Profile,
) -> Option<Result<BaseProvider<'_>, ProfileFileError>> {
profile
.get(credential_process::CREDENTIAL_PROCESS)
.map(|credential_process| {
Ok(BaseProvider::CredentialProcess(
CommandWithSensitiveArgs::new(credential_process),
))
})
}
#[cfg(test)]
mod tests {
#[cfg(feature = "test-util")]
use super::ProfileChain;
use crate::profile::credentials::repr::BaseProvider;
use crate::sensitive_command::CommandWithSensitiveArgs;
use serde::Deserialize;
#[cfg(feature = "test-util")]
use std::collections::HashMap;
#[cfg(feature = "test-util")]
#[test]
fn run_test_cases() -> Result<(), Box<dyn std::error::Error>> {
let test_cases: Vec<TestCase> = serde_json::from_str(&std::fs::read_to_string(
"./test-data/assume-role-tests.json",
)?)?;
for test_case in test_cases {
print!("checking: {}...", test_case.docs);
check(test_case);
println!("ok")
}
Ok(())
}
#[cfg(feature = "test-util")]
fn check(test_case: TestCase) {
use super::resolve_chain;
use aws_runtime::env_config::property::Properties;
use aws_runtime::env_config::section::EnvConfigSections;
let source = EnvConfigSections::new(
test_case.input.profiles,
test_case.input.selected_profile,
test_case.input.sso_sessions,
Properties::new(),
);
let actual = resolve_chain(&source);
let expected = test_case.output;
match (expected, actual) {
(TestOutput::Error(s), Err(e)) => assert!(
format!("{}", e).contains(&s),
"expected\n{}\nto contain\n{}\n",
e,
s
),
(TestOutput::ProfileChain(expected), Ok(actual)) => {
assert_eq!(to_test_output(actual), expected)
}
(expected, actual) => panic!(
"error/success mismatch. Expected:\n {:?}\nActual:\n {:?}",
&expected, actual
),
}
}
#[derive(Deserialize)]
#[cfg(feature = "test-util")]
struct TestCase {
docs: String,
input: TestInput,
output: TestOutput,
}
#[derive(Deserialize)]
#[cfg(feature = "test-util")]
struct TestInput {
profiles: HashMap<String, HashMap<String, String>>,
selected_profile: String,
#[serde(default)]
sso_sessions: HashMap<String, HashMap<String, String>>,
}
#[cfg(feature = "test-util")]
fn to_test_output(profile_chain: ProfileChain<'_>) -> Vec<Provider> {
let mut output = vec![];
match profile_chain.base {
BaseProvider::NamedSource(name) => output.push(Provider::NamedSource(name.into())),
BaseProvider::AccessKey(creds) => output.push(Provider::AccessKey {
access_key_id: creds.access_key_id().into(),
secret_access_key: creds.secret_access_key().into(),
session_token: creds.session_token().map(|tok| tok.to_string()),
}),
BaseProvider::CredentialProcess(credential_process) => output.push(
Provider::CredentialProcess(credential_process.unredacted().into()),
),
BaseProvider::WebIdentityTokenRole {
role_arn,
web_identity_token_file,
session_name,
} => output.push(Provider::WebIdentityToken {
role_arn: role_arn.into(),
web_identity_token_file: web_identity_token_file.into(),
role_session_name: session_name.map(|sess| sess.to_string()),
}),
BaseProvider::Sso {
sso_region,
sso_start_url,
sso_session_name,
sso_account_id,
sso_role_name,
} => output.push(Provider::Sso {
sso_region: sso_region.into(),
sso_start_url: sso_start_url.into(),
sso_session: sso_session_name.map(|s| s.to_string()),
sso_account_id: sso_account_id.map(|s| s.to_string()),
sso_role_name: sso_role_name.map(|s| s.to_string()),
}),
};
for role in profile_chain.chain {
output.push(Provider::AssumeRole {
role_arn: role.role_arn.into(),
external_id: role.external_id.map(ToString::to_string),
role_session_name: role.session_name.map(ToString::to_string),
})
}
output
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
enum TestOutput {
ProfileChain(Vec<Provider>),
Error(String),
}
#[derive(Deserialize, Debug, Eq, PartialEq)]
enum Provider {
AssumeRole {
role_arn: String,
external_id: Option<String>,
role_session_name: Option<String>,
},
AccessKey {
access_key_id: String,
secret_access_key: String,
session_token: Option<String>,
},
NamedSource(String),
CredentialProcess(String),
WebIdentityToken {
role_arn: String,
web_identity_token_file: String,
role_session_name: Option<String>,
},
Sso {
sso_region: String,
sso_start_url: String,
sso_session: Option<String>,
sso_account_id: Option<String>,
sso_role_name: Option<String>,
},
}
#[test]
fn base_provider_process_credentials_args_redaction() {
assert_eq!(
"CredentialProcess(\"program\")",
format!(
"{:?}",
BaseProvider::CredentialProcess(CommandWithSensitiveArgs::new("program"))
)
);
assert_eq!(
"CredentialProcess(\"program ** arguments redacted **\")",
format!(
"{:?}",
BaseProvider::CredentialProcess(CommandWithSensitiveArgs::new("program arg1 arg2"))
)
);
assert_eq!(
"CredentialProcess(\"program ** arguments redacted **\")",
format!(
"{:?}",
BaseProvider::CredentialProcess(CommandWithSensitiveArgs::new(
"program\targ1 arg2"
))
)
);
}
}