alloy_transport/layers/
fallback.rs

1use crate::time::Instant;
2use alloy_json_rpc::{RequestPacket, ResponsePacket};
3use core::time::Duration;
4use derive_more::{Deref, DerefMut};
5use futures::{stream::FuturesUnordered, StreamExt};
6use parking_lot::RwLock;
7use std::{
8    collections::VecDeque,
9    num::NonZeroUsize,
10    sync::Arc,
11    task::{Context, Poll},
12};
13use tower::{Layer, Service};
14use tracing::trace;
15
16use crate::{TransportError, TransportErrorKind, TransportFut};
17
18// Constants for the transport ranking algorithm
19const STABILITY_WEIGHT: f64 = 0.7;
20const LATENCY_WEIGHT: f64 = 0.3;
21const DEFAULT_SAMPLE_COUNT: usize = 10;
22const DEFAULT_ACTIVE_TRANSPORT_COUNT: usize = 3;
23
24/// The [`FallbackService`] consumes multiple transports and is able to
25/// query them in parallel, returning the first successful response.
26///
27/// The service ranks transports based on latency and stability metrics,
28/// and will attempt to always use the best available transports.
29#[derive(Debug, Clone)]
30pub struct FallbackService<S> {
31    /// The list of transports to use
32    transports: Arc<Vec<ScoredTransport<S>>>,
33    /// The maximum number of transports to use in parallel
34    active_transport_count: usize,
35}
36
37impl<S: Clone> FallbackService<S> {
38    /// Create a new fallback service from a list of transports.
39    ///
40    /// The `active_transport_count` parameter controls how many transports are used for requests
41    /// at any one time.
42    pub fn new(transports: Vec<S>, active_transport_count: usize) -> Self {
43        let scored_transports = transports
44            .into_iter()
45            .enumerate()
46            .map(|(id, transport)| ScoredTransport::new(id, transport))
47            .collect::<Vec<_>>();
48
49        Self { transports: Arc::new(scored_transports), active_transport_count }
50    }
51
52    /// Log the current ranking of transports
53    fn log_transport_rankings(&self) {
54        if !tracing::enabled!(tracing::Level::TRACE) {
55            return;
56        }
57
58        // Prepare lightweight ranking data without cloning transports
59        let mut ranked: Vec<(usize, f64, String)> =
60            self.transports.iter().map(|t| (t.id, t.score(), t.metrics_summary())).collect();
61
62        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
63
64        trace!("Current transport rankings:");
65        for (idx, (id, _score, summary)) in ranked.iter().enumerate() {
66            trace!("  #{}: Transport[{}] - {}", idx + 1, id, summary);
67        }
68    }
69}
70
71impl<S> FallbackService<S>
72where
73    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
74        + Send
75        + Clone
76        + 'static,
77{
78    /// Make a request to the fallback service middleware.
79    ///
80    /// Here is a high-level overview of how requests are handled:
81    ///
82    /// - At the start of each request, we sort transports by score
83    /// - We take the top `self.active_transport_count` and call them in parallel
84    /// - If any of them succeeds, we update the transport scores and return the response
85    /// - If all transports fail, we update the scores and return the last error that occurred
86    ///
87    /// This strategy allows us to always make requests to the best available transports
88    /// while keeping them available.
89    async fn make_request(&self, req: RequestPacket) -> Result<ResponsePacket, TransportError> {
90        // Get the top transports to use for this request
91        let top_transports = {
92            // Clone the vec, sort it, and take the top `self.active_transport_count`
93            let mut transports_clone = (*self.transports).clone();
94            transports_clone.sort_by(|a, b| b.cmp(a));
95            transports_clone.into_iter().take(self.active_transport_count).collect::<Vec<_>>()
96        };
97
98        // Create a collection of future requests
99        let mut futures = FuturesUnordered::new();
100
101        // Launch requests to all active transports in parallel
102        for mut transport in top_transports {
103            let req_clone = req.clone();
104
105            let future = async move {
106                let start = Instant::now();
107                let result = transport.call(req_clone).await;
108                trace!(
109                    "Transport[{}] completed: latency={:?}, status={}",
110                    transport.id,
111                    start.elapsed(),
112                    if result.is_ok() { "success" } else { "fail" }
113                );
114
115                (result, transport, start.elapsed())
116            };
117
118            futures.push(future);
119        }
120
121        // Wait for the first successful response or until all fail
122        let mut last_error = None;
123
124        while let Some((result, transport, duration)) = futures.next().await {
125            match result {
126                Ok(response) => {
127                    // Record success
128                    transport.track_success(duration);
129
130                    self.log_transport_rankings();
131
132                    return Ok(response);
133                }
134                Err(error) => {
135                    // Record failure
136                    transport.track_failure();
137
138                    last_error = Some(error);
139                }
140            }
141        }
142
143        Err(last_error.unwrap_or_else(|| {
144            TransportErrorKind::custom_str("All transport futures failed to complete")
145        }))
146    }
147}
148
149impl<S> Service<RequestPacket> for FallbackService<S>
150where
151    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
152        + Send
153        + Sync
154        + Clone
155        + 'static,
156{
157    type Response = ResponsePacket;
158    type Error = TransportError;
159    type Future = TransportFut<'static>;
160
161    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
162        // Service is always ready
163        Poll::Ready(Ok(()))
164    }
165
166    fn call(&mut self, req: RequestPacket) -> Self::Future {
167        let this = self.clone();
168        Box::pin(async move { this.make_request(req).await })
169    }
170}
171
172/// Fallback layer for transparent transport failover. This layer will
173/// consume a list of transports to provide better availability and
174/// reliability.
175///
176/// The [`FallbackService`] will attempt to make requests to multiple
177/// transports in parallel, and return the first successful response.
178///
179/// If all transports fail, the fallback service will return an error.
180///
181/// # Automatic Transport Ranking
182///
183/// Each transport is automatically ranked based on latency & stability
184/// using a weighted algorithm. By default:
185///
186/// - Stability (success rate) is weighted at 70%
187/// - Latency (response time) is weighted at 30%
188/// - The `active_transport_count` parameter controls how many transports are queried at any one
189///   time.
190#[derive(Debug, Clone)]
191pub struct FallbackLayer {
192    /// The maximum number of transports to use in parallel
193    active_transport_count: usize,
194}
195
196impl FallbackLayer {
197    /// Set the number of active transports to use (must be greater than 0)
198    pub const fn with_active_transport_count(mut self, count: NonZeroUsize) -> Self {
199        self.active_transport_count = count.get();
200        self
201    }
202}
203
204impl<S> Layer<Vec<S>> for FallbackLayer
205where
206    S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
207        + Send
208        + Clone
209        + 'static,
210{
211    type Service = FallbackService<S>;
212
213    fn layer(&self, inner: Vec<S>) -> Self::Service {
214        FallbackService::new(inner, self.active_transport_count)
215    }
216}
217
218impl Default for FallbackLayer {
219    fn default() -> Self {
220        Self { active_transport_count: DEFAULT_ACTIVE_TRANSPORT_COUNT }
221    }
222}
223
224/// A scored transport that can be ordered in a heap.
225///
226/// The transport is scored every time it is used according to
227/// a simple weighted algorithm that favors latency and stability.
228///
229/// The score is calculated as follows (by default):
230///
231/// - Stability (success rate) is weighted at 70%
232/// - Latency (response time) is weighted at 30%
233///
234/// The score is then used to determine which transport to use next in
235/// the [`FallbackService`].
236#[derive(Debug, Clone, Deref, DerefMut)]
237struct ScoredTransport<S> {
238    /// The transport itself
239    #[deref]
240    #[deref_mut]
241    transport: S,
242    /// Unique identifier for the transport
243    id: usize,
244    /// Metrics for the transport
245    metrics: Arc<RwLock<TransportMetrics>>,
246}
247
248impl<S> ScoredTransport<S> {
249    /// Create a new scored transport
250    fn new(id: usize, transport: S) -> Self {
251        Self { id, transport, metrics: Arc::new(Default::default()) }
252    }
253
254    /// Returns the current score of the transport based on the weighted algorithm.
255    fn score(&self) -> f64 {
256        let metrics = self.metrics.read();
257        metrics.calculate_score()
258    }
259
260    /// Get metrics summary for debugging
261    fn metrics_summary(&self) -> String {
262        let metrics = self.metrics.read();
263        metrics.get_summary()
264    }
265
266    /// Track a successful request and its latency.
267    fn track_success(&self, duration: Duration) {
268        let mut metrics = self.metrics.write();
269        metrics.track_success(duration);
270    }
271
272    /// Track a failed request.
273    fn track_failure(&self) {
274        let mut metrics = self.metrics.write();
275        metrics.track_failure();
276    }
277}
278
279impl<S> PartialEq for ScoredTransport<S> {
280    fn eq(&self, other: &Self) -> bool {
281        self.score().eq(&other.score())
282    }
283}
284
285impl<S> Eq for ScoredTransport<S> {}
286
287#[expect(clippy::non_canonical_partial_ord_impl)]
288impl<S> PartialOrd for ScoredTransport<S> {
289    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
290        self.score().partial_cmp(&other.score())
291    }
292}
293
294impl<S> Ord for ScoredTransport<S> {
295    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
296        self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
297    }
298}
299
300/// Represents performance metrics for a transport.
301#[derive(Debug)]
302struct TransportMetrics {
303    // Latency history - tracks last N responses
304    latencies: VecDeque<Duration>,
305    // Success history - tracks last N successes (true) or failures (false)
306    successes: VecDeque<bool>,
307    // Last time this transport was checked/used
308    last_update: Instant,
309    // Total number of requests made to this transport
310    total_requests: u64,
311    // Total number of successful requests
312    successful_requests: u64,
313}
314
315impl TransportMetrics {
316    /// Track a successful request and its latency.
317    fn track_success(&mut self, duration: Duration) {
318        self.total_requests += 1;
319        self.successful_requests += 1;
320        self.last_update = Instant::now();
321
322        // Add to sample windows
323        self.latencies.push_back(duration);
324        self.successes.push_back(true);
325
326        // Limit to sample count
327        while self.latencies.len() > DEFAULT_SAMPLE_COUNT {
328            self.latencies.pop_front();
329        }
330        while self.successes.len() > DEFAULT_SAMPLE_COUNT {
331            self.successes.pop_front();
332        }
333    }
334
335    /// Track a failed request.
336    fn track_failure(&mut self) {
337        self.total_requests += 1;
338        self.last_update = Instant::now();
339
340        // Add to sample windows (no latency for failures)
341        self.successes.push_back(false);
342
343        // Limit to sample count
344        while self.successes.len() > DEFAULT_SAMPLE_COUNT {
345            self.successes.pop_front();
346        }
347    }
348
349    /// Calculate weighted score based on stability and latency
350    fn calculate_score(&self) -> f64 {
351        // If no data yet, return initial neutral score
352        if self.successes.is_empty() {
353            return 0.0;
354        }
355
356        // Calculate stability score (percentage of successful requests)
357        let success_count = self.successes.iter().filter(|&&s| s).count();
358        let stability_score = success_count as f64 / self.successes.len() as f64;
359
360        // Calculate latency score (lower is better)
361        let latency_score = if !self.latencies.is_empty() {
362            let avg_latency = self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
363                / self.latencies.len() as f64;
364
365            // Normalize latency score (1.0 for 0ms, approaches 0.0 as latency increases)
366            1.0 / (1.0 + avg_latency)
367        } else {
368            0.0
369        };
370
371        // Apply weights to calculate final score
372        (stability_score * STABILITY_WEIGHT) + (latency_score * LATENCY_WEIGHT)
373    }
374
375    /// Get a summary of metrics for debugging
376    fn get_summary(&self) -> String {
377        let success_rate = if !self.successes.is_empty() {
378            let success_count = self.successes.iter().filter(|&&s| s).count();
379            success_count as f64 / self.successes.len() as f64
380        } else {
381            0.0
382        };
383
384        let avg_latency = if !self.latencies.is_empty() {
385            self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
386                / self.latencies.len() as f64
387        } else {
388            0.0
389        };
390
391        format!(
392            "success_rate: {:.2}%, avg_latency: {:.2}ms, samples: {}, score: {:.4}",
393            success_rate * 100.0,
394            avg_latency * 1000.0,
395            self.successes.len(),
396            self.calculate_score()
397        )
398    }
399}
400
401impl Default for TransportMetrics {
402    fn default() -> Self {
403        Self {
404            latencies: VecDeque::new(),
405            successes: VecDeque::new(),
406            last_update: Instant::now(),
407            total_requests: 0,
408            successful_requests: 0,
409        }
410    }
411}