openvm_prof/
aggregate.rs

1use std::{collections::HashMap, io::Write};
2
3use eyre::Result;
4use serde::{Deserialize, Serialize};
5
6use crate::types::{BencherValue, BenchmarkOutput, Labels, MdTableCell, MetricDb};
7
8type MetricName = String;
9type MetricsByName = HashMap<MetricName, Vec<(f64, Labels)>>;
10
11#[derive(Clone, Debug, Default)]
12pub struct GroupedMetrics {
13    /// "group" label => metrics with that "group" label, further grouped by metric name
14    pub by_group: HashMap<String, MetricsByName>,
15    pub ungrouped: MetricsByName,
16}
17
18#[derive(Clone, Debug, Default, Serialize, Deserialize)]
19pub struct AggregateMetrics {
20    /// "group" label => metric aggregate statistics
21    #[serde(flatten)]
22    pub by_group: HashMap<String, HashMap<MetricName, Stats>>,
23    /// In seconds
24    pub total_proof_time: MdTableCell,
25    /// In seconds
26    pub total_par_proof_time: MdTableCell,
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct BencherAggregateMetrics {
31    #[serde(flatten)]
32    pub by_group: HashMap<String, HashMap<String, BencherValue>>,
33    /// In seconds
34    pub total_proof_time: BencherValue,
35    /// In seconds
36    pub total_par_proof_time: BencherValue,
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize)]
40pub struct Stats {
41    pub sum: MdTableCell,
42    pub max: MdTableCell,
43    pub min: MdTableCell,
44    pub avg: MdTableCell,
45    #[serde(skip)]
46    pub count: usize,
47}
48
49impl Default for Stats {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl Stats {
56    pub fn new() -> Self {
57        Self {
58            sum: MdTableCell::default(),
59            max: MdTableCell::default(),
60            min: MdTableCell::new(f64::MAX, None),
61            avg: MdTableCell::default(),
62            count: 0,
63        }
64    }
65    pub fn push(&mut self, value: f64) {
66        self.sum.val += value;
67        self.count += 1;
68        if value > self.max.val {
69            self.max.val = value;
70        }
71        if value < self.min.val {
72            self.min.val = value;
73        }
74    }
75
76    pub fn finalize(&mut self) {
77        assert!(self.count != 0);
78        self.avg.val = self.sum.val / self.count as f64;
79    }
80
81    pub fn set_diff(&mut self, prev: &Self) {
82        self.sum.diff = Some(self.sum.val - prev.sum.val);
83        self.max.diff = Some(self.max.val - prev.max.val);
84        self.min.diff = Some(self.min.val - prev.min.val);
85        self.avg.diff = Some(self.avg.val - prev.avg.val);
86    }
87}
88
89impl GroupedMetrics {
90    pub fn new(db: &MetricDb, group_label_name: &str) -> Result<Self> {
91        let mut by_group = HashMap::<String, MetricsByName>::new();
92        let mut ungrouped = MetricsByName::new();
93        for (labels, metrics) in db.flat_dict.iter() {
94            let group_name = labels.get(group_label_name);
95            if let Some(group_name) = group_name {
96                let group_entry = by_group.entry(group_name.to_string()).or_default();
97                let mut labels = labels.clone();
98                labels.remove(group_label_name);
99                for metric in metrics {
100                    group_entry
101                        .entry(metric.name.clone())
102                        .or_default()
103                        .push((metric.value, labels.clone()));
104                }
105            } else {
106                for metric in metrics {
107                    ungrouped
108                        .entry(metric.name.clone())
109                        .or_default()
110                        .push((metric.value, labels.clone()));
111                }
112            }
113        }
114        Ok(Self {
115            by_group,
116            ungrouped,
117        })
118    }
119
120    /// Validates that E1, metered, and preflight instruction counts all match each other
121    fn validate_instruction_counts(group_summaries: &HashMap<MetricName, Stats>) {
122        let e1_insns = group_summaries.get(EXECUTE_E1_INSNS_LABEL);
123        let metered_insns = group_summaries.get(EXECUTE_METERED_INSNS_LABEL);
124        let preflight_insns = group_summaries.get(EXECUTE_PREFLIGHT_INSNS_LABEL);
125
126        if let (Some(e1_insns), Some(preflight_insns)) = (e1_insns, preflight_insns) {
127            assert_eq!(e1_insns.sum.val as u64, preflight_insns.sum.val as u64);
128        }
129        if let (Some(e1_insns), Some(metered_insns)) = (e1_insns, metered_insns) {
130            assert_eq!(e1_insns.sum.val as u64, metered_insns.sum.val as u64);
131        }
132        if let (Some(metered_insns), Some(preflight_insns)) = (metered_insns, preflight_insns) {
133            assert_eq!(metered_insns.sum.val as u64, preflight_insns.sum.val as u64);
134        }
135    }
136
137    pub fn aggregate(&self) -> AggregateMetrics {
138        let by_group: HashMap<String, _> = self
139            .by_group
140            .iter()
141            .map(|(group_name, metrics)| {
142                let group_summaries: HashMap<MetricName, Stats> = metrics
143                    .iter()
144                    .map(|(metric_name, metrics)| {
145                        let mut summary = Stats::new();
146                        for (value, _) in metrics {
147                            summary.push(*value);
148                        }
149                        summary.finalize();
150                        (metric_name.clone(), summary)
151                    })
152                    .collect();
153
154                if !group_name.contains("keygen") {
155                    Self::validate_instruction_counts(&group_summaries);
156                }
157
158                (group_name.clone(), group_summaries)
159            })
160            .collect();
161        let mut metrics = AggregateMetrics {
162            by_group,
163            ..Default::default()
164        };
165        metrics.compute_total();
166        metrics
167    }
168}
169
170// A hacky way to order the groups for display.
171pub(crate) fn group_weight(name: &str) -> usize {
172    let label_prefix = ["leaf", "internal", "root", "halo2_outer", "halo2_wrapper"];
173    if name.contains("keygen") {
174        return label_prefix.len() + 1;
175    }
176    for (i, prefix) in label_prefix.iter().enumerate().rev() {
177        if name.starts_with(prefix) {
178            return i + 1;
179        }
180    }
181    0
182}
183
184impl AggregateMetrics {
185    pub fn compute_total(&mut self) {
186        let mut total_proof_time = MdTableCell::new(0.0, Some(0.0));
187        let mut total_par_proof_time = MdTableCell::new(0.0, Some(0.0));
188        for (group_name, metrics) in &self.by_group {
189            let stats = metrics.get(PROOF_TIME_LABEL);
190            let execute_metered_stats = metrics.get(EXECUTE_METERED_TIME_LABEL);
191            let execute_e1_stats = metrics.get(EXECUTE_E1_TIME_LABEL);
192            if stats.is_none() {
193                continue;
194            }
195            let stats = stats.unwrap_or_else(|| {
196                panic!("Missing proof time statistics for group '{}'", group_name)
197            });
198            let mut sum = stats.sum;
199            let mut max = stats.max;
200            // convert ms to s
201            sum.val /= 1000.0;
202            max.val /= 1000.0;
203            if let Some(diff) = &mut sum.diff {
204                *diff /= 1000.0;
205            }
206            if let Some(diff) = &mut max.diff {
207                *diff /= 1000.0;
208            }
209            if !group_name.contains("keygen") {
210                // Proving time in keygen group is dummy and not part of total.
211                total_proof_time.val += sum.val;
212                *total_proof_time
213                    .diff
214                    .as_mut()
215                    .expect("total_proof_time.diff should be initialized") +=
216                    sum.diff.unwrap_or(0.0);
217                total_par_proof_time.val += max.val;
218                *total_par_proof_time
219                    .diff
220                    .as_mut()
221                    .expect("total_par_proof_time.diff should be initialized") +=
222                    max.diff.unwrap_or(0.0);
223
224                // Account for the serial execute_metered and execute_e1 for app outside of segments
225                if group_name != "leaf"
226                    && group_name != "root"
227                    && group_name != "halo2_outer"
228                    && group_name != "halo2_wrapper"
229                    && !group_name.starts_with("internal")
230                {
231                    if let Some(execute_metered_stats) = execute_metered_stats {
232                        // For metered metrics without segment labels, we just use the value
233                        // directly Count is 1, so avg = sum = max = min =
234                        // value
235                        total_proof_time.val += execute_metered_stats.avg.val / 1000.0;
236                        total_par_proof_time.val += execute_metered_stats.avg.val / 1000.0;
237                        if let Some(diff) = execute_metered_stats.avg.diff {
238                            *total_proof_time
239                                .diff
240                                .as_mut()
241                                .expect("total_proof_time.diff should be initialized") +=
242                                diff / 1000.0;
243                            *total_par_proof_time
244                                .diff
245                                .as_mut()
246                                .expect("total_par_proof_time.diff should be initialized") +=
247                                diff / 1000.0;
248                        }
249                    }
250
251                    if let Some(execute_e1_stats) = execute_e1_stats {
252                        total_proof_time.val += execute_e1_stats.avg.val / 1000.0;
253                        total_par_proof_time.val += execute_e1_stats.avg.val / 1000.0;
254                        if let Some(diff) = execute_e1_stats.avg.diff {
255                            *total_proof_time
256                                .diff
257                                .as_mut()
258                                .expect("total_proof_time.diff should be initialized") +=
259                                diff / 1000.0;
260                            *total_par_proof_time
261                                .diff
262                                .as_mut()
263                                .expect("total_par_proof_time.diff should be initialized") +=
264                                diff / 1000.0;
265                        }
266                    }
267                }
268            }
269        }
270        self.total_proof_time = total_proof_time;
271        self.total_par_proof_time = total_par_proof_time;
272    }
273
274    pub fn set_diff(&mut self, prev: &Self) {
275        for (group_name, metrics) in self.by_group.iter_mut() {
276            if let Some(prev_metrics) = prev.by_group.get(group_name) {
277                for (metric_name, stats) in metrics.iter_mut() {
278                    if let Some(prev_stats) = prev_metrics.get(metric_name) {
279                        stats.set_diff(prev_stats);
280                    }
281                }
282            }
283        }
284        self.compute_total();
285    }
286
287    pub fn to_vec(&self) -> Vec<(String, HashMap<MetricName, Stats>)> {
288        let mut group_names: Vec<_> = self.by_group.keys().collect();
289        group_names.sort_by(|a, b| {
290            let a_wt = group_weight(a);
291            let b_wt = group_weight(b);
292            if a_wt == b_wt {
293                a.cmp(b)
294            } else {
295                a_wt.cmp(&b_wt)
296            }
297        });
298        group_names
299            .into_iter()
300            .map(|group_name| {
301                let key = group_name.clone();
302                let value = self
303                    .by_group
304                    .get(group_name)
305                    .unwrap_or_else(|| {
306                        panic!("Group '{}' should exist in by_group map", group_name)
307                    })
308                    .clone();
309                (key, value)
310            })
311            .collect()
312    }
313
314    pub fn to_bencher_metrics(&self) -> BencherAggregateMetrics {
315        let by_group = self
316            .by_group
317            .iter()
318            .map(|(group_name, metrics)| {
319                let metrics = metrics
320                    .iter()
321                    .filter(|(_, stats)| stats.avg.val.is_finite() && stats.sum.val.is_finite())
322                    .flat_map(|(metric_name, stats)| {
323                        [
324                            (format!("{metric_name}::sum"), stats.sum.into()),
325                            (
326                                metric_name.clone(),
327                                BencherValue {
328                                    value: stats.avg.val,
329                                    lower_value: Some(stats.min.val),
330                                    upper_value: Some(stats.max.val),
331                                },
332                            ),
333                        ]
334                    })
335                    .collect();
336                (group_name.clone(), metrics)
337            })
338            .collect();
339        let total_proof_time = self.total_proof_time.into();
340        let total_par_proof_time = self.total_par_proof_time.into();
341        BencherAggregateMetrics {
342            by_group,
343            total_proof_time,
344            total_par_proof_time,
345        }
346    }
347
348    pub fn write_markdown(&self, writer: &mut impl Write, metric_names: &[&str]) -> Result<()> {
349        self.write_summary_markdown(writer)?;
350        writeln!(writer)?;
351
352        let metric_names = metric_names.to_vec();
353        for (group_name, summaries) in self.to_vec() {
354            writeln!(writer, "| {} |||||", group_name)?;
355            writeln!(writer, "|:---|---:|---:|---:|---:|")?;
356            writeln!(writer, "|metric|avg|sum|max|min|")?;
357            let names = if metric_names.is_empty() {
358                summaries.keys().map(|s| s.as_str()).collect()
359            } else {
360                metric_names.clone()
361            };
362            for metric_name in names {
363                let summary = summaries.get(metric_name);
364                if let Some(summary) = summary {
365                    // Special handling for execute_metered metrics (not aggregated across segments
366                    // in the app proof case)
367                    if metric_name == EXECUTE_METERED_TIME_LABEL
368                        && group_name != "leaf"
369                        && group_name != "root"
370                        && group_name != "halo2_outer"
371                        && group_name != "halo2_wrapper"
372                        && !group_name.starts_with("internal")
373                    {
374                        writeln!(
375                            writer,
376                            "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
377                            metric_name, summary.avg, "-", "-", "-",
378                        )?;
379                    } else if metric_name == EXECUTE_E1_INSN_MI_S_LABEL
380                        || metric_name == EXECUTE_PREFLIGHT_INSN_MI_S_LABEL
381                        || metric_name == EXECUTE_METERED_INSN_MI_S_LABEL
382                    {
383                        // skip sum because it is misleading
384                        writeln!(
385                            writer,
386                            "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
387                            metric_name, summary.avg, "-", summary.max, summary.min,
388                        )?;
389                    } else {
390                        writeln!(
391                            writer,
392                            "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
393                            metric_name, summary.avg, summary.sum, summary.max, summary.min,
394                        )?;
395                    }
396                }
397            }
398            writeln!(writer)?;
399        }
400        writeln!(writer)?;
401
402        Ok(())
403    }
404
405    fn write_summary_markdown(&self, writer: &mut impl Write) -> Result<()> {
406        writeln!(
407            writer,
408            "| Summary | Proof Time (s) | Parallel Proof Time (s) |"
409        )?;
410        writeln!(writer, "|:---|---:|---:|")?;
411        let mut rows = Vec::new();
412        for (group_name, summaries) in self.to_vec() {
413            if group_name.contains("keygen") {
414                continue;
415            }
416            let stats = summaries.get(PROOF_TIME_LABEL);
417            if stats.is_none() {
418                continue;
419            }
420            let stats = stats.unwrap_or_else(|| {
421                panic!("Missing proof time statistics for group '{}'", group_name)
422            });
423            let mut sum = stats.sum;
424            let mut max = stats.max;
425            // convert ms to s
426            sum.val /= 1000.0;
427            max.val /= 1000.0;
428            if let Some(diff) = &mut sum.diff {
429                *diff /= 1000.0;
430            }
431            if let Some(diff) = &mut max.diff {
432                *diff /= 1000.0;
433            }
434            rows.push((group_name, sum, max));
435        }
436        writeln!(
437            writer,
438            "| Total | {} | {} |",
439            self.total_proof_time, self.total_par_proof_time
440        )?;
441        for (group_name, proof_time, par_proof_time) in rows {
442            writeln!(writer, "| {group_name} | {proof_time} | {par_proof_time} |")?;
443        }
444        writeln!(writer)?;
445        Ok(())
446    }
447
448    pub fn name(&self) -> String {
449        // A hacky way to determine the app name
450        self.by_group
451            .keys()
452            .find(|k| group_weight(k) == 0)
453            .unwrap_or_else(|| {
454                self.by_group
455                    .keys()
456                    .next()
457                    .expect("by_group should contain at least one group")
458            })
459            .clone()
460    }
461}
462
463impl BenchmarkOutput {
464    pub fn insert(&mut self, name: &str, metrics: BencherAggregateMetrics) {
465        for (group_name, metrics) in metrics.by_group {
466            self.by_name
467                .entry(format!("{name}::{group_name}"))
468                .or_default()
469                .extend(metrics);
470        }
471        if let Some(e) = self.by_name.insert(
472            name.to_owned(),
473            HashMap::from_iter([
474                ("total_proof_time".to_owned(), metrics.total_proof_time),
475                (
476                    "total_par_proof_time".to_owned(),
477                    metrics.total_par_proof_time,
478                ),
479            ]),
480        ) {
481            panic!("Duplicate metric: {e:?}");
482        }
483    }
484}
485
486pub const PROOF_TIME_LABEL: &str = "total_proof_time_ms";
487pub const MAIN_CELLS_USED_LABEL: &str = "main_cells_used";
488pub const TOTAL_CELLS_USED_LABEL: &str = "total_cells_used";
489pub const EXECUTE_E1_INSNS_LABEL: &str = "execute_e1_insns";
490pub const EXECUTE_METERED_INSNS_LABEL: &str = "execute_metered_insns";
491pub const EXECUTE_PREFLIGHT_INSNS_LABEL: &str = "execute_preflight_insns";
492pub const EXECUTE_E1_TIME_LABEL: &str = "execute_e1_time_ms";
493pub const EXECUTE_E1_INSN_MI_S_LABEL: &str = "execute_e1_insn_mi/s";
494pub const EXECUTE_METERED_TIME_LABEL: &str = "execute_metered_time_ms";
495pub const EXECUTE_METERED_INSN_MI_S_LABEL: &str = "execute_metered_insn_mi/s";
496pub const EXECUTE_PREFLIGHT_TIME_LABEL: &str = "execute_preflight_time_ms";
497pub const EXECUTE_PREFLIGHT_INSN_MI_S_LABEL: &str = "execute_preflight_insn_mi/s";
498pub const TRACE_GEN_TIME_LABEL: &str = "trace_gen_time_ms";
499pub const MEM_FIN_TIME_LABEL: &str = "memory_finalize_time_ms";
500pub const BOUNDARY_FIN_TIME_LABEL: &str = "boundary_finalize_time_ms";
501pub const MERKLE_FIN_TIME_LABEL: &str = "merkle_finalize_time_ms";
502pub const PROVE_EXCL_TRACE_TIME_LABEL: &str = "stark_prove_excluding_trace_time_ms";
503
504pub const VM_METRIC_NAMES: &[&str] = &[
505    PROOF_TIME_LABEL,
506    MAIN_CELLS_USED_LABEL,
507    TOTAL_CELLS_USED_LABEL,
508    EXECUTE_E1_TIME_LABEL,
509    EXECUTE_E1_INSN_MI_S_LABEL,
510    EXECUTE_METERED_TIME_LABEL,
511    EXECUTE_METERED_INSN_MI_S_LABEL,
512    EXECUTE_PREFLIGHT_INSNS_LABEL,
513    EXECUTE_PREFLIGHT_TIME_LABEL,
514    EXECUTE_PREFLIGHT_INSN_MI_S_LABEL,
515    TRACE_GEN_TIME_LABEL,
516    MEM_FIN_TIME_LABEL,
517    BOUNDARY_FIN_TIME_LABEL,
518    MERKLE_FIN_TIME_LABEL,
519    PROVE_EXCL_TRACE_TIME_LABEL,
520    "main_trace_commit_time_ms",
521    "generate_perm_trace_time_ms",
522    "perm_trace_commit_time_ms",
523    "quotient_poly_compute_time_ms",
524    "quotient_poly_commit_time_ms",
525    "pcs_opening_time_ms",
526];