1use crate::time::Instant;
2use alloy_json_rpc::{RequestPacket, ResponsePacket};
3use core::time::Duration;
4use derive_more::{Deref, DerefMut};
5use futures::{stream::FuturesUnordered, StreamExt};
6use parking_lot::RwLock;
7use std::{
8 collections::VecDeque,
9 num::NonZeroUsize,
10 sync::Arc,
11 task::{Context, Poll},
12};
13use tower::{Layer, Service};
14use tracing::trace;
15
16use crate::{TransportError, TransportErrorKind, TransportFut};
17
18const STABILITY_WEIGHT: f64 = 0.7;
20const LATENCY_WEIGHT: f64 = 0.3;
21const DEFAULT_SAMPLE_COUNT: usize = 10;
22const DEFAULT_ACTIVE_TRANSPORT_COUNT: usize = 3;
23
24#[derive(Debug, Clone)]
30pub struct FallbackService<S> {
31 transports: Arc<Vec<ScoredTransport<S>>>,
33 active_transport_count: usize,
35}
36
37impl<S: Clone> FallbackService<S> {
38 pub fn new(transports: Vec<S>, active_transport_count: usize) -> Self {
43 let scored_transports = transports
44 .into_iter()
45 .enumerate()
46 .map(|(id, transport)| ScoredTransport::new(id, transport))
47 .collect::<Vec<_>>();
48
49 Self { transports: Arc::new(scored_transports), active_transport_count }
50 }
51
52 fn log_transport_rankings(&self) {
54 if !tracing::enabled!(tracing::Level::TRACE) {
55 return;
56 }
57
58 let mut ranked: Vec<(usize, f64, String)> =
60 self.transports.iter().map(|t| (t.id, t.score(), t.metrics_summary())).collect();
61
62 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
63
64 trace!("Current transport rankings:");
65 for (idx, (id, _score, summary)) in ranked.iter().enumerate() {
66 trace!(" #{}: Transport[{}] - {}", idx + 1, id, summary);
67 }
68 }
69}
70
71impl<S> FallbackService<S>
72where
73 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
74 + Send
75 + Clone
76 + 'static,
77{
78 async fn make_request(&self, req: RequestPacket) -> Result<ResponsePacket, TransportError> {
90 let top_transports = {
92 let mut transports_clone = (*self.transports).clone();
94 transports_clone.sort_by(|a, b| b.cmp(a));
95 transports_clone.into_iter().take(self.active_transport_count).collect::<Vec<_>>()
96 };
97
98 let mut futures = FuturesUnordered::new();
100
101 for mut transport in top_transports {
103 let req_clone = req.clone();
104
105 let future = async move {
106 let start = Instant::now();
107 let result = transport.call(req_clone).await;
108 trace!(
109 "Transport[{}] completed: latency={:?}, status={}",
110 transport.id,
111 start.elapsed(),
112 if result.is_ok() { "success" } else { "fail" }
113 );
114
115 (result, transport, start.elapsed())
116 };
117
118 futures.push(future);
119 }
120
121 let mut last_error = None;
123
124 while let Some((result, transport, duration)) = futures.next().await {
125 match result {
126 Ok(response) => {
127 transport.track_success(duration);
129
130 self.log_transport_rankings();
131
132 return Ok(response);
133 }
134 Err(error) => {
135 transport.track_failure();
137
138 last_error = Some(error);
139 }
140 }
141 }
142
143 Err(last_error.unwrap_or_else(|| {
144 TransportErrorKind::custom_str("All transport futures failed to complete")
145 }))
146 }
147}
148
149impl<S> Service<RequestPacket> for FallbackService<S>
150where
151 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
152 + Send
153 + Sync
154 + Clone
155 + 'static,
156{
157 type Response = ResponsePacket;
158 type Error = TransportError;
159 type Future = TransportFut<'static>;
160
161 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
162 Poll::Ready(Ok(()))
164 }
165
166 fn call(&mut self, req: RequestPacket) -> Self::Future {
167 let this = self.clone();
168 Box::pin(async move { this.make_request(req).await })
169 }
170}
171
172#[derive(Debug, Clone)]
191pub struct FallbackLayer {
192 active_transport_count: usize,
194}
195
196impl FallbackLayer {
197 pub const fn with_active_transport_count(mut self, count: NonZeroUsize) -> Self {
199 self.active_transport_count = count.get();
200 self
201 }
202}
203
204impl<S> Layer<Vec<S>> for FallbackLayer
205where
206 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
207 + Send
208 + Clone
209 + 'static,
210{
211 type Service = FallbackService<S>;
212
213 fn layer(&self, inner: Vec<S>) -> Self::Service {
214 FallbackService::new(inner, self.active_transport_count)
215 }
216}
217
218impl Default for FallbackLayer {
219 fn default() -> Self {
220 Self { active_transport_count: DEFAULT_ACTIVE_TRANSPORT_COUNT }
221 }
222}
223
224#[derive(Debug, Clone, Deref, DerefMut)]
237struct ScoredTransport<S> {
238 #[deref]
240 #[deref_mut]
241 transport: S,
242 id: usize,
244 metrics: Arc<RwLock<TransportMetrics>>,
246}
247
248impl<S> ScoredTransport<S> {
249 fn new(id: usize, transport: S) -> Self {
251 Self { id, transport, metrics: Arc::new(Default::default()) }
252 }
253
254 fn score(&self) -> f64 {
256 let metrics = self.metrics.read();
257 metrics.calculate_score()
258 }
259
260 fn metrics_summary(&self) -> String {
262 let metrics = self.metrics.read();
263 metrics.get_summary()
264 }
265
266 fn track_success(&self, duration: Duration) {
268 let mut metrics = self.metrics.write();
269 metrics.track_success(duration);
270 }
271
272 fn track_failure(&self) {
274 let mut metrics = self.metrics.write();
275 metrics.track_failure();
276 }
277}
278
279impl<S> PartialEq for ScoredTransport<S> {
280 fn eq(&self, other: &Self) -> bool {
281 self.score().eq(&other.score())
282 }
283}
284
285impl<S> Eq for ScoredTransport<S> {}
286
287#[expect(clippy::non_canonical_partial_ord_impl)]
288impl<S> PartialOrd for ScoredTransport<S> {
289 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
290 self.score().partial_cmp(&other.score())
291 }
292}
293
294impl<S> Ord for ScoredTransport<S> {
295 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
296 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
297 }
298}
299
300#[derive(Debug)]
302struct TransportMetrics {
303 latencies: VecDeque<Duration>,
305 successes: VecDeque<bool>,
307 last_update: Instant,
309 total_requests: u64,
311 successful_requests: u64,
313}
314
315impl TransportMetrics {
316 fn track_success(&mut self, duration: Duration) {
318 self.total_requests += 1;
319 self.successful_requests += 1;
320 self.last_update = Instant::now();
321
322 self.latencies.push_back(duration);
324 self.successes.push_back(true);
325
326 while self.latencies.len() > DEFAULT_SAMPLE_COUNT {
328 self.latencies.pop_front();
329 }
330 while self.successes.len() > DEFAULT_SAMPLE_COUNT {
331 self.successes.pop_front();
332 }
333 }
334
335 fn track_failure(&mut self) {
337 self.total_requests += 1;
338 self.last_update = Instant::now();
339
340 self.successes.push_back(false);
342
343 while self.successes.len() > DEFAULT_SAMPLE_COUNT {
345 self.successes.pop_front();
346 }
347 }
348
349 fn calculate_score(&self) -> f64 {
351 if self.successes.is_empty() {
353 return 0.0;
354 }
355
356 let success_count = self.successes.iter().filter(|&&s| s).count();
358 let stability_score = success_count as f64 / self.successes.len() as f64;
359
360 let latency_score = if !self.latencies.is_empty() {
362 let avg_latency = self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
363 / self.latencies.len() as f64;
364
365 1.0 / (1.0 + avg_latency)
367 } else {
368 0.0
369 };
370
371 (stability_score * STABILITY_WEIGHT) + (latency_score * LATENCY_WEIGHT)
373 }
374
375 fn get_summary(&self) -> String {
377 let success_rate = if !self.successes.is_empty() {
378 let success_count = self.successes.iter().filter(|&&s| s).count();
379 success_count as f64 / self.successes.len() as f64
380 } else {
381 0.0
382 };
383
384 let avg_latency = if !self.latencies.is_empty() {
385 self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
386 / self.latencies.len() as f64
387 } else {
388 0.0
389 };
390
391 format!(
392 "success_rate: {:.2}%, avg_latency: {:.2}ms, samples: {}, score: {:.4}",
393 success_rate * 100.0,
394 avg_latency * 1000.0,
395 self.successes.len(),
396 self.calculate_score()
397 )
398 }
399}
400
401impl Default for TransportMetrics {
402 fn default() -> Self {
403 Self {
404 latencies: VecDeque::new(),
405 successes: VecDeque::new(),
406 last_update: Instant::now(),
407 total_requests: 0,
408 successful_requests: 0,
409 }
410 }
411}