1use std::{collections::HashMap, io::Write};
2
3use eyre::Result;
4use serde::{Deserialize, Serialize};
5
6use crate::types::{BencherValue, BenchmarkOutput, Labels, MdTableCell, MetricDb};
7
8type MetricName = String;
9type MetricsByName = HashMap<MetricName, Vec<(f64, Labels)>>;
10
11#[derive(Clone, Debug, Default)]
12pub struct GroupedMetrics {
13 pub by_group: HashMap<String, MetricsByName>,
15 pub ungrouped: MetricsByName,
16}
17
18#[derive(Clone, Debug, Default, Serialize, Deserialize)]
19pub struct AggregateMetrics {
20 #[serde(flatten)]
22 pub by_group: HashMap<String, HashMap<MetricName, Stats>>,
23 pub total_proof_time: MdTableCell,
25 pub total_par_proof_time: MdTableCell,
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct BencherAggregateMetrics {
31 #[serde(flatten)]
32 pub by_group: HashMap<String, HashMap<String, BencherValue>>,
33 pub total_proof_time: BencherValue,
35 pub total_par_proof_time: BencherValue,
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize)]
40pub struct Stats {
41 pub sum: MdTableCell,
42 pub max: MdTableCell,
43 pub min: MdTableCell,
44 pub avg: MdTableCell,
45 #[serde(skip)]
46 pub count: usize,
47}
48
49impl Default for Stats {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl Stats {
56 pub fn new() -> Self {
57 Self {
58 sum: MdTableCell::default(),
59 max: MdTableCell::default(),
60 min: MdTableCell::new(f64::MAX, None),
61 avg: MdTableCell::default(),
62 count: 0,
63 }
64 }
65 pub fn push(&mut self, value: f64) {
66 self.sum.val += value;
67 self.count += 1;
68 if value > self.max.val {
69 self.max.val = value;
70 }
71 if value < self.min.val {
72 self.min.val = value;
73 }
74 }
75
76 pub fn finalize(&mut self) {
77 assert!(self.count != 0);
78 self.avg.val = self.sum.val / self.count as f64;
79 }
80
81 pub fn set_diff(&mut self, prev: &Self) {
82 self.sum.diff = Some(self.sum.val - prev.sum.val);
83 self.max.diff = Some(self.max.val - prev.max.val);
84 self.min.diff = Some(self.min.val - prev.min.val);
85 self.avg.diff = Some(self.avg.val - prev.avg.val);
86 }
87}
88
89impl GroupedMetrics {
90 pub fn new(db: &MetricDb, group_label_name: &str) -> Result<Self> {
91 let mut by_group = HashMap::<String, MetricsByName>::new();
92 let mut ungrouped = MetricsByName::new();
93 for (labels, metrics) in db.flat_dict.iter() {
94 let group_name = labels.get(group_label_name);
95 if let Some(group_name) = group_name {
96 let group_entry = by_group.entry(group_name.to_string()).or_default();
97 let mut labels = labels.clone();
98 labels.remove(group_label_name);
99 for metric in metrics {
100 group_entry
101 .entry(metric.name.clone())
102 .or_default()
103 .push((metric.value, labels.clone()));
104 }
105 } else {
106 for metric in metrics {
107 ungrouped
108 .entry(metric.name.clone())
109 .or_default()
110 .push((metric.value, labels.clone()));
111 }
112 }
113 }
114 Ok(Self {
115 by_group,
116 ungrouped,
117 })
118 }
119
120 pub fn aggregate(&self) -> AggregateMetrics {
121 let by_group: HashMap<String, _> = self
122 .by_group
123 .iter()
124 .map(|(group_name, metrics)| {
125 let group_summaries: HashMap<MetricName, Stats> = metrics
126 .iter()
127 .map(|(metric_name, metrics)| {
128 let mut summary = Stats::new();
129 for (value, _) in metrics {
130 summary.push(*value);
131 }
132 summary.finalize();
133 (metric_name.clone(), summary)
134 })
135 .collect();
136 (group_name.clone(), group_summaries)
137 })
138 .collect();
139 let mut metrics = AggregateMetrics {
140 by_group,
141 ..Default::default()
142 };
143 metrics.compute_total();
144 metrics
145 }
146}
147
148pub(crate) fn group_weight(name: &str) -> usize {
150 let label_prefix = ["leaf", "internal", "root", "halo2_outer", "halo2_wrapper"];
151 if name.contains("keygen") {
152 return label_prefix.len() + 1;
153 }
154 for (i, prefix) in label_prefix.iter().enumerate().rev() {
155 if name.starts_with(prefix) {
156 return i + 1;
157 }
158 }
159 0
160}
161
162impl AggregateMetrics {
163 pub fn compute_total(&mut self) {
164 let mut total_proof_time = MdTableCell::new(0.0, Some(0.0));
165 let mut total_par_proof_time = MdTableCell::new(0.0, Some(0.0));
166 for (group_name, metrics) in &self.by_group {
167 let stats = metrics.get(PROOF_TIME_LABEL);
168 if stats.is_none() {
169 continue;
170 }
171 let stats = stats.unwrap();
172 let mut sum = stats.sum;
173 let mut max = stats.max;
174 sum.val /= 1000.0;
176 max.val /= 1000.0;
177 if let Some(diff) = &mut sum.diff {
178 *diff /= 1000.0;
179 }
180 if let Some(diff) = &mut max.diff {
181 *diff /= 1000.0;
182 }
183 if !group_name.contains("keygen") {
184 total_proof_time.val += sum.val;
186 *total_proof_time.diff.as_mut().unwrap() += sum.diff.unwrap_or(0.0);
187 total_par_proof_time.val += max.val;
188 *total_par_proof_time.diff.as_mut().unwrap() += max.diff.unwrap_or(0.0);
189 }
190 }
191 self.total_proof_time = total_proof_time;
192 self.total_par_proof_time = total_par_proof_time;
193 }
194
195 pub fn set_diff(&mut self, prev: &Self) {
196 for (group_name, metrics) in self.by_group.iter_mut() {
197 if let Some(prev_metrics) = prev.by_group.get(group_name) {
198 for (metric_name, stats) in metrics.iter_mut() {
199 if let Some(prev_stats) = prev_metrics.get(metric_name) {
200 stats.set_diff(prev_stats);
201 }
202 }
203 }
204 }
205 self.compute_total();
206 }
207
208 pub fn to_vec(&self) -> Vec<(String, HashMap<MetricName, Stats>)> {
209 let mut group_names: Vec<_> = self.by_group.keys().collect();
210 group_names.sort_by(|a, b| {
211 let a_wt = group_weight(a);
212 let b_wt = group_weight(b);
213 if a_wt == b_wt {
214 a.cmp(b)
215 } else {
216 a_wt.cmp(&b_wt)
217 }
218 });
219 group_names
220 .into_iter()
221 .map(|group_name| {
222 let key = group_name.clone();
223 let value = self.by_group.get(group_name).unwrap().clone();
224 (key, value)
225 })
226 .collect()
227 }
228
229 pub fn to_bencher_metrics(&self) -> BencherAggregateMetrics {
230 let by_group = self
231 .by_group
232 .iter()
233 .map(|(group_name, metrics)| {
234 let metrics = metrics
235 .iter()
236 .flat_map(|(metric_name, stats)| {
237 [
238 (format!("{metric_name}::sum"), stats.sum.into()),
239 (
240 metric_name.clone(),
241 BencherValue {
242 value: stats.avg.val,
243 lower_value: Some(stats.min.val),
244 upper_value: Some(stats.max.val),
245 },
246 ),
247 ]
248 })
249 .collect();
250 (group_name.clone(), metrics)
251 })
252 .collect();
253 let total_proof_time = self.total_proof_time.into();
254 let total_par_proof_time = self.total_par_proof_time.into();
255 BencherAggregateMetrics {
256 by_group,
257 total_proof_time,
258 total_par_proof_time,
259 }
260 }
261
262 pub fn write_markdown(&self, writer: &mut impl Write, metric_names: &[&str]) -> Result<()> {
263 self.write_summary_markdown(writer)?;
264 writeln!(writer)?;
265
266 let metric_names = metric_names.to_vec();
267 for (group_name, summaries) in self.to_vec() {
268 writeln!(writer, "| {} |||||", group_name)?;
269 writeln!(writer, "|:---|---:|---:|---:|---:|")?;
270 writeln!(writer, "|metric|avg|sum|max|min|")?;
271 let names = if metric_names.is_empty() {
272 summaries.keys().map(|s| s.as_str()).collect()
273 } else {
274 metric_names.clone()
275 };
276 for metric_name in names {
277 let summary = summaries.get(metric_name);
278 if let Some(summary) = summary {
279 writeln!(
280 writer,
281 "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |",
282 metric_name, summary.avg, summary.sum, summary.max, summary.min,
283 )?;
284 }
285 }
286 writeln!(writer)?;
287 }
288 writeln!(writer)?;
289
290 Ok(())
291 }
292
293 fn write_summary_markdown(&self, writer: &mut impl Write) -> Result<()> {
294 writeln!(
295 writer,
296 "| Summary | Proof Time (s) | Parallel Proof Time (s) |"
297 )?;
298 writeln!(writer, "|:---|---:|---:|")?;
299 let mut rows = Vec::new();
300 for (group_name, summaries) in self.to_vec() {
301 let stats = summaries.get(PROOF_TIME_LABEL);
302 if stats.is_none() {
303 continue;
304 }
305 let stats = stats.unwrap();
306 let mut sum = stats.sum;
307 let mut max = stats.max;
308 sum.val /= 1000.0;
310 max.val /= 1000.0;
311 if let Some(diff) = &mut sum.diff {
312 *diff /= 1000.0;
313 }
314 if let Some(diff) = &mut max.diff {
315 *diff /= 1000.0;
316 }
317 rows.push((group_name, sum, max));
318 }
319 writeln!(
320 writer,
321 "| Total | {} | {} |",
322 self.total_proof_time, self.total_par_proof_time
323 )?;
324 for (group_name, proof_time, par_proof_time) in rows {
325 writeln!(writer, "| {group_name} | {proof_time} | {par_proof_time} |")?;
326 }
327 writeln!(writer)?;
328 Ok(())
329 }
330
331 pub fn name(&self) -> String {
332 self.by_group
334 .keys()
335 .find(|k| group_weight(k) == 0)
336 .unwrap_or_else(|| self.by_group.keys().next().unwrap())
337 .clone()
338 }
339}
340
341impl BenchmarkOutput {
342 pub fn insert(&mut self, name: &str, metrics: BencherAggregateMetrics) {
343 for (group_name, metrics) in metrics.by_group {
344 self.by_name
345 .entry(format!("{name}::{group_name}"))
346 .or_default()
347 .extend(metrics);
348 }
349 if let Some(e) = self.by_name.insert(
350 name.to_owned(),
351 HashMap::from_iter([
352 ("total_proof_time".to_owned(), metrics.total_proof_time),
353 (
354 "total_par_proof_time".to_owned(),
355 metrics.total_par_proof_time,
356 ),
357 ]),
358 ) {
359 panic!("Duplicate metric: {e:?}");
360 }
361 }
362}
363
364pub const PROOF_TIME_LABEL: &str = "total_proof_time_ms";
365pub const CELLS_USED_LABEL: &str = "main_cells_used";
366pub const CYCLES_LABEL: &str = "total_cycles";
367pub const EXECUTE_TIME_LABEL: &str = "execute_time_ms";
368pub const TRACE_GEN_TIME_LABEL: &str = "trace_gen_time_ms";
369pub const PROVE_EXCL_TRACE_TIME_LABEL: &str = "stark_prove_excluding_trace_time_ms";
370
371pub const VM_METRIC_NAMES: &[&str] = &[
372 PROOF_TIME_LABEL,
373 CELLS_USED_LABEL,
374 CYCLES_LABEL,
375 EXECUTE_TIME_LABEL,
376 TRACE_GEN_TIME_LABEL,
377 PROVE_EXCL_TRACE_TIME_LABEL,
378 "main_trace_commit_time_ms",
379 "generate_perm_trace_time_ms",
380 "perm_trace_commit_time_ms",
381 "quotient_poly_compute_time_ms",
382 "quotient_poly_commit_time_ms",
383 "pcs_opening_time_ms",
384];