openvm_prof/
aggregate.rs

1use std::{collections::HashMap, io::Write};
2
3use eyre::Result;
4use serde::{Deserialize, Serialize};
5
6use crate::types::{BencherValue, BenchmarkOutput, Labels, MdTableCell, MetricDb};
7
8type MetricName = String;
9type MetricsByName = HashMap<MetricName, Vec<(f64, Labels)>>;
10
11#[derive(Clone, Debug, Default)]
12pub struct GroupedMetrics {
13    /// "group" label => metrics with that "group" label, further grouped by metric name
14    pub by_group: HashMap<String, MetricsByName>,
15    pub ungrouped: MetricsByName,
16}
17
18#[derive(Clone, Debug, Default, Serialize, Deserialize)]
19pub struct AggregateMetrics {
20    /// "group" label => metric aggregate statistics
21    #[serde(flatten)]
22    pub by_group: HashMap<String, HashMap<MetricName, Stats>>,
23    /// In seconds
24    pub total_proof_time: MdTableCell,
25    /// In seconds (infinite parallelism)
26    pub total_par_proof_time: MdTableCell,
27    /// Per-group bounded parallel proof time in seconds
28    #[serde(skip)]
29    pub bounded_par_by_group: HashMap<String, MdTableCell>,
30}
31
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct BencherAggregateMetrics {
34    #[serde(flatten)]
35    pub by_group: HashMap<String, HashMap<String, BencherValue>>,
36    /// In seconds
37    pub total_proof_time: BencherValue,
38    /// In seconds
39    pub total_par_proof_time: BencherValue,
40}
41
42#[derive(Clone, Debug, Serialize, Deserialize)]
43pub struct Stats {
44    pub sum: MdTableCell,
45    pub max: MdTableCell,
46    pub min: MdTableCell,
47    pub avg: MdTableCell,
48    #[serde(skip)]
49    pub count: usize,
50}
51
52impl Default for Stats {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl Stats {
59    pub fn new() -> Self {
60        Self {
61            sum: MdTableCell::default(),
62            max: MdTableCell::default(),
63            min: MdTableCell::new(f64::MAX, None),
64            avg: MdTableCell::default(),
65            count: 0,
66        }
67    }
68    pub fn push(&mut self, value: f64) {
69        self.sum.val += value;
70        self.count += 1;
71        if value > self.max.val {
72            self.max.val = value;
73        }
74        if value < self.min.val {
75            self.min.val = value;
76        }
77    }
78
79    pub fn finalize(&mut self) {
80        assert!(self.count != 0);
81        self.avg.val = self.sum.val / self.count as f64;
82    }
83
84    pub fn set_diff(&mut self, prev: &Self) {
85        self.sum.diff = Some(self.sum.val - prev.sum.val);
86        self.max.diff = Some(self.max.val - prev.max.val);
87        self.min.diff = Some(self.min.val - prev.min.val);
88        self.avg.diff = Some(self.avg.val - prev.avg.val);
89    }
90}
91
92impl GroupedMetrics {
93    pub fn new(db: &MetricDb, group_label_name: &str) -> Result<Self> {
94        let mut by_group = HashMap::<String, MetricsByName>::new();
95        let mut ungrouped = MetricsByName::new();
96        for (labels, metrics) in db.flat_dict.iter() {
97            let group_name = labels.get(group_label_name);
98            if let Some(group_name) = group_name {
99                let group_entry = by_group.entry(group_name.to_string()).or_default();
100                let mut labels = labels.clone();
101                labels.remove(group_label_name);
102                for metric in metrics {
103                    group_entry
104                        .entry(metric.name.clone())
105                        .or_default()
106                        .push((metric.value, labels.clone()));
107                }
108            } else {
109                for metric in metrics {
110                    ungrouped
111                        .entry(metric.name.clone())
112                        .or_default()
113                        .push((metric.value, labels.clone()));
114                }
115            }
116        }
117        Ok(Self {
118            by_group,
119            ungrouped,
120        })
121    }
122
123    /// Validates that E1, metered, and preflight instruction counts all match each other
124    fn validate_instruction_counts(group_summaries: &HashMap<MetricName, Stats>) {
125        let e1_insns = group_summaries.get(EXECUTE_E1_INSNS_LABEL);
126        let metered_insns = group_summaries.get(EXECUTE_METERED_INSNS_LABEL);
127        let preflight_insns = group_summaries.get(EXECUTE_PREFLIGHT_INSNS_LABEL);
128
129        if let (Some(e1_insns), Some(preflight_insns)) = (e1_insns, preflight_insns) {
130            assert_eq!(e1_insns.sum.val as u64, preflight_insns.sum.val as u64);
131        }
132        if let (Some(e1_insns), Some(metered_insns)) = (e1_insns, metered_insns) {
133            assert_eq!(e1_insns.sum.val as u64, metered_insns.sum.val as u64);
134        }
135        if let (Some(metered_insns), Some(preflight_insns)) = (metered_insns, preflight_insns) {
136            assert_eq!(metered_insns.sum.val as u64, preflight_insns.sum.val as u64);
137        }
138    }
139
140    pub fn aggregate(&self, num_parallel: usize) -> AggregateMetrics {
141        let by_group: HashMap<String, _> = self
142            .by_group
143            .iter()
144            .map(|(group_name, metrics)| {
145                let group_summaries: HashMap<MetricName, Stats> = metrics
146                    .iter()
147                    .map(|(metric_name, metrics)| {
148                        let mut summary = Stats::new();
149                        for (value, _) in metrics {
150                            summary.push(*value);
151                        }
152                        summary.finalize();
153                        (metric_name.clone(), summary)
154                    })
155                    .collect();
156
157                if !group_name.contains("keygen") {
158                    Self::validate_instruction_counts(&group_summaries);
159                }
160
161                (group_name.clone(), group_summaries)
162            })
163            .collect();
164        let mut metrics = AggregateMetrics {
165            by_group,
166            ..Default::default()
167        };
168        metrics.compute_total();
169        metrics.bounded_par_by_group = self
170            .compute_bounded_par_times(num_parallel, &metrics.by_group)
171            .into_iter()
172            .map(|(k, v)| (k, MdTableCell::new(v, Some(0.0))))
173            .collect();
174
175        metrics
176    }
177
178    /// Compute per-group parallel proof time with bounded parallelism.
179    fn compute_bounded_par_times(
180        &self,
181        num_parallel: usize,
182        stats_by_group: &HashMap<String, HashMap<MetricName, Stats>>,
183    ) -> HashMap<String, f64> {
184        let mut per_group = HashMap::new();
185
186        for (group_name, metrics) in &self.by_group {
187            if group_name.contains("keygen") {
188                continue;
189            }
190
191            let mut group_time = 0.0;
192
193            // Add serial execution time for app_proof groups
194            if is_app_proof_group(group_name) {
195                if let Some(stats) = stats_by_group.get(group_name) {
196                    if let Some(metered) = stats.get(EXECUTE_METERED_TIME_LABEL) {
197                        group_time += metered.avg.val / 1000.0;
198                    }
199                    if let Some(e1) = stats.get(EXECUTE_E1_TIME_LABEL) {
200                        group_time += e1.avg.val / 1000.0;
201                    }
202                }
203            }
204
205            // Schedule proofs in parallel
206            if let Some(proof_times) = metrics.get(PROOF_TIME_LABEL) {
207                let times_s: Vec<f64> = proof_times.iter().map(|(ms, _)| ms / 1000.0).collect();
208                group_time += schedule_parallel(&times_s, num_parallel);
209            }
210
211            per_group.insert(group_name.clone(), group_time);
212        }
213
214        per_group
215    }
216}
217
218/// Round-robin assignment: proof i -> slot i % num_parallel. Returns max slot time.
219fn schedule_parallel(proof_times: &[f64], num_parallel: usize) -> f64 {
220    if proof_times.is_empty() || num_parallel == 0 {
221        return 0.0;
222    }
223
224    let mut slot_times = vec![0.0_f64; num_parallel];
225    for (i, duration) in proof_times.iter().enumerate() {
226        slot_times[i % num_parallel] += duration;
227    }
228    slot_times.iter().cloned().fold(0.0_f64, f64::max)
229}
230
231fn is_app_proof_group(name: &str) -> bool {
232    name != "leaf"
233        && name != "root"
234        && name != "halo2_outer"
235        && name != "halo2_wrapper"
236        && !name.starts_with("internal")
237}
238
239// A hacky way to order the groups for display.
240pub(crate) fn group_weight(name: &str) -> usize {
241    let label_prefix = ["leaf", "internal", "root", "halo2_outer", "halo2_wrapper"];
242    if name.contains("keygen") {
243        return label_prefix.len() + 1;
244    }
245    for (i, prefix) in label_prefix.iter().enumerate().rev() {
246        if name.starts_with(prefix) {
247            return i + 1;
248        }
249    }
250    0
251}
252
253impl AggregateMetrics {
254    pub fn compute_total(&mut self) {
255        let mut total_proof_time = MdTableCell::new(0.0, Some(0.0));
256        let mut total_par_proof_time = MdTableCell::new(0.0, Some(0.0));
257        for (group_name, metrics) in &self.by_group {
258            let stats = metrics.get(PROOF_TIME_LABEL);
259            let execute_metered_stats = metrics.get(EXECUTE_METERED_TIME_LABEL);
260            let execute_e1_stats = metrics.get(EXECUTE_E1_TIME_LABEL);
261            if stats.is_none() {
262                continue;
263            }
264            let stats = stats.unwrap_or_else(|| {
265                panic!("Missing proof time statistics for group '{group_name}'")
266            });
267            let mut sum = stats.sum;
268            let mut max = stats.max;
269            // convert ms to s
270            sum.val /= 1000.0;
271            max.val /= 1000.0;
272            if let Some(diff) = &mut sum.diff {
273                *diff /= 1000.0;
274            }
275            if let Some(diff) = &mut max.diff {
276                *diff /= 1000.0;
277            }
278            if !group_name.contains("keygen") {
279                // Proving time in keygen group is dummy and not part of total.
280                total_proof_time.val += sum.val;
281                *total_proof_time
282                    .diff
283                    .as_mut()
284                    .expect("total_proof_time.diff should be initialized") +=
285                    sum.diff.unwrap_or(0.0);
286                total_par_proof_time.val += max.val;
287                *total_par_proof_time
288                    .diff
289                    .as_mut()
290                    .expect("total_par_proof_time.diff should be initialized") +=
291                    max.diff.unwrap_or(0.0);
292
293                // Account for the serial execute_metered and execute_e1 for app outside of segments
294                if is_app_proof_group(group_name) {
295                    if let Some(execute_metered_stats) = execute_metered_stats {
296                        // For metered metrics without segment labels, we just use the value
297                        // directly Count is 1, so avg = sum = max = min =
298                        // value
299                        total_proof_time.val += execute_metered_stats.avg.val / 1000.0;
300                        total_par_proof_time.val += execute_metered_stats.avg.val / 1000.0;
301                        if let Some(diff) = execute_metered_stats.avg.diff {
302                            *total_proof_time
303                                .diff
304                                .as_mut()
305                                .expect("total_proof_time.diff should be initialized") +=
306                                diff / 1000.0;
307                            *total_par_proof_time
308                                .diff
309                                .as_mut()
310                                .expect("total_par_proof_time.diff should be initialized") +=
311                                diff / 1000.0;
312                        }
313                    }
314
315                    if let Some(execute_e1_stats) = execute_e1_stats {
316                        total_proof_time.val += execute_e1_stats.avg.val / 1000.0;
317                        total_par_proof_time.val += execute_e1_stats.avg.val / 1000.0;
318                        if let Some(diff) = execute_e1_stats.avg.diff {
319                            *total_proof_time
320                                .diff
321                                .as_mut()
322                                .expect("total_proof_time.diff should be initialized") +=
323                                diff / 1000.0;
324                            *total_par_proof_time
325                                .diff
326                                .as_mut()
327                                .expect("total_par_proof_time.diff should be initialized") +=
328                                diff / 1000.0;
329                        }
330                    }
331                }
332            }
333        }
334        self.total_proof_time = total_proof_time;
335        self.total_par_proof_time = total_par_proof_time;
336    }
337
338    pub fn set_diff(&mut self, prev: &Self) {
339        for (group_name, metrics) in self.by_group.iter_mut() {
340            if let Some(prev_metrics) = prev.by_group.get(group_name) {
341                for (metric_name, stats) in metrics.iter_mut() {
342                    if let Some(prev_stats) = prev_metrics.get(metric_name) {
343                        stats.set_diff(prev_stats);
344                    }
345                }
346            }
347        }
348        self.compute_total();
349    }
350
351    pub fn to_vec(&self) -> Vec<(String, HashMap<MetricName, Stats>)> {
352        let mut group_names: Vec<_> = self.by_group.keys().collect();
353        group_names.sort_by(|a, b| {
354            let a_wt = group_weight(a);
355            let b_wt = group_weight(b);
356            if a_wt == b_wt {
357                a.cmp(b)
358            } else {
359                a_wt.cmp(&b_wt)
360            }
361        });
362        group_names
363            .into_iter()
364            .map(|group_name| {
365                let key = group_name.clone();
366                let value = self
367                    .by_group
368                    .get(group_name)
369                    .unwrap_or_else(|| panic!("Group '{group_name}' should exist in by_group map"))
370                    .clone();
371                (key, value)
372            })
373            .collect()
374    }
375
376    pub fn to_bencher_metrics(&self) -> BencherAggregateMetrics {
377        let by_group = self
378            .by_group
379            .iter()
380            .map(|(group_name, metrics)| {
381                let metrics = metrics
382                    .iter()
383                    .filter(|(_, stats)| stats.avg.val.is_finite() && stats.sum.val.is_finite())
384                    .flat_map(|(metric_name, stats)| {
385                        [
386                            (format!("{metric_name}::sum"), stats.sum.into()),
387                            (
388                                metric_name.clone(),
389                                BencherValue {
390                                    value: stats.avg.val,
391                                    lower_value: Some(stats.min.val),
392                                    upper_value: Some(stats.max.val),
393                                },
394                            ),
395                        ]
396                    })
397                    .collect();
398                (group_name.clone(), metrics)
399            })
400            .collect();
401        let total_proof_time = self.total_proof_time.into();
402        let total_par_proof_time = self.total_par_proof_time.into();
403        BencherAggregateMetrics {
404            by_group,
405            total_proof_time,
406            total_par_proof_time,
407        }
408    }
409
410    pub fn write_markdown(
411        &self,
412        writer: &mut impl Write,
413        metric_names: &[&str],
414        num_parallel: usize,
415    ) -> Result<()> {
416        self.write_summary_markdown(writer, num_parallel)?;
417        writeln!(writer)?;
418
419        let metric_names = metric_names.to_vec();
420        for (group_name, summaries) in self.to_vec() {
421            writeln!(writer, "| {group_name} |||||")?;
422            writeln!(writer, "|:---|---:|---:|---:|---:|")?;
423            writeln!(writer, "|metric|avg|sum|max|min|")?;
424            let names = if metric_names.is_empty() {
425                summaries.keys().map(|s| s.as_str()).collect()
426            } else {
427                metric_names.clone()
428            };
429            for metric_name in names {
430                let summary = summaries.get(metric_name);
431                if let Some(summary) = summary {
432                    // Special handling for execute_metered metrics (not aggregated across segments
433                    // in the app proof case)
434                    if metric_name == EXECUTE_METERED_TIME_LABEL
435                        && group_name != "leaf"
436                        && group_name != "root"
437                        && group_name != "halo2_outer"
438                        && group_name != "halo2_wrapper"
439                        && !group_name.starts_with("internal")
440                    {
441                        writeln!(
442                            writer,
443                            "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
444                            metric_name, summary.avg, "-", "-", "-",
445                        )?;
446                    } else if metric_name == EXECUTE_E1_INSN_MI_S_LABEL
447                        || metric_name == EXECUTE_PREFLIGHT_INSN_MI_S_LABEL
448                        || metric_name == EXECUTE_METERED_INSN_MI_S_LABEL
449                    {
450                        // skip sum because it is misleading
451                        writeln!(
452                            writer,
453                            "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
454                            metric_name, summary.avg, "-", summary.max, summary.min,
455                        )?;
456                    } else {
457                        writeln!(
458                            writer,
459                            "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
460                            metric_name, summary.avg, summary.sum, summary.max, summary.min,
461                        )?;
462                    }
463                }
464            }
465            writeln!(writer)?;
466        }
467        writeln!(writer)?;
468
469        Ok(())
470    }
471
472    fn write_summary_markdown(&self, writer: &mut impl Write, num_parallel: usize) -> Result<()> {
473        writeln!(
474            writer,
475            "| Summary | Proof Time (s) | Parallel Proof Time (s) | Parallel Proof Time ({} provers) (s) |",
476            num_parallel
477        )?;
478        writeln!(writer, "|:---|---:|---:|---:|")?;
479        let mut rows = Vec::new();
480        for (group_name, summaries) in self.to_vec() {
481            if group_name.contains("keygen") {
482                continue;
483            }
484            let stats = summaries.get(PROOF_TIME_LABEL);
485            if stats.is_none() {
486                continue;
487            }
488            let stats = stats.unwrap_or_else(|| {
489                panic!("Missing proof time statistics for group '{group_name}'")
490            });
491            let mut sum = stats.sum;
492            let mut max = stats.max;
493            // convert ms to s
494            sum.val /= 1000.0;
495            max.val /= 1000.0;
496            if let Some(diff) = &mut sum.diff {
497                *diff /= 1000.0;
498            }
499            if let Some(diff) = &mut max.diff {
500                *diff /= 1000.0;
501            }
502            // Add serial execution time for app_proof groups
503            if is_app_proof_group(&group_name) {
504                if let Some(metered) = summaries.get(EXECUTE_METERED_TIME_LABEL) {
505                    sum.val += metered.avg.val / 1000.0;
506                    max.val += metered.avg.val / 1000.0;
507                }
508                if let Some(e1) = summaries.get(EXECUTE_E1_TIME_LABEL) {
509                    sum.val += e1.avg.val / 1000.0;
510                    max.val += e1.avg.val / 1000.0;
511                }
512            }
513            rows.push((group_name, sum, max));
514        }
515        let total_bounded: f64 = self.bounded_par_by_group.values().map(|v| v.val).sum();
516        writeln!(
517            writer,
518            "| Total | {} | {} | {:.2} |",
519            self.total_proof_time, self.total_par_proof_time, total_bounded
520        )?;
521        for (group_name, proof_time, par_proof_time) in rows {
522            let bounded = self
523                .bounded_par_by_group
524                .get(&group_name)
525                .map(|v| v.to_string())
526                .unwrap_or_else(|| "-".to_string());
527            writeln!(
528                writer,
529                "| {group_name} | {proof_time} | {par_proof_time} | {bounded} |"
530            )?;
531        }
532        writeln!(writer)?;
533        Ok(())
534    }
535
536    pub fn name(&self) -> String {
537        // A hacky way to determine the app name
538        self.by_group
539            .keys()
540            .find(|k| group_weight(k) == 0)
541            .unwrap_or_else(|| {
542                self.by_group
543                    .keys()
544                    .next()
545                    .expect("by_group should contain at least one group")
546            })
547            .clone()
548    }
549}
550
551impl BenchmarkOutput {
552    pub fn insert(&mut self, name: &str, metrics: BencherAggregateMetrics) {
553        for (group_name, metrics) in metrics.by_group {
554            self.by_name
555                .entry(format!("{name}::{group_name}"))
556                .or_default()
557                .extend(metrics);
558        }
559        if let Some(e) = self.by_name.insert(
560            name.to_owned(),
561            HashMap::from_iter([
562                ("total_proof_time".to_owned(), metrics.total_proof_time),
563                (
564                    "total_par_proof_time".to_owned(),
565                    metrics.total_par_proof_time,
566                ),
567            ]),
568        ) {
569            panic!("Duplicate metric: {e:?}");
570        }
571    }
572}
573
574pub const PROOF_TIME_LABEL: &str = "total_proof_time_ms";
575pub const MAIN_CELLS_USED_LABEL: &str = "main_cells_used";
576pub const TOTAL_CELLS_USED_LABEL: &str = "total_cells_used";
577pub const EXECUTE_E1_INSNS_LABEL: &str = "execute_e1_insns";
578pub const EXECUTE_METERED_INSNS_LABEL: &str = "execute_metered_insns";
579pub const EXECUTE_PREFLIGHT_INSNS_LABEL: &str = "execute_preflight_insns";
580pub const EXECUTE_E1_TIME_LABEL: &str = "execute_e1_time_ms";
581pub const EXECUTE_E1_INSN_MI_S_LABEL: &str = "execute_e1_insn_mi/s";
582pub const EXECUTE_METERED_TIME_LABEL: &str = "execute_metered_time_ms";
583pub const EXECUTE_METERED_INSN_MI_S_LABEL: &str = "execute_metered_insn_mi/s";
584pub const EXECUTE_PREFLIGHT_TIME_LABEL: &str = "execute_preflight_time_ms";
585pub const EXECUTE_PREFLIGHT_INSN_MI_S_LABEL: &str = "execute_preflight_insn_mi/s";
586pub const TRACE_GEN_TIME_LABEL: &str = "trace_gen_time_ms";
587pub const MEM_FIN_TIME_LABEL: &str = "memory_finalize_time_ms";
588pub const BOUNDARY_FIN_TIME_LABEL: &str = "boundary_finalize_time_ms";
589pub const MERKLE_FIN_TIME_LABEL: &str = "merkle_finalize_time_ms";
590pub const PROVE_EXCL_TRACE_TIME_LABEL: &str = "stark_prove_excluding_trace_time_ms";
591
592pub const VM_METRIC_NAMES: &[&str] = &[
593    PROOF_TIME_LABEL,
594    MAIN_CELLS_USED_LABEL,
595    TOTAL_CELLS_USED_LABEL,
596    EXECUTE_E1_TIME_LABEL,
597    EXECUTE_E1_INSN_MI_S_LABEL,
598    EXECUTE_METERED_TIME_LABEL,
599    EXECUTE_METERED_INSN_MI_S_LABEL,
600    EXECUTE_PREFLIGHT_INSNS_LABEL,
601    EXECUTE_PREFLIGHT_TIME_LABEL,
602    EXECUTE_PREFLIGHT_INSN_MI_S_LABEL,
603    TRACE_GEN_TIME_LABEL,
604    MEM_FIN_TIME_LABEL,
605    BOUNDARY_FIN_TIME_LABEL,
606    MERKLE_FIN_TIME_LABEL,
607    PROVE_EXCL_TRACE_TIME_LABEL,
608    "main_trace_commit_time_ms",
609    "generate_perm_trace_time_ms",
610    "perm_trace_commit_time_ms",
611    "quotient_poly_compute_time_ms",
612    "quotient_poly_commit_time_ms",
613    "pcs_opening_time_ms",
614];