openvm_stark_backend/prover/
metrics.rs1use std::fmt::Display;
2
3use itertools::zip_eq;
4use serde::{Deserialize, Serialize};
5use tracing::{debug, info};
6
7use super::hal::ProverBackend;
8use crate::{keygen::types::TraceWidth, prover::types::DeviceMultiStarkProvingKeyView};
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct TraceMetrics {
12 pub per_air: Vec<SingleTraceMetrics>,
13 pub total_cells: usize,
15 pub trace_height_inequalities: Vec<(usize, usize)>,
17}
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct SingleTraceMetrics {
21 pub air_name: String,
22 pub air_id: usize,
23 pub height: usize,
24 pub width: TraceWidth,
26 pub cells: TraceCells,
27 pub total_cells: usize,
30 pub quotient_poly_cells: usize,
32}
33
34#[derive(Clone, Debug, Serialize, Deserialize)]
36pub struct TraceCells {
37 pub preprocessed: Option<usize>,
38 pub cached_mains: Vec<usize>,
39 pub common_main: usize,
40 pub after_challenge: Vec<usize>,
41}
42
43impl Display for TraceMetrics {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate() {
46 writeln!(
47 f,
48 "trace_height_constraint_{i} | weighted_sum = {:<10} | threshold = {:<10}",
49 format_number_with_underscores(*weighted_sum),
50 format_number_with_underscores(*threshold)
51 )?;
52 }
53 for trace_metrics in &self.per_air {
54 writeln!(f, "{}", trace_metrics)?;
55 }
56 Ok(())
57 }
58}
59
60impl Display for SingleTraceMetrics {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(
63 f,
64 "{:<20} | Rows = {:<10} | Cells = {:<11} | Prep Cols = {:<5} | Main Cols = {:<5} | Perm Cols = {:<5}",
65 self.air_name, format_number_with_underscores(self.height), format_number_with_underscores(self.total_cells), self.width.preprocessed.unwrap_or(0),
66 format!("{:?}", self.width.main_widths()),
67 format!("{:?}",self.width.after_challenge),
68 )?;
69 Ok(())
70 }
71}
72
73pub fn trace_metrics<PB: ProverBackend>(
75 mpk: &DeviceMultiStarkProvingKeyView<PB>,
76 log_trace_heights: &[u8],
77) -> TraceMetrics {
78 let heights = log_trace_heights
79 .iter()
80 .map(|&h| 1usize << h)
81 .collect::<Vec<_>>();
82 let trace_height_inequalities = mpk
83 .trace_height_constraints
84 .iter()
85 .map(|trace_height_constraint| {
86 let weighted_sum = heights
87 .iter()
88 .enumerate()
89 .map(|(j, h)| {
90 let air_id = mpk.air_ids[j];
91 (trace_height_constraint.coefficients[air_id] as usize) * h
92 })
93 .sum::<usize>();
94 (weighted_sum, trace_height_constraint.threshold as usize)
95 })
96 .collect::<Vec<_>>();
97 let per_air: Vec<_> = zip_eq(mpk.air_ids.iter().copied(), zip_eq(&mpk.per_air, heights))
98 .map(|(air_id, (pk, height))| {
99 let air_name = &pk.air_name;
100 let mut width = pk.vk.params.width.clone();
101 let ext_degree = PB::CHALLENGE_EXT_DEGREE as usize;
102 for w in &mut width.after_challenge {
103 *w *= ext_degree;
104 }
105 let cells = TraceCells {
106 preprocessed: width.preprocessed.map(|w| w * height),
107 cached_mains: width.cached_mains.iter().map(|w| w * height).collect(),
108 common_main: width.common_main * height,
109 after_challenge: width.after_challenge.iter().map(|w| w * height).collect(),
110 };
111 let total_cells = cells
112 .cached_mains
113 .iter()
114 .chain([&cells.common_main])
115 .chain(cells.after_challenge.iter())
116 .sum::<usize>();
117 SingleTraceMetrics {
118 air_name: air_name.to_string(),
119 air_id,
120 height,
121 width,
122 cells,
123 total_cells,
124 quotient_poly_cells: height * (pk.vk.quotient_degree as usize) * ext_degree,
125 }
126 })
127 .collect();
128 let total_cells = per_air.iter().map(|m| m.total_cells).sum();
129 let metrics = TraceMetrics {
130 per_air,
131 total_cells,
132 trace_height_inequalities,
133 };
134 info!(
135 "total_trace_cells = {} (excluding preprocessed)",
136 format_number_with_underscores(metrics.total_cells)
137 );
138 info!(
139 "preprocessed_trace_cells = {}",
140 format_number_with_underscores(
141 metrics
142 .per_air
143 .iter()
144 .map(|m| m.cells.preprocessed.unwrap_or(0))
145 .sum::<usize>()
146 )
147 );
148 info!(
149 "main_trace_cells = {}",
150 format_number_with_underscores(
151 metrics
152 .per_air
153 .iter()
154 .map(|m| m.cells.cached_mains.iter().sum::<usize>() + m.cells.common_main)
155 .sum::<usize>()
156 )
157 );
158 info!(
159 "perm_trace_cells = {}",
160 format_number_with_underscores(
161 metrics
162 .per_air
163 .iter()
164 .map(|m| m.cells.after_challenge.iter().sum::<usize>())
165 .sum::<usize>()
166 )
167 );
168 info!(
169 "quotient_poly_cells = {}",
170 format_number_with_underscores(
171 metrics
172 .per_air
173 .iter()
174 .map(|m| m.quotient_poly_cells)
175 .sum::<usize>()
176 )
177 );
178 debug!("{}", metrics);
179 metrics
180}
181
182pub fn format_number_with_underscores(n: usize) -> String {
183 let num_str = n.to_string();
184 let mut result = String::new();
185
186 for (i, c) in num_str.chars().rev().enumerate() {
188 if i > 0 && i % 3 == 0 {
189 result.push('_');
190 }
191 result.push(c);
192 }
193
194 result.chars().rev().collect()
196}
197
198#[cfg(feature = "metrics")]
199mod emit {
200 use metrics::counter;
201
202 use super::{SingleTraceMetrics, TraceMetrics};
203
204 impl TraceMetrics {
205 pub fn emit(&self) {
206 for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate()
207 {
208 let labels = [("trace_height_constraint", i.to_string())];
209 counter!("weighted_sum", &labels).absolute(*weighted_sum as u64);
210 counter!("threshold", &labels).absolute(*threshold as u64);
211 }
212 for trace_metrics in &self.per_air {
213 trace_metrics.emit();
214 }
215 counter!("total_cells").absolute(self.total_cells as u64);
216 }
217 }
218
219 impl SingleTraceMetrics {
220 pub fn emit(&self) {
221 let labels = [
222 ("air_name", self.air_name.clone()),
223 ("air_id", self.air_id.to_string()),
224 ];
225 counter!("rows", &labels).absolute(self.height as u64);
226 counter!("cells", &labels).absolute(self.total_cells as u64);
227 counter!("prep_cols", &labels).absolute(self.width.preprocessed.unwrap_or(0) as u64);
228 counter!("main_cols", &labels).absolute(
229 (self.width.cached_mains.iter().sum::<usize>() + self.width.common_main) as u64,
230 );
231 counter!("perm_cols", &labels)
232 .absolute(self.width.after_challenge.iter().sum::<usize>() as u64);
233 }
234 }
235}