aws_smithy_runtime/client/waiters/
backoff.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::{fmt, time::Duration};
7
8#[derive(Debug)]
9pub(super) struct Backoff {
10    min_delay: Duration,
11    max_delay: Duration,
12    max_wait: Duration,
13    attempt_ceiling: u32,
14    random: RandomImpl,
15}
16
17impl Backoff {
18    pub(super) fn new(
19        min_delay: Duration,
20        max_delay: Duration,
21        max_wait: Duration,
22        random: RandomImpl,
23    ) -> Self {
24        Self {
25            min_delay,
26            max_delay,
27            max_wait,
28            // Attempt ceiling calculation taken from the Smithy spec: https://smithy.io/2.0/additional-specs/waiters.html#waiter-retries
29            attempt_ceiling: (((max_delay.as_secs_f64() / min_delay.as_secs_f64()).ln()
30                / 2f64.ln())
31                + 1.0) as u32,
32            random,
33        }
34    }
35
36    // Calculates backoff delay time according to the Smithy spec: https://smithy.io/2.0/additional-specs/waiters.html#waiter-retries
37    pub(super) fn delay(&self, attempt: u32, elapsed: Duration) -> Duration {
38        let delay = if attempt > self.attempt_ceiling {
39            self.max_delay.as_secs()
40        } else {
41            self.min_delay.as_secs() * 2u64.pow(attempt - 1)
42        };
43        let mut delay = Duration::from_secs(self.random.random(self.min_delay.as_secs(), delay));
44
45        let remaining_time = self.max_wait.saturating_sub(elapsed);
46        if remaining_time.saturating_sub(delay) <= self.min_delay {
47            // Note: deviating from the spec here. Subtracting `min_delay` doesn't fulfill the original intent.
48            delay = remaining_time;
49        }
50        delay
51    }
52
53    #[inline]
54    pub(super) fn max_wait(&self) -> Duration {
55        self.max_wait
56    }
57}
58
59#[derive(Default)]
60pub(super) enum RandomImpl {
61    #[default]
62    Default,
63    #[cfg(test)]
64    Override(Box<dyn Fn(u64, u64) -> u64 + Send + Sync>),
65}
66
67impl fmt::Debug for RandomImpl {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        match self {
70            Self::Default => write!(f, "Default"),
71            #[cfg(test)]
72            Self::Override(_) => f.debug_tuple("Override").field(&"** function **").finish(),
73        }
74    }
75}
76
77impl RandomImpl {
78    fn random(&self, min_inclusive: u64, max_inclusive: u64) -> u64 {
79        match self {
80            Self::Default => fastrand::u64(min_inclusive..=max_inclusive),
81            #[cfg(test)]
82            Self::Override(overrid) => (overrid)(min_inclusive, max_inclusive),
83        }
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use std::sync::{Arc, Mutex};
91
92    fn test_backoff(
93        min_delay: u64,
94        max_delay: u64,
95        test_random: impl Fn(u64, u64) -> u64 + Send + Sync + 'static,
96        attempt_delays: &[(u64, u64)],
97    ) {
98        let backoff = dbg!(Backoff::new(
99            Duration::from_secs(min_delay),
100            Duration::from_secs(max_delay),
101            Duration::from_secs(300),
102            RandomImpl::Override(Box::new(test_random)),
103        ));
104
105        for (index, (delay, time)) in attempt_delays.iter().enumerate() {
106            let attempt = index + 1;
107            println!("attempt: {attempt}, delay: {delay}, time: {time}");
108            assert_eq!(
109                Duration::from_secs(*delay),
110                backoff.delay(attempt as _, Duration::from_secs(*time))
111            );
112        }
113    }
114
115    #[test]
116    fn backoff_jitter_as_average() {
117        let test_random = |min: u64, max: u64| (min + max) / 2;
118        let attempt_delays = &[
119            // delay, time
120            (2, 2),
121            (3, 4),
122            (5, 7),
123            (9, 12),
124            (17, 21),
125            (33, 38),
126            (61, 71),
127            (61, 132),
128            (61, 193),
129            (46, 254),
130            (0, 300),
131        ];
132        test_backoff(2, 120, test_random, attempt_delays);
133    }
134
135    #[test]
136    fn backoff_with_seeded_jitter() {
137        let random = Arc::new(Mutex::new(fastrand::Rng::with_seed(1)));
138        let test_random = move |min: u64, max: u64| random.lock().unwrap().u64(min..=max);
139        let attempt_delays = &[
140            // delay, time
141            (2, 2),
142            (3, 4),
143            (3, 7),
144            (13, 12),
145            (2, 14),
146            (51, 16),
147            (93, 73),
148            (102, 164),
149            (73, 170),
150            (21, 227),
151            (9, 256),
152            (17, 283),
153            (0, 300),
154        ];
155        test_backoff(2, 120, test_random, attempt_delays);
156    }
157
158    #[test]
159    fn backoff_with_large_min_delay() {
160        let test_random = |min: u64, max: u64| (min + max) / 2;
161        let attempt_delays = &[
162            // delay, time
163            (15, 1),
164            (22, 16),
165            (37, 38),
166            (67, 75),
167            (67, 142),
168            (24, 276),
169            (0, 300),
170        ];
171        test_backoff(15, 120, test_random, attempt_delays);
172    }
173}