1use crate::expiring_cache::ExpiringCache;
7use aws_smithy_async::future::timeout::Timeout;
8use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
9use aws_smithy_async::time::{SharedTimeSource, TimeSource};
10use aws_smithy_runtime_api::box_error::BoxError;
11use aws_smithy_runtime_api::client::identity::{
12 Identity, IdentityCachePartition, IdentityFuture, ResolveCachedIdentity, ResolveIdentity,
13 SharedIdentityCache, SharedIdentityResolver,
14};
15use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
16use aws_smithy_runtime_api::shared::IntoShared;
17use aws_smithy_types::config_bag::ConfigBag;
18use aws_smithy_types::DateTime;
19use std::collections::HashMap;
20use std::fmt;
21use std::sync::RwLock;
22use std::time::Duration;
23use tracing::Instrument;
24
25const DEFAULT_LOAD_TIMEOUT: Duration = Duration::from_secs(5);
26const DEFAULT_EXPIRATION: Duration = Duration::from_secs(15 * 60);
27const DEFAULT_BUFFER_TIME: Duration = Duration::from_secs(10);
28const DEFAULT_BUFFER_TIME_JITTER_FRACTION: fn() -> f64 = || fastrand::f64() * 0.5;
29
30#[derive(Default, Debug)]
32pub struct LazyCacheBuilder {
33 time_source: Option<SharedTimeSource>,
34 sleep_impl: Option<SharedAsyncSleep>,
35 load_timeout: Option<Duration>,
36 buffer_time: Option<Duration>,
37 buffer_time_jitter_fraction: Option<fn() -> f64>,
38 default_expiration: Option<Duration>,
39}
40
41impl LazyCacheBuilder {
42 pub fn new() -> Self {
44 Default::default()
45 }
46
47 pub fn time_source(mut self, time_source: impl TimeSource + 'static) -> Self {
49 self.set_time_source(time_source.into_shared());
50 self
51 }
52 pub fn set_time_source(&mut self, time_source: SharedTimeSource) -> &mut Self {
54 self.time_source = Some(time_source.into_shared());
55 self
56 }
57
58 pub fn sleep_impl(mut self, sleep_impl: impl AsyncSleep + 'static) -> Self {
60 self.set_sleep_impl(sleep_impl.into_shared());
61 self
62 }
63 pub fn set_sleep_impl(&mut self, sleep_impl: SharedAsyncSleep) -> &mut Self {
65 self.sleep_impl = Some(sleep_impl);
66 self
67 }
68
69 pub fn load_timeout(mut self, timeout: Duration) -> Self {
73 self.set_load_timeout(Some(timeout));
74 self
75 }
76
77 pub fn set_load_timeout(&mut self, timeout: Option<Duration>) -> &mut Self {
81 self.load_timeout = timeout;
82 self
83 }
84
85 pub fn buffer_time(mut self, buffer_time: Duration) -> Self {
94 self.set_buffer_time(Some(buffer_time));
95 self
96 }
97
98 pub fn set_buffer_time(&mut self, buffer_time: Option<Duration>) -> &mut Self {
107 self.buffer_time = buffer_time;
108 self
109 }
110
111 #[allow(unused)]
119 #[cfg(test)]
120 fn buffer_time_jitter_fraction(mut self, buffer_time_jitter_fraction: fn() -> f64) -> Self {
121 self.set_buffer_time_jitter_fraction(Some(buffer_time_jitter_fraction));
122 self
123 }
124
125 #[allow(unused)]
133 #[cfg(test)]
134 fn set_buffer_time_jitter_fraction(
135 &mut self,
136 buffer_time_jitter_fraction: Option<fn() -> f64>,
137 ) -> &mut Self {
138 self.buffer_time_jitter_fraction = buffer_time_jitter_fraction;
139 self
140 }
141
142 pub fn default_expiration(mut self, duration: Duration) -> Self {
149 self.set_default_expiration(Some(duration));
150 self
151 }
152
153 pub fn set_default_expiration(&mut self, duration: Option<Duration>) -> &mut Self {
160 self.default_expiration = duration;
161 self
162 }
163
164 pub fn build(self) -> SharedIdentityCache {
170 let default_expiration = self.default_expiration.unwrap_or(DEFAULT_EXPIRATION);
171 assert!(
172 default_expiration >= DEFAULT_EXPIRATION,
173 "default_expiration must be at least 15 minutes"
174 );
175 LazyCache::new(
176 self.load_timeout.unwrap_or(DEFAULT_LOAD_TIMEOUT),
177 self.buffer_time.unwrap_or(DEFAULT_BUFFER_TIME),
178 self.buffer_time_jitter_fraction
179 .unwrap_or(DEFAULT_BUFFER_TIME_JITTER_FRACTION),
180 default_expiration,
181 )
182 .into_shared()
183 }
184}
185
186#[derive(Debug)]
187struct CachePartitions {
188 partitions: RwLock<HashMap<IdentityCachePartition, ExpiringCache<Identity, BoxError>>>,
189 buffer_time: Duration,
190}
191
192impl CachePartitions {
193 fn new(buffer_time: Duration) -> Self {
194 Self {
195 partitions: RwLock::new(HashMap::new()),
196 buffer_time,
197 }
198 }
199
200 fn partition(&self, key: IdentityCachePartition) -> ExpiringCache<Identity, BoxError> {
201 let mut partition = self.partitions.read().unwrap().get(&key).cloned();
202 if partition.is_none() {
205 let mut partitions = self.partitions.write().unwrap();
206 partitions
209 .entry(key)
210 .or_insert_with(|| ExpiringCache::new(self.buffer_time));
211 drop(partitions);
212
213 partition = self.partitions.read().unwrap().get(&key).cloned();
214 }
215 partition.expect("inserted above if not present")
216 }
217}
218
219#[derive(Debug)]
220struct LazyCache {
221 partitions: CachePartitions,
222 load_timeout: Duration,
223 buffer_time: Duration,
224 buffer_time_jitter_fraction: fn() -> f64,
225 default_expiration: Duration,
226}
227
228impl LazyCache {
229 fn new(
230 load_timeout: Duration,
231 buffer_time: Duration,
232 buffer_time_jitter_fraction: fn() -> f64,
233 default_expiration: Duration,
234 ) -> Self {
235 Self {
236 partitions: CachePartitions::new(buffer_time),
237 load_timeout,
238 buffer_time,
239 buffer_time_jitter_fraction,
240 default_expiration,
241 }
242 }
243}
244
245macro_rules! required_err {
246 ($thing:literal, $how:literal) => {
247 BoxError::from(concat!(
248 "Lazy identity caching requires ",
249 $thing,
250 " to be configured. ",
251 $how,
252 " If this isn't possible, then disable identity caching by calling ",
253 "the `identity_cache` method on config with `IdentityCache::no_cache()`",
254 ))
255 };
256}
257macro_rules! validate_components {
258 ($components:ident) => {
259 let _ = $components.time_source().ok_or_else(|| {
260 required_err!(
261 "a time source",
262 "Set a time source using the `time_source` method on config."
263 )
264 })?;
265 let _ = $components.sleep_impl().ok_or_else(|| {
266 required_err!(
267 "an async sleep implementation",
268 "Set a sleep impl using the `sleep_impl` method on config."
269 )
270 })?;
271 };
272}
273
274impl ResolveCachedIdentity for LazyCache {
275 fn validate_base_client_config(
276 &self,
277 runtime_components: &aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder,
278 _cfg: &ConfigBag,
279 ) -> Result<(), BoxError> {
280 validate_components!(runtime_components);
281 Ok(())
282 }
283
284 fn validate_final_config(
285 &self,
286 runtime_components: &RuntimeComponents,
287 _cfg: &ConfigBag,
288 ) -> Result<(), BoxError> {
289 validate_components!(runtime_components);
290 Ok(())
291 }
292
293 fn resolve_cached_identity<'a>(
294 &'a self,
295 resolver: SharedIdentityResolver,
296 runtime_components: &'a RuntimeComponents,
297 config_bag: &'a ConfigBag,
298 ) -> IdentityFuture<'a> {
299 let (time_source, sleep_impl) = (
300 runtime_components.time_source().expect("validated"),
301 runtime_components.sleep_impl().expect("validated"),
302 );
303
304 let now = time_source.now();
305 let timeout_future = sleep_impl.sleep(self.load_timeout);
306 let load_timeout = self.load_timeout;
307 let partition = resolver.cache_partition();
308 let cache = self.partitions.partition(partition);
309 let default_expiration = self.default_expiration;
310
311 IdentityFuture::new(async move {
312 if let Some(identity) = cache.yield_or_clear_if_expired(now).await {
314 tracing::debug!(
315 buffer_time=?self.buffer_time,
316 cached_expiration=?identity.expiration(),
317 now=?now,
318 "loaded identity from cache"
319 );
320 Ok(identity)
321 } else {
322 let start_time = time_source.now();
327 let result = cache
328 .get_or_load(|| {
329 let span = tracing::info_span!("lazy_load_identity");
330 async move {
331 let fut = Timeout::new(
332 resolver.resolve_identity(runtime_components, config_bag),
333 timeout_future,
334 );
335 let identity = match fut.await {
336 Ok(result) => result?,
337 Err(_err) => match resolver.fallback_on_interrupt() {
338 Some(identity) => identity,
339 None => {
340 return Err(BoxError::from(TimedOutError(load_timeout)))
341 }
342 },
343 };
344 let expiration =
346 identity.expiration().unwrap_or(now + default_expiration);
347
348 let jitter = self
349 .buffer_time
350 .mul_f64((self.buffer_time_jitter_fraction)());
351
352 let printable = DateTime::from(expiration);
357 tracing::debug!(
358 new_expiration=%printable,
359 valid_for=?expiration.duration_since(time_source.now()).unwrap_or_default(),
360 partition=?partition,
361 "identity cache miss occurred; added new identity (took {:?})",
362 time_source.now().duration_since(start_time).unwrap_or_default()
363 );
364
365 Ok((identity, expiration + jitter))
366 }
367 .instrument(span)
370 })
371 .await;
372 tracing::debug!("loaded identity");
373 result
374 }
375 })
376 }
377}
378
379#[derive(Debug)]
380struct TimedOutError(Duration);
381
382impl std::error::Error for TimedOutError {}
383
384impl fmt::Display for TimedOutError {
385 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386 write!(f, "identity resolver timed out after {:?}", self.0)
387 }
388}
389
390#[cfg(all(test, feature = "client", feature = "http-auth"))]
391mod tests {
392 use super::*;
393 use aws_smithy_async::rt::sleep::TokioSleep;
394 use aws_smithy_async::test_util::{instant_time_and_sleep, ManualTimeSource};
395 use aws_smithy_async::time::TimeSource;
396 use aws_smithy_runtime_api::client::identity::http::Token;
397 use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
398 use std::sync::atomic::{AtomicUsize, Ordering};
399 use std::sync::{Arc, Mutex};
400 use std::time::{Duration, SystemTime, UNIX_EPOCH};
401 use tracing::info;
402
403 const BUFFER_TIME_NO_JITTER: fn() -> f64 = || 0_f64;
404
405 struct ResolverFn<F>(F);
406 impl<F> fmt::Debug for ResolverFn<F> {
407 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408 f.write_str("ResolverFn")
409 }
410 }
411 impl<F> ResolveIdentity for ResolverFn<F>
412 where
413 F: Fn() -> IdentityFuture<'static> + Send + Sync,
414 {
415 fn resolve_identity<'a>(
416 &'a self,
417 _: &'a RuntimeComponents,
418 _config_bag: &'a ConfigBag,
419 ) -> IdentityFuture<'a> {
420 (self.0)()
421 }
422 }
423
424 fn resolver_fn<F>(f: F) -> SharedIdentityResolver
425 where
426 F: Fn() -> IdentityFuture<'static> + Send + Sync + 'static,
427 {
428 SharedIdentityResolver::new(ResolverFn(f))
429 }
430
431 fn test_cache(
432 buffer_time_jitter_fraction: fn() -> f64,
433 load_list: Vec<Result<Identity, BoxError>>,
434 ) -> (LazyCache, SharedIdentityResolver) {
435 #[derive(Debug)]
436 struct Resolver(Mutex<Vec<Result<Identity, BoxError>>>);
437 impl ResolveIdentity for Resolver {
438 fn resolve_identity<'a>(
439 &'a self,
440 _: &'a RuntimeComponents,
441 _config_bag: &'a ConfigBag,
442 ) -> IdentityFuture<'a> {
443 let mut list = self.0.lock().unwrap();
444 if list.len() > 0 {
445 let next = list.remove(0);
446 info!("refreshing the identity to {:?}", next);
447 IdentityFuture::ready(next)
448 } else {
449 drop(list);
450 panic!("no more identities")
451 }
452 }
453 }
454
455 let identity_resolver = SharedIdentityResolver::new(Resolver(Mutex::new(load_list)));
456 let cache = LazyCache::new(
457 DEFAULT_LOAD_TIMEOUT,
458 DEFAULT_BUFFER_TIME,
459 buffer_time_jitter_fraction,
460 DEFAULT_EXPIRATION,
461 );
462 (cache, identity_resolver)
463 }
464
465 fn epoch_secs(secs: u64) -> SystemTime {
466 SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
467 }
468
469 fn test_identity(expired_secs: u64) -> Identity {
470 let expiration = Some(epoch_secs(expired_secs));
471 Identity::new(Token::new("test", expiration), expiration)
472 }
473
474 async fn expect_identity(
475 expired_secs: u64,
476 cache: &LazyCache,
477 components: &RuntimeComponents,
478 resolver: SharedIdentityResolver,
479 ) {
480 let config_bag = ConfigBag::base();
481 let identity = cache
482 .resolve_cached_identity(resolver, components, &config_bag)
483 .await
484 .expect("expected identity");
485 assert_eq!(Some(epoch_secs(expired_secs)), identity.expiration());
486 }
487
488 #[tokio::test]
489 async fn initial_populate_test_identity() {
490 let time = ManualTimeSource::new(UNIX_EPOCH);
491 let components = RuntimeComponentsBuilder::for_tests()
492 .with_time_source(Some(time.clone()))
493 .with_sleep_impl(Some(TokioSleep::new()))
494 .build()
495 .unwrap();
496 let config_bag = ConfigBag::base();
497 let resolver = SharedIdentityResolver::new(resolver_fn(|| {
498 info!("refreshing the test_identity");
499 IdentityFuture::ready(Ok(test_identity(1000)))
500 }));
501 let cache = LazyCache::new(
502 DEFAULT_LOAD_TIMEOUT,
503 DEFAULT_BUFFER_TIME,
504 BUFFER_TIME_NO_JITTER,
505 DEFAULT_EXPIRATION,
506 );
507 assert_eq!(
508 epoch_secs(1000),
509 cache
510 .resolve_cached_identity(resolver, &components, &config_bag)
511 .await
512 .unwrap()
513 .expiration()
514 .unwrap()
515 );
516 }
517
518 #[tokio::test]
519 async fn reload_expired_test_identity() {
520 let time = ManualTimeSource::new(epoch_secs(100));
521 let components = RuntimeComponentsBuilder::for_tests()
522 .with_time_source(Some(time.clone()))
523 .with_sleep_impl(Some(TokioSleep::new()))
524 .build()
525 .unwrap();
526 let (cache, resolver) = test_cache(
527 BUFFER_TIME_NO_JITTER,
528 vec![
529 Ok(test_identity(1000)),
530 Ok(test_identity(2000)),
531 Ok(test_identity(3000)),
532 ],
533 );
534
535 expect_identity(1000, &cache, &components, resolver.clone()).await;
536 expect_identity(1000, &cache, &components, resolver.clone()).await;
537 time.set_time(epoch_secs(1500));
538 expect_identity(2000, &cache, &components, resolver.clone()).await;
539 expect_identity(2000, &cache, &components, resolver.clone()).await;
540 time.set_time(epoch_secs(2500));
541 expect_identity(3000, &cache, &components, resolver.clone()).await;
542 expect_identity(3000, &cache, &components, resolver.clone()).await;
543 }
544
545 #[tokio::test]
546 async fn load_failed_error() {
547 let config_bag = ConfigBag::base();
548 let time = ManualTimeSource::new(epoch_secs(100));
549 let components = RuntimeComponentsBuilder::for_tests()
550 .with_time_source(Some(time.clone()))
551 .with_sleep_impl(Some(TokioSleep::new()))
552 .build()
553 .unwrap();
554 let (cache, resolver) = test_cache(
555 BUFFER_TIME_NO_JITTER,
556 vec![Ok(test_identity(1000)), Err("failed".into())],
557 );
558
559 expect_identity(1000, &cache, &components, resolver.clone()).await;
560 time.set_time(epoch_secs(1500));
561 assert!(cache
562 .resolve_cached_identity(resolver.clone(), &components, &config_bag)
563 .await
564 .is_err());
565 }
566
567 #[test]
568 fn load_contention() {
569 let rt = tokio::runtime::Builder::new_multi_thread()
570 .enable_time()
571 .worker_threads(16)
572 .build()
573 .unwrap();
574
575 let time = ManualTimeSource::new(epoch_secs(0));
576 let components = RuntimeComponentsBuilder::for_tests()
577 .with_time_source(Some(time.clone()))
578 .with_sleep_impl(Some(TokioSleep::new()))
579 .build()
580 .unwrap();
581 let (cache, resolver) = test_cache(
582 BUFFER_TIME_NO_JITTER,
583 vec![
584 Ok(test_identity(500)),
585 Ok(test_identity(1500)),
586 Ok(test_identity(2500)),
587 Ok(test_identity(3500)),
588 Ok(test_identity(4500)),
589 ],
590 );
591 let cache: SharedIdentityCache = cache.into_shared();
592
593 for _ in 0..4 {
596 let mut tasks = Vec::new();
597 for _ in 0..50 {
598 let resolver = resolver.clone();
599 let cache = cache.clone();
600 let time = time.clone();
601 let components = components.clone();
602 tasks.push(rt.spawn(async move {
603 let now = time.advance(Duration::from_secs(22));
604
605 let config_bag = ConfigBag::base();
606 let identity = cache
607 .resolve_cached_identity(resolver, &components, &config_bag)
608 .await
609 .unwrap();
610 assert!(
611 identity.expiration().unwrap() >= now,
612 "{:?} >= {:?}",
613 identity.expiration(),
614 now
615 );
616 }));
617 }
618 for task in tasks {
619 rt.block_on(task).unwrap();
620 }
621 }
622 }
623
624 #[tokio::test]
625 async fn load_timeout() {
626 let config_bag = ConfigBag::base();
627 let (time, sleep) = instant_time_and_sleep(epoch_secs(100));
628 let components = RuntimeComponentsBuilder::for_tests()
629 .with_time_source(Some(time.clone()))
630 .with_sleep_impl(Some(sleep))
631 .build()
632 .unwrap();
633 let resolver = SharedIdentityResolver::new(resolver_fn(|| {
634 IdentityFuture::new(async {
635 aws_smithy_async::future::never::Never::new().await;
636 Ok(test_identity(1000))
637 })
638 }));
639 let cache = LazyCache::new(
640 Duration::from_secs(5),
641 DEFAULT_BUFFER_TIME,
642 BUFFER_TIME_NO_JITTER,
643 DEFAULT_EXPIRATION,
644 );
645
646 let err: BoxError = cache
647 .resolve_cached_identity(resolver, &components, &config_bag)
648 .await
649 .expect_err("it should return an error");
650 let downcasted = err.downcast_ref::<TimedOutError>();
651 assert!(
652 downcasted.is_some(),
653 "expected a BoxError of TimedOutError, but was {err:?}"
654 );
655 assert_eq!(time.now(), epoch_secs(105));
656 }
657
658 #[tokio::test]
659 async fn buffer_time_jitter() {
660 let time = ManualTimeSource::new(epoch_secs(100));
661 let components = RuntimeComponentsBuilder::for_tests()
662 .with_time_source(Some(time.clone()))
663 .with_sleep_impl(Some(TokioSleep::new()))
664 .build()
665 .unwrap();
666 let buffer_time_jitter_fraction = || 0.5_f64;
667 let (cache, resolver) = test_cache(
668 buffer_time_jitter_fraction,
669 vec![Ok(test_identity(1000)), Ok(test_identity(2000))],
670 );
671
672 expect_identity(1000, &cache, &components, resolver.clone()).await;
673 let buffer_time_with_jitter =
674 (DEFAULT_BUFFER_TIME.as_secs_f64() * buffer_time_jitter_fraction()) as u64;
675 assert_eq!(buffer_time_with_jitter, 5);
676 let almost_expired_secs = 1000 - buffer_time_with_jitter - 1;
678 time.set_time(epoch_secs(almost_expired_secs));
679 expect_identity(1000, &cache, &components, resolver.clone()).await;
681 let expired_secs = almost_expired_secs + 1;
683 time.set_time(epoch_secs(expired_secs));
684 expect_identity(2000, &cache, &components, resolver.clone()).await;
686 }
687
688 #[tokio::test]
689 async fn cache_partitioning() {
690 let time = ManualTimeSource::new(epoch_secs(0));
691 let components = RuntimeComponentsBuilder::for_tests()
692 .with_time_source(Some(time.clone()))
693 .with_sleep_impl(Some(TokioSleep::new()))
694 .build()
695 .unwrap();
696 let (cache, _) = test_cache(BUFFER_TIME_NO_JITTER, Vec::new());
697
698 #[allow(clippy::disallowed_methods)]
699 let far_future = SystemTime::now() + Duration::from_secs(10_000);
700
701 let resolver_a_calls = Arc::new(AtomicUsize::new(0));
704 let resolver_b_calls = Arc::new(AtomicUsize::new(0));
705 let resolver_a = resolver_fn({
706 let calls = resolver_a_calls.clone();
707 move || {
708 calls.fetch_add(1, Ordering::Relaxed);
709 IdentityFuture::ready(Ok(Identity::new(
710 Token::new("A", Some(far_future)),
711 Some(far_future),
712 )))
713 }
714 });
715 let resolver_b = resolver_fn({
716 let calls = resolver_b_calls.clone();
717 move || {
718 calls.fetch_add(1, Ordering::Relaxed);
719 IdentityFuture::ready(Ok(Identity::new(
720 Token::new("B", Some(far_future)),
721 Some(far_future),
722 )))
723 }
724 });
725 assert_ne!(
726 resolver_a.cache_partition(),
727 resolver_b.cache_partition(),
728 "pre-condition: they should have different partition keys"
729 );
730
731 let config_bag = ConfigBag::base();
732
733 let identity = cache
736 .resolve_cached_identity(resolver_a.clone(), &components, &config_bag)
737 .await
738 .unwrap();
739 assert_eq!("A", identity.data::<Token>().unwrap().token());
740 let identity = cache
741 .resolve_cached_identity(resolver_a.clone(), &components, &config_bag)
742 .await
743 .unwrap();
744 assert_eq!("A", identity.data::<Token>().unwrap().token());
745 assert_eq!(1, resolver_a_calls.load(Ordering::Relaxed));
746
747 let identity = cache
750 .resolve_cached_identity(resolver_b.clone(), &components, &config_bag)
751 .await
752 .unwrap();
753 assert_eq!("B", identity.data::<Token>().unwrap().token());
754 let identity = cache
755 .resolve_cached_identity(resolver_b.clone(), &components, &config_bag)
756 .await
757 .unwrap();
758 assert_eq!("B", identity.data::<Token>().unwrap().token());
759 assert_eq!(1, resolver_a_calls.load(Ordering::Relaxed));
760 assert_eq!(1, resolver_b_calls.load(Ordering::Relaxed));
761
762 let identity = cache
764 .resolve_cached_identity(resolver_a.clone(), &components, &config_bag)
765 .await
766 .unwrap();
767 assert_eq!("A", identity.data::<Token>().unwrap().token());
768 assert_eq!(1, resolver_a_calls.load(Ordering::Relaxed));
769 assert_eq!(1, resolver_b_calls.load(Ordering::Relaxed));
770 }
771}