aws_smithy_runtime/
expiring_cache.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::sync::{OnceCell, RwLock};
11
12/// Expiry-aware cache
13///
14/// [`ExpiringCache`] implements two important features:
15/// 1. Respect expiry of contents
16/// 2. Deduplicate load requests to prevent thundering herds when no value is present.
17#[derive(Debug)]
18pub struct ExpiringCache<T, E> {
19    /// Amount of time before the actual expiration time
20    /// when the value is considered expired.
21    buffer_time: Duration,
22    value: Arc<RwLock<OnceCell<(T, SystemTime)>>>,
23    _phantom: PhantomData<E>,
24}
25
26impl<T, E> Clone for ExpiringCache<T, E> {
27    fn clone(&self) -> Self {
28        Self {
29            buffer_time: self.buffer_time,
30            value: self.value.clone(),
31            _phantom: Default::default(),
32        }
33    }
34}
35
36impl<T, E> ExpiringCache<T, E>
37where
38    T: Clone,
39{
40    /// Creates `ExpiringCache` with the given `buffer_time`.
41    pub fn new(buffer_time: Duration) -> Self {
42        ExpiringCache {
43            buffer_time,
44            value: Arc::new(RwLock::new(OnceCell::new())),
45            _phantom: Default::default(),
46        }
47    }
48
49    #[cfg(all(test, feature = "client", feature = "http-auth"))]
50    async fn get(&self) -> Option<T>
51    where
52        T: Clone,
53    {
54        self.value
55            .read()
56            .await
57            .get()
58            .cloned()
59            .map(|(creds, _expiry)| creds)
60    }
61
62    /// Attempts to refresh the cached value with the given future.
63    /// If multiple threads attempt to refresh at the same time, one of them will win,
64    /// and the others will await that thread's result rather than multiple refreshes occurring.
65    /// The function given to acquire a value future, `f`, will not be called
66    /// if another thread is chosen to load the value.
67    pub async fn get_or_load<F, Fut>(&self, f: F) -> Result<T, E>
68    where
69        F: FnOnce() -> Fut,
70        Fut: Future<Output = Result<(T, SystemTime), E>>,
71    {
72        let lock = self.value.read().await;
73        let future = lock.get_or_try_init(f);
74        future.await.map(|(value, _expiry)| value.clone())
75    }
76
77    /// If the value is expired, clears the cache. Otherwise, yields the current value.
78    pub async fn yield_or_clear_if_expired(&self, now: SystemTime) -> Option<T> {
79        // Short-circuit if the value is not expired
80        if let Some((value, expiry)) = self.value.read().await.get() {
81            if !expired(*expiry, self.buffer_time, now) {
82                return Some(value.clone());
83            } else {
84                tracing::debug!(expiry = ?expiry, delta= ?now.duration_since(*expiry), "An item existed but it expired.")
85            }
86        }
87
88        // Acquire a write lock to clear the cache, but then once the lock is acquired,
89        // check again that the value is not already cleared. If it has been cleared,
90        // then another thread is refreshing the cache by the time the write lock was acquired.
91        let mut lock = self.value.write().await;
92        if let Some((_value, expiration)) = lock.get() {
93            // Also check that we're clearing the expired value and not a value
94            // that has been refreshed by another thread.
95            if expired(*expiration, self.buffer_time, now) {
96                *lock = OnceCell::new();
97            }
98        }
99        None
100    }
101}
102
103fn expired(expiration: SystemTime, buffer_time: Duration, now: SystemTime) -> bool {
104    now >= (expiration - buffer_time)
105}
106
107#[cfg(all(test, feature = "client", feature = "http-auth"))]
108mod tests {
109    use super::{expired, ExpiringCache};
110    use aws_smithy_runtime_api::box_error::BoxError;
111    use aws_smithy_runtime_api::client::identity::http::Token;
112    use aws_smithy_runtime_api::client::identity::Identity;
113    use std::time::{Duration, SystemTime};
114    use tracing_test::traced_test;
115
116    fn identity(expired_secs: u64) -> Result<(Identity, SystemTime), BoxError> {
117        let expiration = epoch_secs(expired_secs);
118        let identity = Identity::new(Token::new("test", Some(expiration)), Some(expiration));
119        Ok((identity, expiration))
120    }
121
122    fn epoch_secs(secs: u64) -> SystemTime {
123        SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
124    }
125
126    #[test]
127    fn expired_check() {
128        let ts = epoch_secs(100);
129        assert!(expired(ts, Duration::from_secs(10), epoch_secs(1000)));
130        assert!(expired(ts, Duration::from_secs(10), epoch_secs(90)));
131        assert!(!expired(ts, Duration::from_secs(10), epoch_secs(10)));
132    }
133
134    #[traced_test]
135    #[tokio::test]
136    async fn cache_clears_if_expired_only() {
137        let cache = ExpiringCache::new(Duration::from_secs(10));
138        assert!(cache
139            .yield_or_clear_if_expired(epoch_secs(100))
140            .await
141            .is_none());
142
143        cache.get_or_load(|| async { identity(100) }).await.unwrap();
144        assert_eq!(
145            Some(epoch_secs(100)),
146            cache.get().await.unwrap().expiration()
147        );
148
149        // It should not clear the credentials if they're not expired
150        assert_eq!(
151            Some(epoch_secs(100)),
152            cache
153                .yield_or_clear_if_expired(epoch_secs(10))
154                .await
155                .unwrap()
156                .expiration()
157        );
158
159        // It should clear the credentials if they're expired
160        assert!(cache
161            .yield_or_clear_if_expired(epoch_secs(500))
162            .await
163            .is_none());
164        assert!(cache.get().await.is_none());
165    }
166}