openvm_prof/
lib.rs

1use std::{collections::HashMap, fs::File, path::Path};
2
3use aggregate::{PROOF_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, TRACE_GEN_TIME_LABEL};
4use eyre::Result;
5use memmap2::Mmap;
6
7use crate::{
8    aggregate::{EXECUTE_METERED_TIME_LABEL, EXECUTE_PREFLIGHT_TIME_LABEL},
9    types::{Labels, Metric, MetricDb, MetricsFile},
10};
11
12pub mod aggregate;
13pub mod instruction_count;
14pub mod summary;
15pub mod types;
16
17impl MetricDb {
18    pub fn new(metrics_file: impl AsRef<Path>) -> Result<Self> {
19        let file = File::open(metrics_file)?;
20        // SAFETY: File is read-only mapped. File will not be modified by other
21        // processes during the mapping's lifetime.
22        let mmap = unsafe { Mmap::map(&file)? };
23        let metrics: MetricsFile = serde_json::from_slice(&mmap)?;
24
25        let mut db = MetricDb::default();
26
27        // Process counters
28        for entry in metrics.counter {
29            if entry.value == 0.0 {
30                continue;
31            }
32            let labels = Labels::from(entry.labels);
33            db.add_to_flat_dict(labels, entry.metric, entry.value);
34        }
35
36        // Process gauges
37        for entry in metrics.gauge {
38            let labels = Labels::from(entry.labels);
39            db.add_to_flat_dict(labels, entry.metric, entry.value);
40        }
41
42        db.apply_aggregations();
43        db.separate_by_label_types();
44
45        Ok(db)
46    }
47
48    // Currently hardcoding aggregations
49    pub fn apply_aggregations(&mut self) {
50        for metrics in self.flat_dict.values_mut() {
51            let get = |key: &str| metrics.iter().find(|m| m.name == key).map(|m| m.value);
52            let total_proof_time = get(PROOF_TIME_LABEL);
53            if total_proof_time.is_some() {
54                // We have instrumented total_proof_time_ms
55                continue;
56            }
57            // otherwise, calculate it from sub-components
58            let execute_metered_time = get(EXECUTE_METERED_TIME_LABEL);
59            let execute_preflight_time = get(EXECUTE_PREFLIGHT_TIME_LABEL);
60            let trace_gen_time = get(TRACE_GEN_TIME_LABEL);
61            let prove_excl_trace_time = get(PROVE_EXCL_TRACE_TIME_LABEL);
62            if let (
63                Some(execute_preflight_time),
64                Some(trace_gen_time),
65                Some(prove_excl_trace_time),
66            ) = (
67                execute_preflight_time,
68                trace_gen_time,
69                prove_excl_trace_time,
70            ) {
71                let total_time = execute_metered_time.unwrap_or(0.0)
72                    + execute_preflight_time
73                    + trace_gen_time
74                    + prove_excl_trace_time;
75                metrics.push(Metric::new(PROOF_TIME_LABEL.to_string(), total_time));
76            }
77        }
78    }
79
80    pub fn add_to_flat_dict(&mut self, labels: Labels, metric: String, value: f64) {
81        self.flat_dict
82            .entry(labels)
83            .or_default()
84            .push(Metric::new(metric, value));
85    }
86
87    // Custom sorting function that ensures 'group' comes first.
88    // Other keys are sorted alphabetically.
89    pub fn custom_sort_label_keys(label_keys: &mut [String]) {
90        // Prioritize 'group' by giving it the lowest possible sort value
91        label_keys.sort_by_key(|key| {
92            if key == "group" {
93                (0, key.clone()) // Lowest priority for 'group'
94            } else {
95                (1, key.clone()) // Normal priority for other keys
96            }
97        });
98    }
99
100    pub fn separate_by_label_types(&mut self) {
101        self.dict_by_label_types.clear();
102
103        for (labels, metrics) in &self.flat_dict {
104            // Get sorted label keys
105            let mut label_keys: Vec<String> = labels.0.iter().map(|(key, _)| key.clone()).collect();
106            Self::custom_sort_label_keys(&mut label_keys);
107
108            // Create label_values based on sorted keys
109            let label_dict: HashMap<String, String> = labels.0.iter().cloned().collect();
110
111            let label_values: Vec<String> = label_keys
112                .iter()
113                .map(|key| {
114                    label_dict
115                        .get(key)
116                        .unwrap_or_else(|| panic!("Label key '{key}' should exist in label_dict"))
117                        .clone()
118                })
119                .collect();
120
121            // Remove cycle_tracker_span and dsl_ir if present as they are too long for markdown and
122            // visualized in flamegraphs
123            let mut keys = label_keys.clone();
124            let mut values = label_values.clone();
125
126            // Remove cycle_tracker_span if present
127            if let Some(index) = keys.iter().position(|k| k == "cycle_tracker_span") {
128                keys.remove(index);
129                values.remove(index);
130            }
131
132            // Remove dsl_ir if present
133            if let Some(index) = keys.iter().position(|k| k == "dsl_ir") {
134                keys.remove(index);
135                values.remove(index);
136            }
137
138            let (final_label_keys, final_label_values) = (keys, values);
139
140            // Add to dict_by_label_types, combining metrics with same name by summing values
141            let entry = self
142                .dict_by_label_types
143                .entry(final_label_keys)
144                .or_default()
145                .entry(final_label_values)
146                .or_default();
147
148            for metric in metrics.clone() {
149                if let Some(existing_metric) = entry.iter_mut().find(|m| m.name == metric.name) {
150                    // Sum the values for metrics with the same name
151                    existing_metric.value += metric.value;
152                } else {
153                    // Add new metric if no existing one with same name
154                    entry.push(metric);
155                }
156            }
157        }
158    }
159
160    pub fn generate_markdown_tables(&self) -> String {
161        let mut markdown_output = String::new();
162        // Get sorted keys to iterate in consistent order
163        let mut sorted_keys: Vec<_> = self.dict_by_label_types.keys().cloned().collect();
164        sorted_keys.sort();
165
166        for label_keys in sorted_keys {
167            let metrics_dict = &self.dict_by_label_types[&label_keys];
168            let mut metric_names: Vec<String> = metrics_dict
169                .values()
170                .flat_map(|metrics| metrics.iter().map(|m| m.name.clone()))
171                .collect::<std::collections::HashSet<_>>()
172                .into_iter()
173                .collect();
174            metric_names.sort_by(|a, b| b.cmp(a));
175
176            // Create table header
177            let header = format!(
178                "| {} | {} |",
179                label_keys.join(" | "),
180                metric_names.join(" | ")
181            );
182
183            let separator = "| ".to_string()
184                + &vec!["---"; label_keys.len() + metric_names.len()].join(" | ")
185                + " |";
186
187            markdown_output.push_str(&header);
188            markdown_output.push('\n');
189            markdown_output.push_str(&separator);
190            markdown_output.push('\n');
191
192            // Sort rows: first by segment (ascending) if present, then by frequency (descending) if
193            // present
194            let mut rows: Vec<_> = metrics_dict.iter().collect();
195            let segment_index = label_keys.iter().position(|k| k == "segment");
196            let has_frequency = metric_names.contains(&"frequency".to_string());
197
198            if segment_index.is_some() || has_frequency {
199                rows.sort_by(|(label_values_a, metrics_a), (label_values_b, metrics_b)| {
200                    // First, sort by segment (ascending) if present
201                    if let Some(seg_idx) = segment_index {
202                        let seg_a = label_values_a
203                            .get(seg_idx)
204                            .map(|s| s.as_str())
205                            .unwrap_or("");
206                        let seg_b = label_values_b
207                            .get(seg_idx)
208                            .map(|s| s.as_str())
209                            .unwrap_or("");
210                        let seg_cmp = seg_a.cmp(seg_b);
211                        if seg_cmp != std::cmp::Ordering::Equal {
212                            return seg_cmp;
213                        }
214                    }
215
216                    // Then, sort by frequency (descending) if present
217                    if has_frequency {
218                        let freq_a = metrics_a
219                            .iter()
220                            .find(|m| m.name == "frequency")
221                            .map(|m| m.value)
222                            .unwrap_or(0.0);
223                        let freq_b = metrics_b
224                            .iter()
225                            .find(|m| m.name == "frequency")
226                            .map(|m| m.value)
227                            .unwrap_or(0.0);
228                        return freq_b
229                            .partial_cmp(&freq_a)
230                            .unwrap_or(std::cmp::Ordering::Equal);
231                    }
232
233                    std::cmp::Ordering::Equal
234                });
235            }
236
237            // Fill table rows
238            for (label_values, metrics) in rows {
239                let mut row = String::new();
240                row.push_str("| ");
241                row.push_str(&label_values.join(" | "));
242                row.push_str(" | ");
243
244                // Add metric values
245                for metric_name in &metric_names {
246                    let metric_value = metrics
247                        .iter()
248                        .find(|m| &m.name == metric_name)
249                        .map(|m| Self::format_number(m.value))
250                        .unwrap_or_default();
251
252                    row.push_str(&format!("{metric_value} | "));
253                }
254
255                markdown_output.push_str(&row);
256                markdown_output.push('\n');
257            }
258
259            markdown_output.push('\n');
260        }
261
262        markdown_output
263    }
264}