alloy_transport/layers/
retry.rs

1use crate::{
2    error::{RpcErrorExt, TransportError, TransportErrorKind},
3    TransportFut,
4};
5use alloy_json_rpc::{RequestPacket, ResponsePacket};
6use core::fmt;
7use std::{
8    sync::{
9        atomic::{AtomicU32, Ordering},
10        Arc,
11    },
12    task::{Context, Poll},
13    time::Duration,
14};
15use tower::{Layer, Service};
16use tracing::trace;
17
18#[cfg(target_family = "wasm")]
19use wasmtimer::tokio::sleep;
20
21#[cfg(not(target_family = "wasm"))]
22use tokio::time::sleep;
23
24/// The default average cost of a request in Compute Units (CU).
25const DEFAULT_AVG_COST: u64 = 20u64;
26
27/// A Transport Layer that is responsible for retrying requests based on the
28/// error type. See [`TransportError`].
29///
30/// TransportError: crate::error::TransportError
31#[derive(Debug, Clone)]
32pub struct RetryBackoffLayer<P: RetryPolicy = RateLimitRetryPolicy> {
33    /// The maximum number of retries for rate limit errors.
34    max_rate_limit_retries: u32,
35    /// The initial backoff in milliseconds.
36    initial_backoff: u64,
37    /// The number of Compute Units per second for this provider.
38    compute_units_per_second: u64,
39    /// The average cost of a request. Defaults to [DEFAULT_AVG_COST].
40    avg_cost: u64,
41    /// The [RetryPolicy] to use. Defaults to [RateLimitRetryPolicy].
42    policy: P,
43}
44
45impl RetryBackoffLayer {
46    /// Creates a new retry layer with the given parameters and the default [RateLimitRetryPolicy].
47    pub const fn new(
48        max_rate_limit_retries: u32,
49        initial_backoff: u64,
50        compute_units_per_second: u64,
51    ) -> Self {
52        Self {
53            max_rate_limit_retries,
54            initial_backoff,
55            compute_units_per_second,
56            avg_cost: DEFAULT_AVG_COST,
57            policy: RateLimitRetryPolicy,
58        }
59    }
60
61    /// Sets the average Compute Unit (CU) cost per request. Defaults to `20` CU.
62    ///
63    /// Based on Alchemy’s published Compute Unit (CU) table, most frequently used
64    /// JSON-RPC methods fall within the `10–20` CU range, with only a small number
65    /// of higher-cost outliers (such as log queries or transaction submissions).
66    /// Consequently, an average cost of `20` CU per request serves as a practical
67    /// and representative estimate for typical EVM workloads
68    ///
69    /// Alchemy also uses this `20` CU figure when expressing throughput in
70    /// requests per second. For example, the free tier maps `500 CU/s` to
71    /// approximately `25 req/s` under this average, which aligns with the `20` CU.
72    ///
73    /// References:
74    /// - <https://www.alchemy.com/docs/reference/compute-unit-costs#evm-standard-json-rpc-methods>
75    /// - <https://www.alchemy.com/pricing#table-products>
76    pub const fn with_avg_unit_cost(mut self, avg_cost: u64) -> Self {
77        self.avg_cost = avg_cost;
78        self
79    }
80}
81
82impl<P: RetryPolicy> RetryBackoffLayer<P> {
83    /// Creates a new retry layer with the given parameters and [RetryPolicy].
84    pub const fn new_with_policy(
85        max_rate_limit_retries: u32,
86        initial_backoff: u64,
87        compute_units_per_second: u64,
88        policy: P,
89    ) -> Self {
90        Self {
91            max_rate_limit_retries,
92            initial_backoff,
93            compute_units_per_second,
94            policy,
95            avg_cost: DEFAULT_AVG_COST,
96        }
97    }
98}
99
100/// [RateLimitRetryPolicy] implements [RetryPolicy] to determine whether to retry depending on the
101/// err.
102#[derive(Debug, Copy, Clone, Default)]
103#[non_exhaustive]
104pub struct RateLimitRetryPolicy;
105
106impl RateLimitRetryPolicy {
107    /// Creates a new [`RetryPolicy`] that in addition to this policy respects the given closure
108    /// function for detecting if an error should be retried.
109    pub fn or<F>(self, f: F) -> OrRetryPolicyFn<Self>
110    where
111        F: Fn(&TransportError) -> bool + Send + Sync + 'static,
112    {
113        OrRetryPolicyFn::new(self, f)
114    }
115}
116
117/// [RetryPolicy] defines logic for which [TransportError] instances should
118/// the client retry the request and try to recover from.
119pub trait RetryPolicy: Send + Sync + std::fmt::Debug {
120    /// Whether to retry the request based on the given `error`
121    fn should_retry(&self, error: &TransportError) -> bool;
122
123    /// Providers may include the `backoff` in the error response directly
124    fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration>;
125}
126
127impl RetryPolicy for RateLimitRetryPolicy {
128    fn should_retry(&self, error: &TransportError) -> bool {
129        error.is_retryable()
130    }
131
132    /// Provides a backoff hint if the error response contains it
133    fn backoff_hint(&self, error: &TransportError) -> Option<std::time::Duration> {
134        error.backoff_hint()
135    }
136}
137
138/// A [`RetryPolicy`] that supports an additional closure for deciding if an error should be
139/// retried.
140#[derive(Clone)]
141pub struct OrRetryPolicyFn<P = RateLimitRetryPolicy> {
142    inner: Arc<dyn Fn(&TransportError) -> bool + Send + Sync>,
143    base: P,
144}
145
146impl<P> OrRetryPolicyFn<P> {
147    /// Creates a new instance with the given base policy and the given closure
148    pub fn new<F>(base: P, or: F) -> Self
149    where
150        F: Fn(&TransportError) -> bool + Send + Sync + 'static,
151    {
152        Self { inner: Arc::new(or), base }
153    }
154}
155
156impl<P: RetryPolicy> RetryPolicy for OrRetryPolicyFn<P> {
157    fn should_retry(&self, error: &TransportError) -> bool {
158        self.inner.as_ref()(error) || self.base.should_retry(error)
159    }
160
161    fn backoff_hint(&self, error: &TransportError) -> Option<Duration> {
162        self.base.backoff_hint(error)
163    }
164}
165
166impl<P: fmt::Debug> fmt::Debug for OrRetryPolicyFn<P> {
167    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168        f.debug_struct("OrRetryPolicyFn")
169            .field("base", &self.base)
170            .field("inner", &"{{..}}")
171            .finish_non_exhaustive()
172    }
173}
174
175impl<S, P: RetryPolicy + Clone> Layer<S> for RetryBackoffLayer<P> {
176    type Service = RetryBackoffService<S, P>;
177
178    fn layer(&self, inner: S) -> Self::Service {
179        RetryBackoffService {
180            inner,
181            policy: self.policy.clone(),
182            max_rate_limit_retries: self.max_rate_limit_retries,
183            initial_backoff: self.initial_backoff,
184            compute_units_per_second: self.compute_units_per_second,
185            requests_enqueued: Arc::new(AtomicU32::new(0)),
186            avg_cost: self.avg_cost,
187        }
188    }
189}
190
191/// A Tower Service used by the RetryBackoffLayer that is responsible for retrying requests based
192/// on the error type. See [TransportError] and [RateLimitRetryPolicy].
193#[derive(Debug, Clone)]
194pub struct RetryBackoffService<S, P: RetryPolicy = RateLimitRetryPolicy> {
195    /// The inner service
196    inner: S,
197    /// The [RetryPolicy] to use.
198    policy: P,
199    /// The maximum number of retries for rate limit errors
200    max_rate_limit_retries: u32,
201    /// The initial backoff in milliseconds
202    initial_backoff: u64,
203    /// The number of compute units per second for this service
204    compute_units_per_second: u64,
205    /// The number of requests currently enqueued
206    requests_enqueued: Arc<AtomicU32>,
207    /// The average cost of a request.
208    avg_cost: u64,
209}
210
211impl<S, P: RetryPolicy> RetryBackoffService<S, P> {
212    const fn initial_backoff(&self) -> Duration {
213        Duration::from_millis(self.initial_backoff)
214    }
215}
216
217impl<S, P> Service<RequestPacket> for RetryBackoffService<S, P>
218where
219    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
220        + Send
221        + 'static
222        + Clone,
223    P: RetryPolicy + Clone + 'static,
224{
225    type Response = ResponsePacket;
226    type Error = TransportError;
227    type Future = TransportFut<'static>;
228
229    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
230        // Our middleware doesn't care about backpressure, so it's ready as long
231        // as the inner service is ready.
232        self.inner.poll_ready(cx)
233    }
234
235    fn call(&mut self, request: RequestPacket) -> Self::Future {
236        let inner = self.inner.clone();
237        let this = self.clone();
238        let mut inner = std::mem::replace(&mut self.inner, inner);
239        Box::pin(async move {
240            let ahead_in_queue = this.requests_enqueued.fetch_add(1, Ordering::SeqCst) as u64;
241            let mut rate_limit_retry_number: u32 = 0;
242            loop {
243                let err;
244                let res = inner.call(request.clone()).await;
245
246                match res {
247                    Ok(res) => {
248                        if let Some(e) = res.as_error() {
249                            err = TransportError::ErrorResp(e.clone())
250                        } else {
251                            this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
252                            return Ok(res);
253                        }
254                    }
255                    Err(e) => err = e,
256                }
257
258                let should_retry = this.policy.should_retry(&err);
259                if should_retry {
260                    rate_limit_retry_number += 1;
261                    if rate_limit_retry_number > this.max_rate_limit_retries {
262                        this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
263                        return Err(TransportErrorKind::custom_str(&format!(
264                            "Max retries exceeded {err}"
265                        )));
266                    }
267                    trace!(%err, "retrying request");
268
269                    let current_queued_reqs = this.requests_enqueued.load(Ordering::SeqCst) as u64;
270
271                    // try to extract the requested backoff from the error or compute the next
272                    // backoff based on retry count
273                    let backoff_hint = this.policy.backoff_hint(&err);
274                    let next_backoff = backoff_hint.unwrap_or_else(|| this.initial_backoff());
275
276                    let seconds_to_wait_for_compute_budget = compute_unit_offset_in_secs(
277                        this.avg_cost,
278                        this.compute_units_per_second,
279                        current_queued_reqs,
280                        ahead_in_queue,
281                    );
282                    let total_backoff = next_backoff
283                        + std::time::Duration::from_secs(seconds_to_wait_for_compute_budget);
284
285                    trace!(
286                        total_backoff_millis = total_backoff.as_millis(),
287                        budget_backoff_millis = seconds_to_wait_for_compute_budget * 1000,
288                        default_backoff_millis = next_backoff.as_millis(),
289                        backoff_hint_millis = backoff_hint.map(|d| d.as_millis()),
290                        "(all in ms) backing off due to rate limit"
291                    );
292
293                    sleep(total_backoff).await;
294                } else {
295                    this.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
296                    return Err(err);
297                }
298            }
299        })
300    }
301}
302
303/// Calculates an offset in seconds by taking into account the number of currently queued requests,
304/// number of requests that were ahead in the queue when the request was first issued, the average
305/// cost a weighted request (heuristic), and the number of available compute units per seconds.
306///
307/// Returns the number of seconds (the unit the remote endpoint measures compute budget) a request
308/// is supposed to wait to not get rate limited. The budget per second is
309/// `compute_units_per_second`, assuming an average cost of `avg_cost` this allows (in theory)
310/// `compute_units_per_second / avg_cost` requests per seconds without getting rate limited.
311/// By taking into account the number of concurrent request and the position in queue when the
312/// request was first issued and determine the number of seconds a request is supposed to wait, if
313/// at all
314fn compute_unit_offset_in_secs(
315    avg_cost: u64,
316    compute_units_per_second: u64,
317    current_queued_requests: u64,
318    ahead_in_queue: u64,
319) -> u64 {
320    let request_capacity_per_second = compute_units_per_second.saturating_div(avg_cost).max(1);
321    if current_queued_requests > request_capacity_per_second {
322        current_queued_requests.min(ahead_in_queue).saturating_div(request_capacity_per_second)
323    } else {
324        0
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_compute_units_per_second() {
334        let offset = compute_unit_offset_in_secs(17, 10, 0, 0);
335        assert_eq!(offset, 0);
336        let offset = compute_unit_offset_in_secs(17, 10, 2, 2);
337        assert_eq!(offset, 2);
338    }
339}