openvm_stark_sdk/
metrics_tracing.rs

1use std::{sync::Arc, time::Instant};
2
3use dashmap::DashMap;
4use tracing::{
5    field::{Field, Visit},
6    Id, Subscriber,
7};
8use tracing_subscriber::{registry::LookupSpan, Layer};
9
10/// A tracing layer that automatically emits metric gauges for all span durations.
11/// This replaces the need for manual metrics_span calls by leveraging the tracing infrastructure.
12#[derive(Clone, Default)]
13pub struct TimingMetricsLayer {
14    /// Store span timings indexed by span ID
15    span_timings: Arc<DashMap<Id, SpanTiming>>,
16}
17
18#[derive(Debug)]
19struct SpanTiming {
20    name: String,
21    start_time: Instant,
22}
23
24/// A visitor to extract the return value from span events
25struct ReturnValueVisitor {
26    has_return: bool,
27}
28
29impl Visit for ReturnValueVisitor {
30    fn record_debug(&mut self, field: &Field, _value: &dyn std::fmt::Debug) {
31        if field.name() == "return" {
32            self.has_return = true;
33        }
34    }
35
36    fn record_i64(&mut self, _field: &Field, _value: i64) {}
37    fn record_u64(&mut self, _field: &Field, _value: u64) {}
38    fn record_bool(&mut self, _field: &Field, _value: bool) {}
39    fn record_str(&mut self, _field: &Field, _value: &str) {}
40}
41
42impl TimingMetricsLayer {
43    /// Create a new TimingMetricsLayer
44    pub fn new() -> Self {
45        Self::default()
46    }
47}
48
49impl<S> Layer<S> for TimingMetricsLayer
50where
51    S: Subscriber + for<'a> LookupSpan<'a>,
52{
53    fn on_new_span(
54        &self,
55        _attrs: &tracing::span::Attributes<'_>,
56        id: &Id,
57        ctx: tracing_subscriber::layer::Context<'_, S>,
58    ) {
59        if let Some(span) = ctx.span(id) {
60            let metadata = span.metadata();
61            let name = metadata.name();
62
63            // Only track spans at INFO level or higher to match metrics_span behavior
64            if metadata.level() <= &tracing::Level::INFO {
65                self.span_timings.insert(
66                    id.clone(),
67                    SpanTiming {
68                        name: name.to_string(),
69                        start_time: Instant::now(),
70                    },
71                );
72            }
73        }
74    }
75
76    fn on_event(&self, event: &tracing::Event<'_>, ctx: tracing_subscriber::layer::Context<'_, S>) {
77        // Check if this is a return event in an instrumented function
78        let mut visitor = ReturnValueVisitor { has_return: false };
79        event.record(&mut visitor);
80
81        if visitor.has_return {
82            // Get the current span
83            if let Some(span) = ctx.event_span(event) {
84                let span_id = span.id();
85
86                // Emit metric for the span that's returning
87                if let Some((_, timing)) = self.span_timings.remove(&span_id) {
88                    let duration_ms = timing.start_time.elapsed().as_millis() as f64;
89
90                    // Emit the metric gauge with the span name
91                    // This matches the behavior of metrics_span
92                    metrics::gauge!(format!("{}_time_ms", timing.name)).set(duration_ms);
93                }
94            }
95        }
96    }
97
98    fn on_close(&self, id: Id, _ctx: tracing_subscriber::layer::Context<'_, S>) {
99        // Clean up any spans that weren't emitted via return events
100        // This handles spans that don't have instrumented return values
101        if let Some((_, timing)) = self.span_timings.remove(&id) {
102            let duration_ms = timing.start_time.elapsed().as_millis() as f64;
103
104            // Emit the metric gauge with the span name
105            metrics::gauge!(format!("{}_time_ms", timing.name)).set(duration_ms);
106        }
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use tracing::instrument;
113    use tracing_subscriber::{layer::SubscriberExt, Registry};
114
115    use super::*;
116
117    #[instrument(level = "info")]
118    fn example_function() -> i32 {
119        std::thread::sleep(std::time::Duration::from_millis(10));
120        42
121    }
122
123    #[test]
124    fn test_metrics_layer() {
125        let subscriber = Registry::default().with(TimingMetricsLayer::new());
126
127        tracing::subscriber::with_default(subscriber, || {
128            let result = example_function();
129            assert_eq!(result, 42);
130        });
131    }
132}