openvm_stark_sdk/
metrics_tracing.rs1use std::{sync::Arc, time::Instant};
2
3use dashmap::DashMap;
4use tracing::{
5 field::{Field, Visit},
6 Id, Subscriber,
7};
8use tracing_subscriber::{registry::LookupSpan, Layer};
9
10#[derive(Clone, Default)]
13pub struct TimingMetricsLayer {
14 span_timings: Arc<DashMap<Id, SpanTiming>>,
16}
17
18#[derive(Debug)]
19struct SpanTiming {
20 name: String,
21 start_time: Instant,
22}
23
24struct ReturnValueVisitor {
26 has_return: bool,
27}
28
29impl Visit for ReturnValueVisitor {
30 fn record_debug(&mut self, field: &Field, _value: &dyn std::fmt::Debug) {
31 if field.name() == "return" {
32 self.has_return = true;
33 }
34 }
35
36 fn record_i64(&mut self, _field: &Field, _value: i64) {}
37 fn record_u64(&mut self, _field: &Field, _value: u64) {}
38 fn record_bool(&mut self, _field: &Field, _value: bool) {}
39 fn record_str(&mut self, _field: &Field, _value: &str) {}
40}
41
42impl TimingMetricsLayer {
43 pub fn new() -> Self {
45 Self::default()
46 }
47}
48
49impl<S> Layer<S> for TimingMetricsLayer
50where
51 S: Subscriber + for<'a> LookupSpan<'a>,
52{
53 fn on_new_span(
54 &self,
55 _attrs: &tracing::span::Attributes<'_>,
56 id: &Id,
57 ctx: tracing_subscriber::layer::Context<'_, S>,
58 ) {
59 if let Some(span) = ctx.span(id) {
60 let metadata = span.metadata();
61 let name = metadata.name();
62
63 if metadata.level() <= &tracing::Level::INFO {
65 self.span_timings.insert(
66 id.clone(),
67 SpanTiming {
68 name: name.to_string(),
69 start_time: Instant::now(),
70 },
71 );
72 }
73 }
74 }
75
76 fn on_event(&self, event: &tracing::Event<'_>, ctx: tracing_subscriber::layer::Context<'_, S>) {
77 let mut visitor = ReturnValueVisitor { has_return: false };
79 event.record(&mut visitor);
80
81 if visitor.has_return {
82 if let Some(span) = ctx.event_span(event) {
84 let span_id = span.id();
85
86 if let Some((_, timing)) = self.span_timings.remove(&span_id) {
88 let duration_ms = timing.start_time.elapsed().as_millis() as f64;
89
90 metrics::gauge!(format!("{}_time_ms", timing.name)).set(duration_ms);
93 }
94 }
95 }
96 }
97
98 fn on_close(&self, id: Id, _ctx: tracing_subscriber::layer::Context<'_, S>) {
99 if let Some((_, timing)) = self.span_timings.remove(&id) {
102 let duration_ms = timing.start_time.elapsed().as_millis() as f64;
103
104 metrics::gauge!(format!("{}_time_ms", timing.name)).set(duration_ms);
106 }
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use tracing::instrument;
113 use tracing_subscriber::{layer::SubscriberExt, Registry};
114
115 use super::*;
116
117 #[instrument(level = "info")]
118 fn example_function() -> i32 {
119 std::thread::sleep(std::time::Duration::from_millis(10));
120 42
121 }
122
123 #[test]
124 fn test_metrics_layer() {
125 let subscriber = Registry::default().with(TimingMetricsLayer::new());
126
127 tracing::subscriber::with_default(subscriber, || {
128 let result = example_function();
129 assert_eq!(result, 42);
130 });
131 }
132}