1use std::{collections::HashMap, io::Write};
2
3use eyre::Result;
4use serde::{Deserialize, Serialize};
5
6use crate::types::{BencherValue, BenchmarkOutput, Labels, MdTableCell, MetricDb};
7
8type MetricName = String;
9type MetricsByName = HashMap<MetricName, Vec<(f64, Labels)>>;
10
11#[derive(Clone, Debug, Default)]
12pub struct GroupedMetrics {
13 pub by_group: HashMap<String, MetricsByName>,
15 pub ungrouped: MetricsByName,
16}
17
18#[derive(Clone, Debug, Default, Serialize, Deserialize)]
19pub struct AggregateMetrics {
20 #[serde(flatten)]
22 pub by_group: HashMap<String, HashMap<MetricName, Stats>>,
23 pub total_proof_time: MdTableCell,
25 pub total_par_proof_time: MdTableCell,
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct BencherAggregateMetrics {
31 #[serde(flatten)]
32 pub by_group: HashMap<String, HashMap<String, BencherValue>>,
33 pub total_proof_time: BencherValue,
35 pub total_par_proof_time: BencherValue,
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize)]
40pub struct Stats {
41 pub sum: MdTableCell,
42 pub max: MdTableCell,
43 pub min: MdTableCell,
44 pub avg: MdTableCell,
45 #[serde(skip)]
46 pub count: usize,
47}
48
49impl Default for Stats {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl Stats {
56 pub fn new() -> Self {
57 Self {
58 sum: MdTableCell::default(),
59 max: MdTableCell::default(),
60 min: MdTableCell::new(f64::MAX, None),
61 avg: MdTableCell::default(),
62 count: 0,
63 }
64 }
65 pub fn push(&mut self, value: f64) {
66 self.sum.val += value;
67 self.count += 1;
68 if value > self.max.val {
69 self.max.val = value;
70 }
71 if value < self.min.val {
72 self.min.val = value;
73 }
74 }
75
76 pub fn finalize(&mut self) {
77 assert!(self.count != 0);
78 self.avg.val = self.sum.val / self.count as f64;
79 }
80
81 pub fn set_diff(&mut self, prev: &Self) {
82 self.sum.diff = Some(self.sum.val - prev.sum.val);
83 self.max.diff = Some(self.max.val - prev.max.val);
84 self.min.diff = Some(self.min.val - prev.min.val);
85 self.avg.diff = Some(self.avg.val - prev.avg.val);
86 }
87}
88
89impl GroupedMetrics {
90 pub fn new(db: &MetricDb, group_label_name: &str) -> Result<Self> {
91 let mut by_group = HashMap::<String, MetricsByName>::new();
92 let mut ungrouped = MetricsByName::new();
93 for (labels, metrics) in db.flat_dict.iter() {
94 let group_name = labels.get(group_label_name);
95 if let Some(group_name) = group_name {
96 let group_entry = by_group.entry(group_name.to_string()).or_default();
97 let mut labels = labels.clone();
98 labels.remove(group_label_name);
99 for metric in metrics {
100 group_entry
101 .entry(metric.name.clone())
102 .or_default()
103 .push((metric.value, labels.clone()));
104 }
105 } else {
106 for metric in metrics {
107 ungrouped
108 .entry(metric.name.clone())
109 .or_default()
110 .push((metric.value, labels.clone()));
111 }
112 }
113 }
114 Ok(Self {
115 by_group,
116 ungrouped,
117 })
118 }
119
120 fn validate_instruction_counts(group_summaries: &HashMap<MetricName, Stats>) {
122 let e1_insns = group_summaries.get(EXECUTE_E1_INSNS_LABEL);
123 let metered_insns = group_summaries.get(EXECUTE_METERED_INSNS_LABEL);
124 let preflight_insns = group_summaries.get(EXECUTE_PREFLIGHT_INSNS_LABEL);
125
126 if let (Some(e1_insns), Some(preflight_insns)) = (e1_insns, preflight_insns) {
127 assert_eq!(e1_insns.sum.val as u64, preflight_insns.sum.val as u64);
128 }
129 if let (Some(e1_insns), Some(metered_insns)) = (e1_insns, metered_insns) {
130 assert_eq!(e1_insns.sum.val as u64, metered_insns.sum.val as u64);
131 }
132 if let (Some(metered_insns), Some(preflight_insns)) = (metered_insns, preflight_insns) {
133 assert_eq!(metered_insns.sum.val as u64, preflight_insns.sum.val as u64);
134 }
135 }
136
137 pub fn aggregate(&self) -> AggregateMetrics {
138 let by_group: HashMap<String, _> = self
139 .by_group
140 .iter()
141 .map(|(group_name, metrics)| {
142 let group_summaries: HashMap<MetricName, Stats> = metrics
143 .iter()
144 .map(|(metric_name, metrics)| {
145 let mut summary = Stats::new();
146 for (value, _) in metrics {
147 summary.push(*value);
148 }
149 summary.finalize();
150 (metric_name.clone(), summary)
151 })
152 .collect();
153
154 if !group_name.contains("keygen") {
155 Self::validate_instruction_counts(&group_summaries);
156 }
157
158 (group_name.clone(), group_summaries)
159 })
160 .collect();
161 let mut metrics = AggregateMetrics {
162 by_group,
163 ..Default::default()
164 };
165 metrics.compute_total();
166 metrics
167 }
168}
169
170pub(crate) fn group_weight(name: &str) -> usize {
172 let label_prefix = ["leaf", "internal", "root", "halo2_outer", "halo2_wrapper"];
173 if name.contains("keygen") {
174 return label_prefix.len() + 1;
175 }
176 for (i, prefix) in label_prefix.iter().enumerate().rev() {
177 if name.starts_with(prefix) {
178 return i + 1;
179 }
180 }
181 0
182}
183
184impl AggregateMetrics {
185 pub fn compute_total(&mut self) {
186 let mut total_proof_time = MdTableCell::new(0.0, Some(0.0));
187 let mut total_par_proof_time = MdTableCell::new(0.0, Some(0.0));
188 for (group_name, metrics) in &self.by_group {
189 let stats = metrics.get(PROOF_TIME_LABEL);
190 let execute_metered_stats = metrics.get(EXECUTE_METERED_TIME_LABEL);
191 let execute_e1_stats = metrics.get(EXECUTE_E1_TIME_LABEL);
192 if stats.is_none() {
193 continue;
194 }
195 let stats = stats.unwrap_or_else(|| {
196 panic!("Missing proof time statistics for group '{}'", group_name)
197 });
198 let mut sum = stats.sum;
199 let mut max = stats.max;
200 sum.val /= 1000.0;
202 max.val /= 1000.0;
203 if let Some(diff) = &mut sum.diff {
204 *diff /= 1000.0;
205 }
206 if let Some(diff) = &mut max.diff {
207 *diff /= 1000.0;
208 }
209 if !group_name.contains("keygen") {
210 total_proof_time.val += sum.val;
212 *total_proof_time
213 .diff
214 .as_mut()
215 .expect("total_proof_time.diff should be initialized") +=
216 sum.diff.unwrap_or(0.0);
217 total_par_proof_time.val += max.val;
218 *total_par_proof_time
219 .diff
220 .as_mut()
221 .expect("total_par_proof_time.diff should be initialized") +=
222 max.diff.unwrap_or(0.0);
223
224 if group_name != "leaf"
226 && group_name != "root"
227 && group_name != "halo2_outer"
228 && group_name != "halo2_wrapper"
229 && !group_name.starts_with("internal")
230 {
231 if let Some(execute_metered_stats) = execute_metered_stats {
232 total_proof_time.val += execute_metered_stats.avg.val / 1000.0;
236 total_par_proof_time.val += execute_metered_stats.avg.val / 1000.0;
237 if let Some(diff) = execute_metered_stats.avg.diff {
238 *total_proof_time
239 .diff
240 .as_mut()
241 .expect("total_proof_time.diff should be initialized") +=
242 diff / 1000.0;
243 *total_par_proof_time
244 .diff
245 .as_mut()
246 .expect("total_par_proof_time.diff should be initialized") +=
247 diff / 1000.0;
248 }
249 }
250
251 if let Some(execute_e1_stats) = execute_e1_stats {
252 total_proof_time.val += execute_e1_stats.avg.val / 1000.0;
253 total_par_proof_time.val += execute_e1_stats.avg.val / 1000.0;
254 if let Some(diff) = execute_e1_stats.avg.diff {
255 *total_proof_time
256 .diff
257 .as_mut()
258 .expect("total_proof_time.diff should be initialized") +=
259 diff / 1000.0;
260 *total_par_proof_time
261 .diff
262 .as_mut()
263 .expect("total_par_proof_time.diff should be initialized") +=
264 diff / 1000.0;
265 }
266 }
267 }
268 }
269 }
270 self.total_proof_time = total_proof_time;
271 self.total_par_proof_time = total_par_proof_time;
272 }
273
274 pub fn set_diff(&mut self, prev: &Self) {
275 for (group_name, metrics) in self.by_group.iter_mut() {
276 if let Some(prev_metrics) = prev.by_group.get(group_name) {
277 for (metric_name, stats) in metrics.iter_mut() {
278 if let Some(prev_stats) = prev_metrics.get(metric_name) {
279 stats.set_diff(prev_stats);
280 }
281 }
282 }
283 }
284 self.compute_total();
285 }
286
287 pub fn to_vec(&self) -> Vec<(String, HashMap<MetricName, Stats>)> {
288 let mut group_names: Vec<_> = self.by_group.keys().collect();
289 group_names.sort_by(|a, b| {
290 let a_wt = group_weight(a);
291 let b_wt = group_weight(b);
292 if a_wt == b_wt {
293 a.cmp(b)
294 } else {
295 a_wt.cmp(&b_wt)
296 }
297 });
298 group_names
299 .into_iter()
300 .map(|group_name| {
301 let key = group_name.clone();
302 let value = self
303 .by_group
304 .get(group_name)
305 .unwrap_or_else(|| {
306 panic!("Group '{}' should exist in by_group map", group_name)
307 })
308 .clone();
309 (key, value)
310 })
311 .collect()
312 }
313
314 pub fn to_bencher_metrics(&self) -> BencherAggregateMetrics {
315 let by_group = self
316 .by_group
317 .iter()
318 .map(|(group_name, metrics)| {
319 let metrics = metrics
320 .iter()
321 .filter(|(_, stats)| stats.avg.val.is_finite() && stats.sum.val.is_finite())
322 .flat_map(|(metric_name, stats)| {
323 [
324 (format!("{metric_name}::sum"), stats.sum.into()),
325 (
326 metric_name.clone(),
327 BencherValue {
328 value: stats.avg.val,
329 lower_value: Some(stats.min.val),
330 upper_value: Some(stats.max.val),
331 },
332 ),
333 ]
334 })
335 .collect();
336 (group_name.clone(), metrics)
337 })
338 .collect();
339 let total_proof_time = self.total_proof_time.into();
340 let total_par_proof_time = self.total_par_proof_time.into();
341 BencherAggregateMetrics {
342 by_group,
343 total_proof_time,
344 total_par_proof_time,
345 }
346 }
347
348 pub fn write_markdown(&self, writer: &mut impl Write, metric_names: &[&str]) -> Result<()> {
349 self.write_summary_markdown(writer)?;
350 writeln!(writer)?;
351
352 let metric_names = metric_names.to_vec();
353 for (group_name, summaries) in self.to_vec() {
354 writeln!(writer, "| {} |||||", group_name)?;
355 writeln!(writer, "|:---|---:|---:|---:|---:|")?;
356 writeln!(writer, "|metric|avg|sum|max|min|")?;
357 let names = if metric_names.is_empty() {
358 summaries.keys().map(|s| s.as_str()).collect()
359 } else {
360 metric_names.clone()
361 };
362 for metric_name in names {
363 let summary = summaries.get(metric_name);
364 if let Some(summary) = summary {
365 if metric_name == EXECUTE_METERED_TIME_LABEL
368 && group_name != "leaf"
369 && group_name != "root"
370 && group_name != "halo2_outer"
371 && group_name != "halo2_wrapper"
372 && !group_name.starts_with("internal")
373 {
374 writeln!(
375 writer,
376 "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
377 metric_name, summary.avg, "-", "-", "-",
378 )?;
379 } else if metric_name == EXECUTE_E1_INSN_MI_S_LABEL
380 || metric_name == EXECUTE_PREFLIGHT_INSN_MI_S_LABEL
381 || metric_name == EXECUTE_METERED_INSN_MI_S_LABEL
382 {
383 writeln!(
385 writer,
386 "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
387 metric_name, summary.avg, "-", summary.max, summary.min,
388 )?;
389 } else {
390 writeln!(
391 writer,
392 "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
393 metric_name, summary.avg, summary.sum, summary.max, summary.min,
394 )?;
395 }
396 }
397 }
398 writeln!(writer)?;
399 }
400 writeln!(writer)?;
401
402 Ok(())
403 }
404
405 fn write_summary_markdown(&self, writer: &mut impl Write) -> Result<()> {
406 writeln!(
407 writer,
408 "| Summary | Proof Time (s) | Parallel Proof Time (s) |"
409 )?;
410 writeln!(writer, "|:---|---:|---:|")?;
411 let mut rows = Vec::new();
412 for (group_name, summaries) in self.to_vec() {
413 if group_name.contains("keygen") {
414 continue;
415 }
416 let stats = summaries.get(PROOF_TIME_LABEL);
417 if stats.is_none() {
418 continue;
419 }
420 let stats = stats.unwrap_or_else(|| {
421 panic!("Missing proof time statistics for group '{}'", group_name)
422 });
423 let mut sum = stats.sum;
424 let mut max = stats.max;
425 sum.val /= 1000.0;
427 max.val /= 1000.0;
428 if let Some(diff) = &mut sum.diff {
429 *diff /= 1000.0;
430 }
431 if let Some(diff) = &mut max.diff {
432 *diff /= 1000.0;
433 }
434 rows.push((group_name, sum, max));
435 }
436 writeln!(
437 writer,
438 "| Total | {} | {} |",
439 self.total_proof_time, self.total_par_proof_time
440 )?;
441 for (group_name, proof_time, par_proof_time) in rows {
442 writeln!(writer, "| {group_name} | {proof_time} | {par_proof_time} |")?;
443 }
444 writeln!(writer)?;
445 Ok(())
446 }
447
448 pub fn name(&self) -> String {
449 self.by_group
451 .keys()
452 .find(|k| group_weight(k) == 0)
453 .unwrap_or_else(|| {
454 self.by_group
455 .keys()
456 .next()
457 .expect("by_group should contain at least one group")
458 })
459 .clone()
460 }
461}
462
463impl BenchmarkOutput {
464 pub fn insert(&mut self, name: &str, metrics: BencherAggregateMetrics) {
465 for (group_name, metrics) in metrics.by_group {
466 self.by_name
467 .entry(format!("{name}::{group_name}"))
468 .or_default()
469 .extend(metrics);
470 }
471 if let Some(e) = self.by_name.insert(
472 name.to_owned(),
473 HashMap::from_iter([
474 ("total_proof_time".to_owned(), metrics.total_proof_time),
475 (
476 "total_par_proof_time".to_owned(),
477 metrics.total_par_proof_time,
478 ),
479 ]),
480 ) {
481 panic!("Duplicate metric: {e:?}");
482 }
483 }
484}
485
486pub const PROOF_TIME_LABEL: &str = "total_proof_time_ms";
487pub const MAIN_CELLS_USED_LABEL: &str = "main_cells_used";
488pub const TOTAL_CELLS_USED_LABEL: &str = "total_cells_used";
489pub const EXECUTE_E1_INSNS_LABEL: &str = "execute_e1_insns";
490pub const EXECUTE_METERED_INSNS_LABEL: &str = "execute_metered_insns";
491pub const EXECUTE_PREFLIGHT_INSNS_LABEL: &str = "execute_preflight_insns";
492pub const EXECUTE_E1_TIME_LABEL: &str = "execute_e1_time_ms";
493pub const EXECUTE_E1_INSN_MI_S_LABEL: &str = "execute_e1_insn_mi/s";
494pub const EXECUTE_METERED_TIME_LABEL: &str = "execute_metered_time_ms";
495pub const EXECUTE_METERED_INSN_MI_S_LABEL: &str = "execute_metered_insn_mi/s";
496pub const EXECUTE_PREFLIGHT_TIME_LABEL: &str = "execute_preflight_time_ms";
497pub const EXECUTE_PREFLIGHT_INSN_MI_S_LABEL: &str = "execute_preflight_insn_mi/s";
498pub const TRACE_GEN_TIME_LABEL: &str = "trace_gen_time_ms";
499pub const MEM_FIN_TIME_LABEL: &str = "memory_finalize_time_ms";
500pub const BOUNDARY_FIN_TIME_LABEL: &str = "boundary_finalize_time_ms";
501pub const MERKLE_FIN_TIME_LABEL: &str = "merkle_finalize_time_ms";
502pub const PROVE_EXCL_TRACE_TIME_LABEL: &str = "stark_prove_excluding_trace_time_ms";
503
504pub const VM_METRIC_NAMES: &[&str] = &[
505 PROOF_TIME_LABEL,
506 MAIN_CELLS_USED_LABEL,
507 TOTAL_CELLS_USED_LABEL,
508 EXECUTE_E1_TIME_LABEL,
509 EXECUTE_E1_INSN_MI_S_LABEL,
510 EXECUTE_METERED_TIME_LABEL,
511 EXECUTE_METERED_INSN_MI_S_LABEL,
512 EXECUTE_PREFLIGHT_INSNS_LABEL,
513 EXECUTE_PREFLIGHT_TIME_LABEL,
514 EXECUTE_PREFLIGHT_INSN_MI_S_LABEL,
515 TRACE_GEN_TIME_LABEL,
516 MEM_FIN_TIME_LABEL,
517 BOUNDARY_FIN_TIME_LABEL,
518 MERKLE_FIN_TIME_LABEL,
519 PROVE_EXCL_TRACE_TIME_LABEL,
520 "main_trace_commit_time_ms",
521 "generate_perm_trace_time_ms",
522 "perm_trace_commit_time_ms",
523 "quotient_poly_compute_time_ms",
524 "quotient_poly_commit_time_ms",
525 "pcs_opening_time_ms",
526];