openvm_stark_backend/prover/
metrics.rs
1use std::fmt::Display;
2
3use itertools::zip_eq;
4use serde::{Deserialize, Serialize};
5use tracing::info;
6
7use super::{hal::ProverBackend, types::DeviceMultiStarkProvingKey};
8use crate::keygen::types::TraceWidth;
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct TraceMetrics {
12 pub per_air: Vec<SingleTraceMetrics>,
13 pub total_cells: usize,
15 pub trace_height_inequalities: Vec<(usize, usize)>,
17}
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct SingleTraceMetrics {
21 pub air_name: String,
22 pub height: usize,
23 pub width: TraceWidth,
25 pub cells: TraceCells,
26 pub total_cells: usize,
29}
30
31#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct TraceCells {
34 pub preprocessed: Option<usize>,
35 pub cached_mains: Vec<usize>,
36 pub common_main: usize,
37 pub after_challenge: Vec<usize>,
38}
39
40impl Display for TraceMetrics {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 writeln!(
43 f,
44 "total_trace_cells = {} (excluding preprocessed)",
45 format_number_with_underscores(self.total_cells)
46 )?;
47 writeln!(
48 f,
49 "preprocessed_trace_cells = {}",
50 format_number_with_underscores(
51 self.per_air
52 .iter()
53 .map(|m| m.cells.preprocessed.unwrap_or(0))
54 .sum::<usize>()
55 )
56 )?;
57 writeln!(
58 f,
59 "main_trace_cells = {}",
60 format_number_with_underscores(
61 self.per_air
62 .iter()
63 .map(|m| m.cells.cached_mains.iter().sum::<usize>() + m.cells.common_main)
64 .sum::<usize>()
65 )
66 )?;
67 writeln!(
68 f,
69 "perm_trace_cells = {}",
70 format_number_with_underscores(
71 self.per_air
72 .iter()
73 .map(|m| m.cells.after_challenge.iter().sum::<usize>())
74 .sum::<usize>()
75 )
76 )?;
77 for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate() {
78 writeln!(
79 f,
80 "trace_height_constraint_{i} | weighted_sum = {:<10} | threshold = {:<10}",
81 format_number_with_underscores(*weighted_sum),
82 format_number_with_underscores(*threshold)
83 )?;
84 }
85 for trace_metrics in &self.per_air {
86 writeln!(f, "{}", trace_metrics)?;
87 }
88 Ok(())
89 }
90}
91
92impl Display for SingleTraceMetrics {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(
95 f,
96 "{:<20} | Rows = {:<10} | Cells = {:<11} | Prep Cols = {:<5} | Main Cols = {:<5} | Perm Cols = {:<5}",
97 self.air_name, format_number_with_underscores(self.height), format_number_with_underscores(self.total_cells), self.width.preprocessed.unwrap_or(0),
98 format!("{:?}", self.width.main_widths()),
99 format!("{:?}",self.width.after_challenge),
100 )?;
101 Ok(())
102 }
103}
104
105pub fn trace_metrics<PB: ProverBackend>(
107 mpk: &DeviceMultiStarkProvingKey<PB>,
108 log_trace_heights: &[u8],
109) -> TraceMetrics {
110 let heights = log_trace_heights
111 .iter()
112 .map(|&h| 1usize << h)
113 .collect::<Vec<_>>();
114 let trace_height_inequalities = mpk
115 .trace_height_constraints
116 .iter()
117 .map(|trace_height_constraint| {
118 let weighted_sum = heights
119 .iter()
120 .enumerate()
121 .map(|(j, h)| {
122 let air_id = mpk.air_ids[j];
123 (trace_height_constraint.coefficients[air_id] as usize) * h
124 })
125 .sum::<usize>();
126 (weighted_sum, trace_height_constraint.threshold as usize)
127 })
128 .collect::<Vec<_>>();
129 let per_air: Vec<_> = zip_eq(&mpk.per_air, heights)
130 .map(|(pk, height)| {
131 let air_name = pk.air_name;
132 let mut width = pk.vk.params.width.clone();
133 let ext_degree = PB::CHALLENGE_EXT_DEGREE as usize;
134 for w in &mut width.after_challenge {
135 *w *= ext_degree;
136 }
137 let cells = TraceCells {
138 preprocessed: width.preprocessed.map(|w| w * height),
139 cached_mains: width.cached_mains.iter().map(|w| w * height).collect(),
140 common_main: width.common_main * height,
141 after_challenge: width.after_challenge.iter().map(|w| w * height).collect(),
142 };
143 let total_cells = cells
144 .cached_mains
145 .iter()
146 .chain([&cells.common_main])
147 .chain(cells.after_challenge.iter())
148 .sum::<usize>();
149 SingleTraceMetrics {
150 air_name: air_name.to_string(),
151 height,
152 width,
153 cells,
154 total_cells,
155 }
156 })
157 .collect();
158 let total_cells = per_air.iter().map(|m| m.total_cells).sum();
159 let metrics = TraceMetrics {
160 per_air,
161 total_cells,
162 trace_height_inequalities,
163 };
164 info!("{}", metrics);
165 metrics
166}
167
168pub fn format_number_with_underscores(n: usize) -> String {
169 let num_str = n.to_string();
170 let mut result = String::new();
171
172 for (i, c) in num_str.chars().rev().enumerate() {
174 if i > 0 && i % 3 == 0 {
175 result.push('_');
176 }
177 result.push(c);
178 }
179
180 result.chars().rev().collect()
182}
183
184#[cfg(feature = "bench-metrics")]
185mod emit {
186 use metrics::counter;
187
188 use super::{SingleTraceMetrics, TraceMetrics};
189
190 impl TraceMetrics {
191 pub fn emit(&self) {
192 for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate()
193 {
194 let labels = [("trace_height_constraint", i.to_string())];
195 counter!("weighted_sum", &labels).absolute(*weighted_sum as u64);
196 counter!("threshold", &labels).absolute(*threshold as u64);
197 }
198 for trace_metrics in &self.per_air {
199 trace_metrics.emit();
200 }
201 counter!("total_cells").absolute(self.total_cells as u64);
202 }
203 }
204
205 impl SingleTraceMetrics {
206 pub fn emit(&self) {
207 let labels = [("air_name", self.air_name.clone())];
208 counter!("rows", &labels).absolute(self.height as u64);
209 counter!("cells", &labels).absolute(self.total_cells as u64);
210 counter!("prep_cols", &labels).absolute(self.width.preprocessed.unwrap_or(0) as u64);
211 counter!("main_cols", &labels).absolute(
212 (self.width.cached_mains.iter().sum::<usize>() + self.width.common_main) as u64,
213 );
214 counter!("perm_cols", &labels)
215 .absolute(self.width.after_challenge.iter().sum::<usize>() as u64);
216 }
217 }
218}