1use std::{collections::HashMap, io::Write};
2
3use eyre::Result;
4use serde::{Deserialize, Serialize};
5
6use crate::types::{BencherValue, BenchmarkOutput, Labels, MdTableCell, MetricDb};
7
8type MetricName = String;
9type MetricsByName = HashMap<MetricName, Vec<(f64, Labels)>>;
10
11#[derive(Clone, Debug, Default)]
12pub struct GroupedMetrics {
13 pub by_group: HashMap<String, MetricsByName>,
15 pub ungrouped: MetricsByName,
16}
17
18#[derive(Clone, Debug, Default, Serialize, Deserialize)]
19pub struct AggregateMetrics {
20 #[serde(flatten)]
22 pub by_group: HashMap<String, HashMap<MetricName, Stats>>,
23 pub total_proof_time: MdTableCell,
25 pub total_par_proof_time: MdTableCell,
27 #[serde(skip)]
29 pub bounded_par_by_group: HashMap<String, MdTableCell>,
30}
31
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct BencherAggregateMetrics {
34 #[serde(flatten)]
35 pub by_group: HashMap<String, HashMap<String, BencherValue>>,
36 pub total_proof_time: BencherValue,
38 pub total_par_proof_time: BencherValue,
40}
41
42#[derive(Clone, Debug, Serialize, Deserialize)]
43pub struct Stats {
44 pub sum: MdTableCell,
45 pub max: MdTableCell,
46 pub min: MdTableCell,
47 pub avg: MdTableCell,
48 #[serde(skip)]
49 pub count: usize,
50}
51
52impl Default for Stats {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl Stats {
59 pub fn new() -> Self {
60 Self {
61 sum: MdTableCell::default(),
62 max: MdTableCell::default(),
63 min: MdTableCell::new(f64::MAX, None),
64 avg: MdTableCell::default(),
65 count: 0,
66 }
67 }
68 pub fn push(&mut self, value: f64) {
69 self.sum.val += value;
70 self.count += 1;
71 if value > self.max.val {
72 self.max.val = value;
73 }
74 if value < self.min.val {
75 self.min.val = value;
76 }
77 }
78
79 pub fn finalize(&mut self) {
80 assert!(self.count != 0);
81 self.avg.val = self.sum.val / self.count as f64;
82 }
83
84 pub fn set_diff(&mut self, prev: &Self) {
85 self.sum.diff = Some(self.sum.val - prev.sum.val);
86 self.max.diff = Some(self.max.val - prev.max.val);
87 self.min.diff = Some(self.min.val - prev.min.val);
88 self.avg.diff = Some(self.avg.val - prev.avg.val);
89 }
90}
91
92impl GroupedMetrics {
93 pub fn new(db: &MetricDb, group_label_name: &str) -> Result<Self> {
94 let mut by_group = HashMap::<String, MetricsByName>::new();
95 let mut ungrouped = MetricsByName::new();
96 for (labels, metrics) in db.flat_dict.iter() {
97 let group_name = labels.get(group_label_name);
98 if let Some(group_name) = group_name {
99 let group_entry = by_group.entry(group_name.to_string()).or_default();
100 let mut labels = labels.clone();
101 labels.remove(group_label_name);
102 for metric in metrics {
103 group_entry
104 .entry(metric.name.clone())
105 .or_default()
106 .push((metric.value, labels.clone()));
107 }
108 } else {
109 for metric in metrics {
110 ungrouped
111 .entry(metric.name.clone())
112 .or_default()
113 .push((metric.value, labels.clone()));
114 }
115 }
116 }
117 Ok(Self {
118 by_group,
119 ungrouped,
120 })
121 }
122
123 fn validate_instruction_counts(group_summaries: &HashMap<MetricName, Stats>) {
125 let e1_insns = group_summaries.get(EXECUTE_E1_INSNS_LABEL);
126 let metered_insns = group_summaries.get(EXECUTE_METERED_INSNS_LABEL);
127 let preflight_insns = group_summaries.get(EXECUTE_PREFLIGHT_INSNS_LABEL);
128
129 if let (Some(e1_insns), Some(preflight_insns)) = (e1_insns, preflight_insns) {
130 assert_eq!(e1_insns.sum.val as u64, preflight_insns.sum.val as u64);
131 }
132 if let (Some(e1_insns), Some(metered_insns)) = (e1_insns, metered_insns) {
133 assert_eq!(e1_insns.sum.val as u64, metered_insns.sum.val as u64);
134 }
135 if let (Some(metered_insns), Some(preflight_insns)) = (metered_insns, preflight_insns) {
136 assert_eq!(metered_insns.sum.val as u64, preflight_insns.sum.val as u64);
137 }
138 }
139
140 pub fn aggregate(&self, num_parallel: usize) -> AggregateMetrics {
141 let by_group: HashMap<String, _> = self
142 .by_group
143 .iter()
144 .map(|(group_name, metrics)| {
145 let group_summaries: HashMap<MetricName, Stats> = metrics
146 .iter()
147 .map(|(metric_name, metrics)| {
148 let mut summary = Stats::new();
149 for (value, _) in metrics {
150 summary.push(*value);
151 }
152 summary.finalize();
153 (metric_name.clone(), summary)
154 })
155 .collect();
156
157 if !group_name.contains("keygen") {
158 Self::validate_instruction_counts(&group_summaries);
159 }
160
161 (group_name.clone(), group_summaries)
162 })
163 .collect();
164 let mut metrics = AggregateMetrics {
165 by_group,
166 ..Default::default()
167 };
168 metrics.compute_total();
169 metrics.bounded_par_by_group = self
170 .compute_bounded_par_times(num_parallel, &metrics.by_group)
171 .into_iter()
172 .map(|(k, v)| (k, MdTableCell::new(v, Some(0.0))))
173 .collect();
174
175 metrics
176 }
177
178 fn compute_bounded_par_times(
180 &self,
181 num_parallel: usize,
182 stats_by_group: &HashMap<String, HashMap<MetricName, Stats>>,
183 ) -> HashMap<String, f64> {
184 let mut per_group = HashMap::new();
185
186 for (group_name, metrics) in &self.by_group {
187 if group_name.contains("keygen") {
188 continue;
189 }
190
191 let mut group_time = 0.0;
192
193 if is_app_proof_group(group_name) {
195 if let Some(stats) = stats_by_group.get(group_name) {
196 if let Some(metered) = stats.get(EXECUTE_METERED_TIME_LABEL) {
197 group_time += metered.avg.val / 1000.0;
198 }
199 if let Some(e1) = stats.get(EXECUTE_E1_TIME_LABEL) {
200 group_time += e1.avg.val / 1000.0;
201 }
202 }
203 }
204
205 if let Some(proof_times) = metrics.get(PROOF_TIME_LABEL) {
207 let times_s: Vec<f64> = proof_times.iter().map(|(ms, _)| ms / 1000.0).collect();
208 group_time += schedule_parallel(×_s, num_parallel);
209 }
210
211 per_group.insert(group_name.clone(), group_time);
212 }
213
214 per_group
215 }
216}
217
218fn schedule_parallel(proof_times: &[f64], num_parallel: usize) -> f64 {
220 if proof_times.is_empty() || num_parallel == 0 {
221 return 0.0;
222 }
223
224 let mut slot_times = vec![0.0_f64; num_parallel];
225 for (i, duration) in proof_times.iter().enumerate() {
226 slot_times[i % num_parallel] += duration;
227 }
228 slot_times.iter().cloned().fold(0.0_f64, f64::max)
229}
230
231fn is_app_proof_group(name: &str) -> bool {
232 name != "leaf"
233 && name != "root"
234 && name != "halo2_outer"
235 && name != "halo2_wrapper"
236 && !name.starts_with("internal")
237}
238
239pub(crate) fn group_weight(name: &str) -> usize {
241 let label_prefix = ["leaf", "internal", "root", "halo2_outer", "halo2_wrapper"];
242 if name.contains("keygen") {
243 return label_prefix.len() + 1;
244 }
245 for (i, prefix) in label_prefix.iter().enumerate().rev() {
246 if name.starts_with(prefix) {
247 return i + 1;
248 }
249 }
250 0
251}
252
253impl AggregateMetrics {
254 pub fn compute_total(&mut self) {
255 let mut total_proof_time = MdTableCell::new(0.0, Some(0.0));
256 let mut total_par_proof_time = MdTableCell::new(0.0, Some(0.0));
257 for (group_name, metrics) in &self.by_group {
258 let stats = metrics.get(PROOF_TIME_LABEL);
259 let execute_metered_stats = metrics.get(EXECUTE_METERED_TIME_LABEL);
260 let execute_e1_stats = metrics.get(EXECUTE_E1_TIME_LABEL);
261 if stats.is_none() {
262 continue;
263 }
264 let stats = stats.unwrap_or_else(|| {
265 panic!("Missing proof time statistics for group '{group_name}'")
266 });
267 let mut sum = stats.sum;
268 let mut max = stats.max;
269 sum.val /= 1000.0;
271 max.val /= 1000.0;
272 if let Some(diff) = &mut sum.diff {
273 *diff /= 1000.0;
274 }
275 if let Some(diff) = &mut max.diff {
276 *diff /= 1000.0;
277 }
278 if !group_name.contains("keygen") {
279 total_proof_time.val += sum.val;
281 *total_proof_time
282 .diff
283 .as_mut()
284 .expect("total_proof_time.diff should be initialized") +=
285 sum.diff.unwrap_or(0.0);
286 total_par_proof_time.val += max.val;
287 *total_par_proof_time
288 .diff
289 .as_mut()
290 .expect("total_par_proof_time.diff should be initialized") +=
291 max.diff.unwrap_or(0.0);
292
293 if is_app_proof_group(group_name) {
295 if let Some(execute_metered_stats) = execute_metered_stats {
296 total_proof_time.val += execute_metered_stats.avg.val / 1000.0;
300 total_par_proof_time.val += execute_metered_stats.avg.val / 1000.0;
301 if let Some(diff) = execute_metered_stats.avg.diff {
302 *total_proof_time
303 .diff
304 .as_mut()
305 .expect("total_proof_time.diff should be initialized") +=
306 diff / 1000.0;
307 *total_par_proof_time
308 .diff
309 .as_mut()
310 .expect("total_par_proof_time.diff should be initialized") +=
311 diff / 1000.0;
312 }
313 }
314
315 if let Some(execute_e1_stats) = execute_e1_stats {
316 total_proof_time.val += execute_e1_stats.avg.val / 1000.0;
317 total_par_proof_time.val += execute_e1_stats.avg.val / 1000.0;
318 if let Some(diff) = execute_e1_stats.avg.diff {
319 *total_proof_time
320 .diff
321 .as_mut()
322 .expect("total_proof_time.diff should be initialized") +=
323 diff / 1000.0;
324 *total_par_proof_time
325 .diff
326 .as_mut()
327 .expect("total_par_proof_time.diff should be initialized") +=
328 diff / 1000.0;
329 }
330 }
331 }
332 }
333 }
334 self.total_proof_time = total_proof_time;
335 self.total_par_proof_time = total_par_proof_time;
336 }
337
338 pub fn set_diff(&mut self, prev: &Self) {
339 for (group_name, metrics) in self.by_group.iter_mut() {
340 if let Some(prev_metrics) = prev.by_group.get(group_name) {
341 for (metric_name, stats) in metrics.iter_mut() {
342 if let Some(prev_stats) = prev_metrics.get(metric_name) {
343 stats.set_diff(prev_stats);
344 }
345 }
346 }
347 }
348 self.compute_total();
349 }
350
351 pub fn to_vec(&self) -> Vec<(String, HashMap<MetricName, Stats>)> {
352 let mut group_names: Vec<_> = self.by_group.keys().collect();
353 group_names.sort_by(|a, b| {
354 let a_wt = group_weight(a);
355 let b_wt = group_weight(b);
356 if a_wt == b_wt {
357 a.cmp(b)
358 } else {
359 a_wt.cmp(&b_wt)
360 }
361 });
362 group_names
363 .into_iter()
364 .map(|group_name| {
365 let key = group_name.clone();
366 let value = self
367 .by_group
368 .get(group_name)
369 .unwrap_or_else(|| panic!("Group '{group_name}' should exist in by_group map"))
370 .clone();
371 (key, value)
372 })
373 .collect()
374 }
375
376 pub fn to_bencher_metrics(&self) -> BencherAggregateMetrics {
377 let by_group = self
378 .by_group
379 .iter()
380 .map(|(group_name, metrics)| {
381 let metrics = metrics
382 .iter()
383 .filter(|(_, stats)| stats.avg.val.is_finite() && stats.sum.val.is_finite())
384 .flat_map(|(metric_name, stats)| {
385 [
386 (format!("{metric_name}::sum"), stats.sum.into()),
387 (
388 metric_name.clone(),
389 BencherValue {
390 value: stats.avg.val,
391 lower_value: Some(stats.min.val),
392 upper_value: Some(stats.max.val),
393 },
394 ),
395 ]
396 })
397 .collect();
398 (group_name.clone(), metrics)
399 })
400 .collect();
401 let total_proof_time = self.total_proof_time.into();
402 let total_par_proof_time = self.total_par_proof_time.into();
403 BencherAggregateMetrics {
404 by_group,
405 total_proof_time,
406 total_par_proof_time,
407 }
408 }
409
410 pub fn write_markdown(
411 &self,
412 writer: &mut impl Write,
413 metric_names: &[&str],
414 num_parallel: usize,
415 ) -> Result<()> {
416 self.write_summary_markdown(writer, num_parallel)?;
417 writeln!(writer)?;
418
419 let metric_names = metric_names.to_vec();
420 for (group_name, summaries) in self.to_vec() {
421 writeln!(writer, "| {group_name} |||||")?;
422 writeln!(writer, "|:---|---:|---:|---:|---:|")?;
423 writeln!(writer, "|metric|avg|sum|max|min|")?;
424 let names = if metric_names.is_empty() {
425 summaries.keys().map(|s| s.as_str()).collect()
426 } else {
427 metric_names.clone()
428 };
429 for metric_name in names {
430 let summary = summaries.get(metric_name);
431 if let Some(summary) = summary {
432 if metric_name == EXECUTE_METERED_TIME_LABEL
435 && group_name != "leaf"
436 && group_name != "root"
437 && group_name != "halo2_outer"
438 && group_name != "halo2_wrapper"
439 && !group_name.starts_with("internal")
440 {
441 writeln!(
442 writer,
443 "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
444 metric_name, summary.avg, "-", "-", "-",
445 )?;
446 } else if metric_name == EXECUTE_E1_INSN_MI_S_LABEL
447 || metric_name == EXECUTE_PREFLIGHT_INSN_MI_S_LABEL
448 || metric_name == EXECUTE_METERED_INSN_MI_S_LABEL
449 {
450 writeln!(
452 writer,
453 "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
454 metric_name, summary.avg, "-", summary.max, summary.min,
455 )?;
456 } else {
457 writeln!(
458 writer,
459 "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
460 metric_name, summary.avg, summary.sum, summary.max, summary.min,
461 )?;
462 }
463 }
464 }
465 writeln!(writer)?;
466 }
467 writeln!(writer)?;
468
469 Ok(())
470 }
471
472 fn write_summary_markdown(&self, writer: &mut impl Write, num_parallel: usize) -> Result<()> {
473 writeln!(
474 writer,
475 "| Summary | Proof Time (s) | Parallel Proof Time (s) | Parallel Proof Time ({} provers) (s) |",
476 num_parallel
477 )?;
478 writeln!(writer, "|:---|---:|---:|---:|")?;
479 let mut rows = Vec::new();
480 for (group_name, summaries) in self.to_vec() {
481 if group_name.contains("keygen") {
482 continue;
483 }
484 let stats = summaries.get(PROOF_TIME_LABEL);
485 if stats.is_none() {
486 continue;
487 }
488 let stats = stats.unwrap_or_else(|| {
489 panic!("Missing proof time statistics for group '{group_name}'")
490 });
491 let mut sum = stats.sum;
492 let mut max = stats.max;
493 sum.val /= 1000.0;
495 max.val /= 1000.0;
496 if let Some(diff) = &mut sum.diff {
497 *diff /= 1000.0;
498 }
499 if let Some(diff) = &mut max.diff {
500 *diff /= 1000.0;
501 }
502 if is_app_proof_group(&group_name) {
504 if let Some(metered) = summaries.get(EXECUTE_METERED_TIME_LABEL) {
505 sum.val += metered.avg.val / 1000.0;
506 max.val += metered.avg.val / 1000.0;
507 }
508 if let Some(e1) = summaries.get(EXECUTE_E1_TIME_LABEL) {
509 sum.val += e1.avg.val / 1000.0;
510 max.val += e1.avg.val / 1000.0;
511 }
512 }
513 rows.push((group_name, sum, max));
514 }
515 let total_bounded: f64 = self.bounded_par_by_group.values().map(|v| v.val).sum();
516 writeln!(
517 writer,
518 "| Total | {} | {} | {:.2} |",
519 self.total_proof_time, self.total_par_proof_time, total_bounded
520 )?;
521 for (group_name, proof_time, par_proof_time) in rows {
522 let bounded = self
523 .bounded_par_by_group
524 .get(&group_name)
525 .map(|v| v.to_string())
526 .unwrap_or_else(|| "-".to_string());
527 writeln!(
528 writer,
529 "| {group_name} | {proof_time} | {par_proof_time} | {bounded} |"
530 )?;
531 }
532 writeln!(writer)?;
533 Ok(())
534 }
535
536 pub fn name(&self) -> String {
537 self.by_group
539 .keys()
540 .find(|k| group_weight(k) == 0)
541 .unwrap_or_else(|| {
542 self.by_group
543 .keys()
544 .next()
545 .expect("by_group should contain at least one group")
546 })
547 .clone()
548 }
549}
550
551impl BenchmarkOutput {
552 pub fn insert(&mut self, name: &str, metrics: BencherAggregateMetrics) {
553 for (group_name, metrics) in metrics.by_group {
554 self.by_name
555 .entry(format!("{name}::{group_name}"))
556 .or_default()
557 .extend(metrics);
558 }
559 if let Some(e) = self.by_name.insert(
560 name.to_owned(),
561 HashMap::from_iter([
562 ("total_proof_time".to_owned(), metrics.total_proof_time),
563 (
564 "total_par_proof_time".to_owned(),
565 metrics.total_par_proof_time,
566 ),
567 ]),
568 ) {
569 panic!("Duplicate metric: {e:?}");
570 }
571 }
572}
573
574pub const PROOF_TIME_LABEL: &str = "total_proof_time_ms";
575pub const MAIN_CELLS_USED_LABEL: &str = "main_cells_used";
576pub const TOTAL_CELLS_USED_LABEL: &str = "total_cells_used";
577pub const EXECUTE_E1_INSNS_LABEL: &str = "execute_e1_insns";
578pub const EXECUTE_METERED_INSNS_LABEL: &str = "execute_metered_insns";
579pub const EXECUTE_PREFLIGHT_INSNS_LABEL: &str = "execute_preflight_insns";
580pub const EXECUTE_E1_TIME_LABEL: &str = "execute_e1_time_ms";
581pub const EXECUTE_E1_INSN_MI_S_LABEL: &str = "execute_e1_insn_mi/s";
582pub const EXECUTE_METERED_TIME_LABEL: &str = "execute_metered_time_ms";
583pub const EXECUTE_METERED_INSN_MI_S_LABEL: &str = "execute_metered_insn_mi/s";
584pub const EXECUTE_PREFLIGHT_TIME_LABEL: &str = "execute_preflight_time_ms";
585pub const EXECUTE_PREFLIGHT_INSN_MI_S_LABEL: &str = "execute_preflight_insn_mi/s";
586pub const TRACE_GEN_TIME_LABEL: &str = "trace_gen_time_ms";
587pub const MEM_FIN_TIME_LABEL: &str = "memory_finalize_time_ms";
588pub const BOUNDARY_FIN_TIME_LABEL: &str = "boundary_finalize_time_ms";
589pub const MERKLE_FIN_TIME_LABEL: &str = "merkle_finalize_time_ms";
590pub const PROVE_EXCL_TRACE_TIME_LABEL: &str = "stark_prove_excluding_trace_time_ms";
591
592pub const VM_METRIC_NAMES: &[&str] = &[
593 PROOF_TIME_LABEL,
594 MAIN_CELLS_USED_LABEL,
595 TOTAL_CELLS_USED_LABEL,
596 EXECUTE_E1_TIME_LABEL,
597 EXECUTE_E1_INSN_MI_S_LABEL,
598 EXECUTE_METERED_TIME_LABEL,
599 EXECUTE_METERED_INSN_MI_S_LABEL,
600 EXECUTE_PREFLIGHT_INSNS_LABEL,
601 EXECUTE_PREFLIGHT_TIME_LABEL,
602 EXECUTE_PREFLIGHT_INSN_MI_S_LABEL,
603 TRACE_GEN_TIME_LABEL,
604 MEM_FIN_TIME_LABEL,
605 BOUNDARY_FIN_TIME_LABEL,
606 MERKLE_FIN_TIME_LABEL,
607 PROVE_EXCL_TRACE_TIME_LABEL,
608 "main_trace_commit_time_ms",
609 "generate_perm_trace_time_ms",
610 "perm_trace_commit_time_ms",
611 "quotient_poly_compute_time_ms",
612 "quotient_poly_commit_time_ms",
613 "pcs_opening_time_ms",
614];