openvm_stark_backend/prover/
metrics.rs1use std::fmt::Display;
2
3use itertools::zip_eq;
4use serde::{Deserialize, Serialize};
5use tracing::{debug, info};
6
7use super::hal::ProverBackend;
8use crate::{keygen::types::TraceWidth, prover::types::DeviceMultiStarkProvingKeyView};
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct TraceMetrics {
12 pub per_air: Vec<SingleTraceMetrics>,
13 pub total_cells: usize,
15 pub trace_height_inequalities: Vec<(usize, usize)>,
17}
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct SingleTraceMetrics {
21 pub air_name: String,
22 pub height: usize,
23 pub width: TraceWidth,
25 pub cells: TraceCells,
26 pub total_cells: usize,
29 pub quotient_poly_cells: usize,
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
35pub struct TraceCells {
36 pub preprocessed: Option<usize>,
37 pub cached_mains: Vec<usize>,
38 pub common_main: usize,
39 pub after_challenge: Vec<usize>,
40}
41
42impl Display for TraceMetrics {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate() {
45 writeln!(
46 f,
47 "trace_height_constraint_{i} | weighted_sum = {:<10} | threshold = {:<10}",
48 format_number_with_underscores(*weighted_sum),
49 format_number_with_underscores(*threshold)
50 )?;
51 }
52 for trace_metrics in &self.per_air {
53 writeln!(f, "{}", trace_metrics)?;
54 }
55 Ok(())
56 }
57}
58
59impl Display for SingleTraceMetrics {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(
62 f,
63 "{:<20} | Rows = {:<10} | Cells = {:<11} | Prep Cols = {:<5} | Main Cols = {:<5} | Perm Cols = {:<5}",
64 self.air_name, format_number_with_underscores(self.height), format_number_with_underscores(self.total_cells), self.width.preprocessed.unwrap_or(0),
65 format!("{:?}", self.width.main_widths()),
66 format!("{:?}",self.width.after_challenge),
67 )?;
68 Ok(())
69 }
70}
71
72pub fn trace_metrics<PB: ProverBackend>(
74 mpk: &DeviceMultiStarkProvingKeyView<PB>,
75 log_trace_heights: &[u8],
76) -> TraceMetrics {
77 let heights = log_trace_heights
78 .iter()
79 .map(|&h| 1usize << h)
80 .collect::<Vec<_>>();
81 let trace_height_inequalities = mpk
82 .trace_height_constraints
83 .iter()
84 .map(|trace_height_constraint| {
85 let weighted_sum = heights
86 .iter()
87 .enumerate()
88 .map(|(j, h)| {
89 let air_id = mpk.air_ids[j];
90 (trace_height_constraint.coefficients[air_id] as usize) * h
91 })
92 .sum::<usize>();
93 (weighted_sum, trace_height_constraint.threshold as usize)
94 })
95 .collect::<Vec<_>>();
96 let per_air: Vec<_> = zip_eq(&mpk.per_air, heights)
97 .map(|(pk, height)| {
98 let air_name = &pk.air_name;
99 let mut width = pk.vk.params.width.clone();
100 let ext_degree = PB::CHALLENGE_EXT_DEGREE as usize;
101 for w in &mut width.after_challenge {
102 *w *= ext_degree;
103 }
104 let cells = TraceCells {
105 preprocessed: width.preprocessed.map(|w| w * height),
106 cached_mains: width.cached_mains.iter().map(|w| w * height).collect(),
107 common_main: width.common_main * height,
108 after_challenge: width.after_challenge.iter().map(|w| w * height).collect(),
109 };
110 let total_cells = cells
111 .cached_mains
112 .iter()
113 .chain([&cells.common_main])
114 .chain(cells.after_challenge.iter())
115 .sum::<usize>();
116 SingleTraceMetrics {
117 air_name: air_name.to_string(),
118 height,
119 width,
120 cells,
121 total_cells,
122 quotient_poly_cells: height * (pk.vk.quotient_degree as usize) * ext_degree,
123 }
124 })
125 .collect();
126 let total_cells = per_air.iter().map(|m| m.total_cells).sum();
127 let metrics = TraceMetrics {
128 per_air,
129 total_cells,
130 trace_height_inequalities,
131 };
132 info!(
133 "total_trace_cells = {} (excluding preprocessed)",
134 format_number_with_underscores(metrics.total_cells)
135 );
136 info!(
137 "preprocessed_trace_cells = {}",
138 format_number_with_underscores(
139 metrics
140 .per_air
141 .iter()
142 .map(|m| m.cells.preprocessed.unwrap_or(0))
143 .sum::<usize>()
144 )
145 );
146 info!(
147 "main_trace_cells = {}",
148 format_number_with_underscores(
149 metrics
150 .per_air
151 .iter()
152 .map(|m| m.cells.cached_mains.iter().sum::<usize>() + m.cells.common_main)
153 .sum::<usize>()
154 )
155 );
156 info!(
157 "perm_trace_cells = {}",
158 format_number_with_underscores(
159 metrics
160 .per_air
161 .iter()
162 .map(|m| m.cells.after_challenge.iter().sum::<usize>())
163 .sum::<usize>()
164 )
165 );
166 info!(
167 "quotient_poly_cells = {}",
168 format_number_with_underscores(
169 metrics
170 .per_air
171 .iter()
172 .map(|m| m.quotient_poly_cells)
173 .sum::<usize>()
174 )
175 );
176 debug!("{}", metrics);
177 metrics
178}
179
180pub fn format_number_with_underscores(n: usize) -> String {
181 let num_str = n.to_string();
182 let mut result = String::new();
183
184 for (i, c) in num_str.chars().rev().enumerate() {
186 if i > 0 && i % 3 == 0 {
187 result.push('_');
188 }
189 result.push(c);
190 }
191
192 result.chars().rev().collect()
194}
195
196#[cfg(feature = "metrics")]
197mod emit {
198 use metrics::counter;
199
200 use super::{SingleTraceMetrics, TraceMetrics};
201
202 impl TraceMetrics {
203 pub fn emit(&self) {
204 for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate()
205 {
206 let labels = [("trace_height_constraint", i.to_string())];
207 counter!("weighted_sum", &labels).absolute(*weighted_sum as u64);
208 counter!("threshold", &labels).absolute(*threshold as u64);
209 }
210 for trace_metrics in &self.per_air {
211 trace_metrics.emit();
212 }
213 counter!("total_cells").absolute(self.total_cells as u64);
214 }
215 }
216
217 impl SingleTraceMetrics {
218 pub fn emit(&self) {
219 let labels = [("air_name", self.air_name.clone())];
220 counter!("rows", &labels).absolute(self.height as u64);
221 counter!("cells", &labels).absolute(self.total_cells as u64);
222 counter!("prep_cols", &labels).absolute(self.width.preprocessed.unwrap_or(0) as u64);
223 counter!("main_cols", &labels).absolute(
224 (self.width.cached_mains.iter().sum::<usize>() + self.width.common_main) as u64,
225 );
226 counter!("perm_cols", &labels)
227 .absolute(self.width.after_challenge.iter().sum::<usize>() as u64);
228 }
229 }
230}