openvm_prof/
aggregate.rs

1use std::{collections::HashMap, io::Write};
2
3use eyre::Result;
4use serde::{Deserialize, Serialize};
5
6use crate::types::{BencherValue, BenchmarkOutput, Labels, MdTableCell, MetricDb};
7
8type MetricName = String;
9type MetricsByName = HashMap<MetricName, Vec<(f64, Labels)>>;
10
11#[derive(Clone, Debug, Default)]
12pub struct GroupedMetrics {
13    /// "group" label => metrics with that "group" label, further grouped by metric name
14    pub by_group: HashMap<String, MetricsByName>,
15    pub ungrouped: MetricsByName,
16}
17
18#[derive(Clone, Debug, Default, Serialize, Deserialize)]
19pub struct AggregateMetrics {
20    /// "group" label => metric aggregate statistics
21    #[serde(flatten)]
22    pub by_group: HashMap<String, HashMap<MetricName, Stats>>,
23    /// In seconds
24    pub total_proof_time: MdTableCell,
25    /// In seconds
26    pub total_par_proof_time: MdTableCell,
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct BencherAggregateMetrics {
31    #[serde(flatten)]
32    pub by_group: HashMap<String, HashMap<String, BencherValue>>,
33    /// In seconds
34    pub total_proof_time: BencherValue,
35    /// In seconds
36    pub total_par_proof_time: BencherValue,
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize)]
40pub struct Stats {
41    pub sum: MdTableCell,
42    pub max: MdTableCell,
43    pub min: MdTableCell,
44    pub avg: MdTableCell,
45    #[serde(skip)]
46    pub count: usize,
47}
48
49impl Default for Stats {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl Stats {
56    pub fn new() -> Self {
57        Self {
58            sum: MdTableCell::default(),
59            max: MdTableCell::default(),
60            min: MdTableCell::new(f64::MAX, None),
61            avg: MdTableCell::default(),
62            count: 0,
63        }
64    }
65    pub fn push(&mut self, value: f64) {
66        self.sum.val += value;
67        self.count += 1;
68        if value > self.max.val {
69            self.max.val = value;
70        }
71        if value < self.min.val {
72            self.min.val = value;
73        }
74    }
75
76    pub fn finalize(&mut self) {
77        assert!(self.count != 0);
78        self.avg.val = self.sum.val / self.count as f64;
79    }
80
81    pub fn set_diff(&mut self, prev: &Self) {
82        self.sum.diff = Some(self.sum.val - prev.sum.val);
83        self.max.diff = Some(self.max.val - prev.max.val);
84        self.min.diff = Some(self.min.val - prev.min.val);
85        self.avg.diff = Some(self.avg.val - prev.avg.val);
86    }
87}
88
89impl GroupedMetrics {
90    pub fn new(db: &MetricDb, group_label_name: &str) -> Result<Self> {
91        let mut by_group = HashMap::<String, MetricsByName>::new();
92        let mut ungrouped = MetricsByName::new();
93        for (labels, metrics) in db.flat_dict.iter() {
94            let group_name = labels.get(group_label_name);
95            if let Some(group_name) = group_name {
96                let group_entry = by_group.entry(group_name.to_string()).or_default();
97                let mut labels = labels.clone();
98                labels.remove(group_label_name);
99                for metric in metrics {
100                    group_entry
101                        .entry(metric.name.clone())
102                        .or_default()
103                        .push((metric.value, labels.clone()));
104                }
105            } else {
106                for metric in metrics {
107                    ungrouped
108                        .entry(metric.name.clone())
109                        .or_default()
110                        .push((metric.value, labels.clone()));
111                }
112            }
113        }
114        Ok(Self {
115            by_group,
116            ungrouped,
117        })
118    }
119
120    pub fn aggregate(&self) -> AggregateMetrics {
121        let by_group: HashMap<String, _> = self
122            .by_group
123            .iter()
124            .map(|(group_name, metrics)| {
125                let group_summaries: HashMap<MetricName, Stats> = metrics
126                    .iter()
127                    .map(|(metric_name, metrics)| {
128                        let mut summary = Stats::new();
129                        for (value, _) in metrics {
130                            summary.push(*value);
131                        }
132                        summary.finalize();
133                        (metric_name.clone(), summary)
134                    })
135                    .collect();
136                (group_name.clone(), group_summaries)
137            })
138            .collect();
139        let mut metrics = AggregateMetrics {
140            by_group,
141            ..Default::default()
142        };
143        metrics.compute_total();
144        metrics
145    }
146}
147
148// A hacky way to order the groups for display.
149pub(crate) fn group_weight(name: &str) -> usize {
150    let label_prefix = ["leaf", "internal", "root", "halo2_outer", "halo2_wrapper"];
151    if name.contains("keygen") {
152        return label_prefix.len() + 1;
153    }
154    for (i, prefix) in label_prefix.iter().enumerate().rev() {
155        if name.starts_with(prefix) {
156            return i + 1;
157        }
158    }
159    0
160}
161
162impl AggregateMetrics {
163    pub fn compute_total(&mut self) {
164        let mut total_proof_time = MdTableCell::new(0.0, Some(0.0));
165        let mut total_par_proof_time = MdTableCell::new(0.0, Some(0.0));
166        for (group_name, metrics) in &self.by_group {
167            let stats = metrics.get(PROOF_TIME_LABEL);
168            if stats.is_none() {
169                continue;
170            }
171            let stats = stats.unwrap();
172            let mut sum = stats.sum;
173            let mut max = stats.max;
174            // convert ms to s
175            sum.val /= 1000.0;
176            max.val /= 1000.0;
177            if let Some(diff) = &mut sum.diff {
178                *diff /= 1000.0;
179            }
180            if let Some(diff) = &mut max.diff {
181                *diff /= 1000.0;
182            }
183            if !group_name.contains("keygen") {
184                // Proving time in keygen group is dummy and not part of total.
185                total_proof_time.val += sum.val;
186                *total_proof_time.diff.as_mut().unwrap() += sum.diff.unwrap_or(0.0);
187                total_par_proof_time.val += max.val;
188                *total_par_proof_time.diff.as_mut().unwrap() += max.diff.unwrap_or(0.0);
189            }
190        }
191        self.total_proof_time = total_proof_time;
192        self.total_par_proof_time = total_par_proof_time;
193    }
194
195    pub fn set_diff(&mut self, prev: &Self) {
196        for (group_name, metrics) in self.by_group.iter_mut() {
197            if let Some(prev_metrics) = prev.by_group.get(group_name) {
198                for (metric_name, stats) in metrics.iter_mut() {
199                    if let Some(prev_stats) = prev_metrics.get(metric_name) {
200                        stats.set_diff(prev_stats);
201                    }
202                }
203            }
204        }
205        self.compute_total();
206    }
207
208    pub fn to_vec(&self) -> Vec<(String, HashMap<MetricName, Stats>)> {
209        let mut group_names: Vec<_> = self.by_group.keys().collect();
210        group_names.sort_by(|a, b| {
211            let a_wt = group_weight(a);
212            let b_wt = group_weight(b);
213            if a_wt == b_wt {
214                a.cmp(b)
215            } else {
216                a_wt.cmp(&b_wt)
217            }
218        });
219        group_names
220            .into_iter()
221            .map(|group_name| {
222                let key = group_name.clone();
223                let value = self.by_group.get(group_name).unwrap().clone();
224                (key, value)
225            })
226            .collect()
227    }
228
229    pub fn to_bencher_metrics(&self) -> BencherAggregateMetrics {
230        let by_group = self
231            .by_group
232            .iter()
233            .map(|(group_name, metrics)| {
234                let metrics = metrics
235                    .iter()
236                    .flat_map(|(metric_name, stats)| {
237                        [
238                            (format!("{metric_name}::sum"), stats.sum.into()),
239                            (
240                                metric_name.clone(),
241                                BencherValue {
242                                    value: stats.avg.val,
243                                    lower_value: Some(stats.min.val),
244                                    upper_value: Some(stats.max.val),
245                                },
246                            ),
247                        ]
248                    })
249                    .collect();
250                (group_name.clone(), metrics)
251            })
252            .collect();
253        let total_proof_time = self.total_proof_time.into();
254        let total_par_proof_time = self.total_par_proof_time.into();
255        BencherAggregateMetrics {
256            by_group,
257            total_proof_time,
258            total_par_proof_time,
259        }
260    }
261
262    pub fn write_markdown(&self, writer: &mut impl Write, metric_names: &[&str]) -> Result<()> {
263        self.write_summary_markdown(writer)?;
264        writeln!(writer)?;
265
266        let metric_names = metric_names.to_vec();
267        for (group_name, summaries) in self.to_vec() {
268            writeln!(writer, "| {} |||||", group_name)?;
269            writeln!(writer, "|:---|---:|---:|---:|---:|")?;
270            writeln!(writer, "|metric|avg|sum|max|min|")?;
271            let names = if metric_names.is_empty() {
272                summaries.keys().map(|s| s.as_str()).collect()
273            } else {
274                metric_names.clone()
275            };
276            for metric_name in names {
277                let summary = summaries.get(metric_name);
278                if let Some(summary) = summary {
279                    writeln!(
280                        writer,
281                        "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
282                        metric_name, summary.avg, summary.sum, summary.max, summary.min,
283                    )?;
284                }
285            }
286            writeln!(writer)?;
287        }
288        writeln!(writer)?;
289
290        Ok(())
291    }
292
293    fn write_summary_markdown(&self, writer: &mut impl Write) -> Result<()> {
294        writeln!(
295            writer,
296            "| Summary | Proof Time (s) | Parallel Proof Time (s) |"
297        )?;
298        writeln!(writer, "|:---|---:|---:|")?;
299        let mut rows = Vec::new();
300        for (group_name, summaries) in self.to_vec() {
301            let stats = summaries.get(PROOF_TIME_LABEL);
302            if stats.is_none() {
303                continue;
304            }
305            let stats = stats.unwrap();
306            let mut sum = stats.sum;
307            let mut max = stats.max;
308            // convert ms to s
309            sum.val /= 1000.0;
310            max.val /= 1000.0;
311            if let Some(diff) = &mut sum.diff {
312                *diff /= 1000.0;
313            }
314            if let Some(diff) = &mut max.diff {
315                *diff /= 1000.0;
316            }
317            rows.push((group_name, sum, max));
318        }
319        writeln!(
320            writer,
321            "| Total | {} | {} |",
322            self.total_proof_time, self.total_par_proof_time
323        )?;
324        for (group_name, proof_time, par_proof_time) in rows {
325            writeln!(writer, "| {group_name} | {proof_time} | {par_proof_time} |")?;
326        }
327        writeln!(writer)?;
328        Ok(())
329    }
330
331    pub fn name(&self) -> String {
332        // A hacky way to determine the app name
333        self.by_group
334            .keys()
335            .find(|k| group_weight(k) == 0)
336            .unwrap_or_else(|| self.by_group.keys().next().unwrap())
337            .clone()
338    }
339}
340
341impl BenchmarkOutput {
342    pub fn insert(&mut self, name: &str, metrics: BencherAggregateMetrics) {
343        for (group_name, metrics) in metrics.by_group {
344            self.by_name
345                .entry(format!("{name}::{group_name}"))
346                .or_default()
347                .extend(metrics);
348        }
349        if let Some(e) = self.by_name.insert(
350            name.to_owned(),
351            HashMap::from_iter([
352                ("total_proof_time".to_owned(), metrics.total_proof_time),
353                (
354                    "total_par_proof_time".to_owned(),
355                    metrics.total_par_proof_time,
356                ),
357            ]),
358        ) {
359            panic!("Duplicate metric: {e:?}");
360        }
361    }
362}
363
364pub const PROOF_TIME_LABEL: &str = "total_proof_time_ms";
365pub const CELLS_USED_LABEL: &str = "main_cells_used";
366pub const CYCLES_LABEL: &str = "total_cycles";
367pub const EXECUTE_TIME_LABEL: &str = "execute_time_ms";
368pub const TRACE_GEN_TIME_LABEL: &str = "trace_gen_time_ms";
369pub const PROVE_EXCL_TRACE_TIME_LABEL: &str = "stark_prove_excluding_trace_time_ms";
370
371pub const VM_METRIC_NAMES: &[&str] = &[
372    PROOF_TIME_LABEL,
373    CELLS_USED_LABEL,
374    CYCLES_LABEL,
375    EXECUTE_TIME_LABEL,
376    TRACE_GEN_TIME_LABEL,
377    PROVE_EXCL_TRACE_TIME_LABEL,
378    "main_trace_commit_time_ms",
379    "generate_perm_trace_time_ms",
380    "perm_trace_commit_time_ms",
381    "quotient_poly_compute_time_ms",
382    "quotient_poly_commit_time_ms",
383    "pcs_opening_time_ms",
384];