1use crate::{
2 error::{RpcErrorExt, TransportError, TransportErrorKind},
3 TransportFut,
4};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use core::fmt;
7use std::{
8 sync::{
9 atomic::{AtomicU32, Ordering},
10 Arc,
11 },
12 task::{Context, Poll},
13 time::Duration,
14};
15use tower::{Layer, Service};
16use tracing::trace;
17
18#[cfg(target_family = "wasm")]
19use wasmtimer::tokio::sleep;
20
21#[cfg(not(target_family = "wasm"))]
22use tokio::time::sleep;
23
24const DEFAULT_AVG_COST: u64 = 20u64;
26
27#[derive(Debug, Clone)]
32pub struct RetryBackoffLayer<P: RetryPolicy = RateLimitRetryPolicy> {
33 max_rate_limit_retries: u32,
35 initial_backoff: u64,
37 compute_units_per_second: u64,
39 avg_cost: u64,
41 policy: P,
43}
44
45impl RetryBackoffLayer {
46 pub const fn new(
48 max_rate_limit_retries: u32,
49 initial_backoff: u64,
50 compute_units_per_second: u64,
51 ) -> Self {
52 Self {
53 max_rate_limit_retries,
54 initial_backoff,
55 compute_units_per_second,
56 avg_cost: DEFAULT_AVG_COST,
57 policy: RateLimitRetryPolicy,
58 }
59 }
60
61 pub const fn with_avg_unit_cost(mut self, avg_cost: u64) -> Self {
77 self.avg_cost = avg_cost;
78 self
79 }
80}
81
82impl<P: RetryPolicy> RetryBackoffLayer<P> {
83 pub const fn new_with_policy(
85 max_rate_limit_retries: u32,
86 initial_backoff: u64,
87 compute_units_per_second: u64,
88 policy: P,
89 ) -> Self {
90 Self {
91 max_rate_limit_retries,
92 initial_backoff,
93 compute_units_per_second,
94 policy,
95 avg_cost: DEFAULT_AVG_COST,
96 }
97 }
98}
99
100#[derive(Debug, Copy, Clone, Default)]
103#[non_exhaustive]
104pub struct RateLimitRetryPolicy;
105
106impl RateLimitRetryPolicy {
107 pub fn or<F>(self, f: F) -> OrRetryPolicyFn<Self>
110 where
111 F: Fn(&TransportError) -> bool + Send + Sync + 'static,
112 {
113 OrRetryPolicyFn::new(self, f)
114 }
115}
116
117pub trait RetryPolicy: Send + Sync + std::fmt::Debug {
120 fn should_retry(&self, error: &TransportError) -> bool;
122
123 fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration>;
125}
126
127impl RetryPolicy for RateLimitRetryPolicy {
128 fn should_retry(&self, error: &TransportError) -> bool {
129 error.is_retryable()
130 }
131
132 fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration> {
134 error.backoff_hint()
135 }
136}
137
138#[derive(Clone)]
141pub struct OrRetryPolicyFn<P = RateLimitRetryPolicy> {
142 inner: Arc<dyn Fn(&TransportError) -> bool + Send + Sync>,
143 base: P,
144}
145
146impl<P> OrRetryPolicyFn<P> {
147 pub fn new<F>(base: P, or: F) -> Self
149 where
150 F: Fn(&TransportError) -> bool + Send + Sync + 'static,
151 {
152 Self { inner: Arc::new(or), base }
153 }
154}
155
156impl<P: RetryPolicy> RetryPolicy for OrRetryPolicyFn<P> {
157 fn should_retry(&self, error: &TransportError) -> bool {
158 self.inner.as_ref()(error) || self.base.should_retry(error)
159 }
160
161 fn backoff_hint(&self, error: &TransportError) -> Option<Duration> {
162 self.base.backoff_hint(error)
163 }
164}
165
166impl<P: fmt::Debug> fmt::Debug for OrRetryPolicyFn<P> {
167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168 f.debug_struct("OrRetryPolicyFn")
169 .field("base", &self.base)
170 .field("inner", &"{{..}}")
171 .finish_non_exhaustive()
172 }
173}
174
175impl<S, P: RetryPolicy + Clone> Layer<S> for RetryBackoffLayer<P> {
176 type Service = RetryBackoffService<S, P>;
177
178 fn layer(&self, inner: S) -> Self::Service {
179 RetryBackoffService {
180 inner,
181 policy: self.policy.clone(),
182 max_rate_limit_retries: self.max_rate_limit_retries,
183 initial_backoff: self.initial_backoff,
184 compute_units_per_second: self.compute_units_per_second,
185 requests_enqueued: Arc::new(AtomicU32::new(0)),
186 avg_cost: self.avg_cost,
187 }
188 }
189}
190
191#[derive(Debug, Clone)]
194pub struct RetryBackoffService<S, P: RetryPolicy = RateLimitRetryPolicy> {
195 inner: S,
197 policy: P,
199 max_rate_limit_retries: u32,
201 initial_backoff: u64,
203 compute_units_per_second: u64,
205 requests_enqueued: Arc<AtomicU32>,
207 avg_cost: u64,
209}
210
211impl<S, P: RetryPolicy> RetryBackoffService<S, P> {
212 const fn initial_backoff(&self) -> Duration {
213 Duration::from_millis(self.initial_backoff)
214 }
215}
216
217impl<S, P> Service<RequestPacket> for RetryBackoffService<S, P>
218where
219 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
220 + Send
221 + 'static
222 + Clone,
223 P: RetryPolicy + Clone + 'static,
224{
225 type Response = ResponsePacket;
226 type Error = TransportError;
227 type Future = TransportFut<'static>;
228
229 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
230 self.inner.poll_ready(cx)
233 }
234
235 fn call(&mut self, request: RequestPacket) -> Self::Future {
236 let inner = self.inner.clone();
237 let this = self.clone();
238 let mut inner = std::mem::replace(&mut self.inner, inner);
239 Box::pin(async move {
240 let ahead_in_queue = this.requests_enqueued.fetch_add(1, Ordering::SeqCst) as u64;
241 let mut rate_limit_retry_number: u32 = 0;
242 loop {
243 let err;
244 let res = inner.call(request.clone()).await;
245
246 match res {
247 Ok(res) => {
248 if let Some(e) = res.as_error() {
249 err = TransportError::ErrorResp(e.clone())
250 } else {
251 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
252 return Ok(res);
253 }
254 }
255 Err(e) => err = e,
256 }
257
258 let should_retry = this.policy.should_retry(&err);
259 if should_retry {
260 rate_limit_retry_number += 1;
261 if rate_limit_retry_number > this.max_rate_limit_retries {
262 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
263 return Err(TransportErrorKind::custom_str(&format!(
264 "Max retries exceeded {err}"
265 )));
266 }
267 trace!(%err, "retrying request");
268
269 let current_queued_reqs = this.requests_enqueued.load(Ordering::SeqCst) as u64;
270
271 let backoff_hint = this.policy.backoff_hint(&err);
274 let next_backoff = backoff_hint.unwrap_or_else(|| this.initial_backoff());
275
276 let seconds_to_wait_for_compute_budget = compute_unit_offset_in_secs(
277 this.avg_cost,
278 this.compute_units_per_second,
279 current_queued_reqs,
280 ahead_in_queue,
281 );
282 let total_backoff = next_backoff
283 + std::time::Duration::from_secs(seconds_to_wait_for_compute_budget);
284
285 trace!(
286 total_backoff_millis = total_backoff.as_millis(),
287 budget_backoff_millis = seconds_to_wait_for_compute_budget * 1000,
288 default_backoff_millis = next_backoff.as_millis(),
289 backoff_hint_millis = backoff_hint.map(|d| d.as_millis()),
290 "(all in ms) backing off due to rate limit"
291 );
292
293 sleep(total_backoff).await;
294 } else {
295 this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
296 return Err(err);
297 }
298 }
299 })
300 }
301}
302
303fn compute_unit_offset_in_secs(
315 avg_cost: u64,
316 compute_units_per_second: u64,
317 current_queued_requests: u64,
318 ahead_in_queue: u64,
319) -> u64 {
320 let request_capacity_per_second = compute_units_per_second.saturating_div(avg_cost).max(1);
321 if current_queued_requests > request_capacity_per_second {
322 current_queued_requests.min(ahead_in_queue).saturating_div(request_capacity_per_second)
323 } else {
324 0
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_compute_units_per_second() {
334 let offset = compute_unit_offset_in_secs(17, 10, 0, 0);
335 assert_eq!(offset, 0);
336 let offset = compute_unit_offset_in_secs(17, 10, 2, 2);
337 assert_eq!(offset, 2);
338 }
339}