aws_smithy_runtime/client/retries/
token_bucket.rs
1use aws_smithy_types::config_bag::{Storable, StoreReplace};
7use aws_smithy_types::retry::ErrorKind;
8use std::sync::Arc;
9use tokio::sync::{OwnedSemaphorePermit, Semaphore};
10use tracing::trace;
11
12const DEFAULT_CAPACITY: usize = 500;
13const RETRY_COST: u32 = 5;
14const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2;
15const PERMIT_REGENERATION_AMOUNT: usize = 1;
16
17#[derive(Clone, Debug)]
19pub struct TokenBucket {
20 semaphore: Arc<Semaphore>,
21 max_permits: usize,
22 timeout_retry_cost: u32,
23 retry_cost: u32,
24}
25
26impl Storable for TokenBucket {
27 type Storer = StoreReplace<Self>;
28}
29
30impl Default for TokenBucket {
31 fn default() -> Self {
32 Self {
33 semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
34 max_permits: DEFAULT_CAPACITY,
35 timeout_retry_cost: RETRY_TIMEOUT_COST,
36 retry_cost: RETRY_COST,
37 }
38 }
39}
40
41impl TokenBucket {
42 pub fn new(initial_quota: usize) -> Self {
44 Self {
45 semaphore: Arc::new(Semaphore::new(initial_quota)),
46 max_permits: initial_quota,
47 retry_cost: RETRY_COST,
48 timeout_retry_cost: RETRY_TIMEOUT_COST,
49 }
50 }
51
52 pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
53 let retry_cost = if err == &ErrorKind::TransientError {
54 self.timeout_retry_cost
55 } else {
56 self.retry_cost
57 };
58
59 self.semaphore
60 .clone()
61 .try_acquire_many_owned(retry_cost)
62 .ok()
63 }
64
65 pub(crate) fn regenerate_a_token(&self) {
66 if self.semaphore.available_permits() < (self.max_permits) {
67 trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket");
68 self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT)
69 }
70 }
71
72 #[cfg(all(test, feature = "test-util"))]
73 pub(crate) fn available_permits(&self) -> usize {
74 self.semaphore.available_permits()
75 }
76}