openvm_stark_backend/prover/
metrics.rs

1use std::fmt::Display;
2
3use itertools::zip_eq;
4use serde::{Deserialize, Serialize};
5use tracing::{debug, info};
6
7use super::hal::ProverBackend;
8use crate::{keygen::types::TraceWidth, prover::types::DeviceMultiStarkProvingKeyView};
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct TraceMetrics {
12    pub per_air: Vec<SingleTraceMetrics>,
13    /// Total base field cells from all traces, excludes preprocessed.
14    pub total_cells: usize,
15    /// For each trace height constraint, the (weighted sum, threshold)
16    pub trace_height_inequalities: Vec<(usize, usize)>,
17}
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct SingleTraceMetrics {
21    pub air_name: String,
22    pub air_id: usize,
23    pub height: usize,
24    /// The after challenge width is adjusted to be in terms of **base field** elements.
25    pub width: TraceWidth,
26    pub cells: TraceCells,
27    /// Omitting preprocessed trace, the total base field cells from main and after challenge
28    /// traces.
29    pub total_cells: usize,
30    /// Base field cells for evaluation of quotient polynomial on the quotient domain
31    pub quotient_poly_cells: usize,
32}
33
34/// Trace cells, counted in terms of number of **base field** elements.
35#[derive(Clone, Debug, Serialize, Deserialize)]
36pub struct TraceCells {
37    pub preprocessed: Option<usize>,
38    pub cached_mains: Vec<usize>,
39    pub common_main: usize,
40    pub after_challenge: Vec<usize>,
41}
42
43impl Display for TraceMetrics {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate() {
46            writeln!(
47                f,
48                "trace_height_constraint_{i} | weighted_sum = {:<10} | threshold = {:<10}",
49                format_number_with_underscores(*weighted_sum),
50                format_number_with_underscores(*threshold)
51            )?;
52        }
53        for trace_metrics in &self.per_air {
54            writeln!(f, "{}", trace_metrics)?;
55        }
56        Ok(())
57    }
58}
59
60impl Display for SingleTraceMetrics {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(
63            f,
64            "{:<20} | Rows = {:<10} | Cells = {:<11} | Prep Cols = {:<5} | Main Cols = {:<5} | Perm Cols = {:<5}",
65            self.air_name, format_number_with_underscores(self.height), format_number_with_underscores(self.total_cells), self.width.preprocessed.unwrap_or(0),
66            format!("{:?}", self.width.main_widths()),
67            format!("{:?}",self.width.after_challenge),
68        )?;
69        Ok(())
70    }
71}
72
73/// heights are the trace heights for each air
74pub fn trace_metrics<PB: ProverBackend>(
75    mpk: &DeviceMultiStarkProvingKeyView<PB>,
76    log_trace_heights: &[u8],
77) -> TraceMetrics {
78    let heights = log_trace_heights
79        .iter()
80        .map(|&h| 1usize << h)
81        .collect::<Vec<_>>();
82    let trace_height_inequalities = mpk
83        .trace_height_constraints
84        .iter()
85        .map(|trace_height_constraint| {
86            let weighted_sum = heights
87                .iter()
88                .enumerate()
89                .map(|(j, h)| {
90                    let air_id = mpk.air_ids[j];
91                    (trace_height_constraint.coefficients[air_id] as usize) * h
92                })
93                .sum::<usize>();
94            (weighted_sum, trace_height_constraint.threshold as usize)
95        })
96        .collect::<Vec<_>>();
97    let per_air: Vec<_> = zip_eq(mpk.air_ids.iter().copied(), zip_eq(&mpk.per_air, heights))
98        .map(|(air_id, (pk, height))| {
99            let air_name = &pk.air_name;
100            let mut width = pk.vk.params.width.clone();
101            let ext_degree = PB::CHALLENGE_EXT_DEGREE as usize;
102            for w in &mut width.after_challenge {
103                *w *= ext_degree;
104            }
105            let cells = TraceCells {
106                preprocessed: width.preprocessed.map(|w| w * height),
107                cached_mains: width.cached_mains.iter().map(|w| w * height).collect(),
108                common_main: width.common_main * height,
109                after_challenge: width.after_challenge.iter().map(|w| w * height).collect(),
110            };
111            let total_cells = cells
112                .cached_mains
113                .iter()
114                .chain([&cells.common_main])
115                .chain(cells.after_challenge.iter())
116                .sum::<usize>();
117            SingleTraceMetrics {
118                air_name: air_name.to_string(),
119                air_id,
120                height,
121                width,
122                cells,
123                total_cells,
124                quotient_poly_cells: height * (pk.vk.quotient_degree as usize) * ext_degree,
125            }
126        })
127        .collect();
128    let total_cells = per_air.iter().map(|m| m.total_cells).sum();
129    let metrics = TraceMetrics {
130        per_air,
131        total_cells,
132        trace_height_inequalities,
133    };
134    info!(
135        "total_trace_cells = {} (excluding preprocessed)",
136        format_number_with_underscores(metrics.total_cells)
137    );
138    info!(
139        "preprocessed_trace_cells = {}",
140        format_number_with_underscores(
141            metrics
142                .per_air
143                .iter()
144                .map(|m| m.cells.preprocessed.unwrap_or(0))
145                .sum::<usize>()
146        )
147    );
148    info!(
149        "main_trace_cells = {}",
150        format_number_with_underscores(
151            metrics
152                .per_air
153                .iter()
154                .map(|m| m.cells.cached_mains.iter().sum::<usize>() + m.cells.common_main)
155                .sum::<usize>()
156        )
157    );
158    info!(
159        "perm_trace_cells = {}",
160        format_number_with_underscores(
161            metrics
162                .per_air
163                .iter()
164                .map(|m| m.cells.after_challenge.iter().sum::<usize>())
165                .sum::<usize>()
166        )
167    );
168    info!(
169        "quotient_poly_cells = {}",
170        format_number_with_underscores(
171            metrics
172                .per_air
173                .iter()
174                .map(|m| m.quotient_poly_cells)
175                .sum::<usize>()
176        )
177    );
178    debug!("{}", metrics);
179    metrics
180}
181
182pub fn format_number_with_underscores(n: usize) -> String {
183    let num_str = n.to_string();
184    let mut result = String::new();
185
186    // Start adding characters from the end of num_str
187    for (i, c) in num_str.chars().rev().enumerate() {
188        if i > 0 && i % 3 == 0 {
189            result.push('_');
190        }
191        result.push(c);
192    }
193
194    // Reverse the result to get the correct order
195    result.chars().rev().collect()
196}
197
198#[cfg(feature = "metrics")]
199mod emit {
200    use metrics::counter;
201
202    use super::{SingleTraceMetrics, TraceMetrics};
203
204    impl TraceMetrics {
205        pub fn emit(&self) {
206            for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate()
207            {
208                let labels = [("trace_height_constraint", i.to_string())];
209                counter!("weighted_sum", &labels).absolute(*weighted_sum as u64);
210                counter!("threshold", &labels).absolute(*threshold as u64);
211            }
212            for trace_metrics in &self.per_air {
213                trace_metrics.emit();
214            }
215            counter!("total_cells").absolute(self.total_cells as u64);
216        }
217    }
218
219    impl SingleTraceMetrics {
220        pub fn emit(&self) {
221            let labels = [
222                ("air_name", self.air_name.clone()),
223                ("air_id", self.air_id.to_string()),
224            ];
225            counter!("rows", &labels).absolute(self.height as u64);
226            counter!("cells", &labels).absolute(self.total_cells as u64);
227            counter!("prep_cols", &labels).absolute(self.width.preprocessed.unwrap_or(0) as u64);
228            counter!("main_cols", &labels).absolute(
229                (self.width.cached_mains.iter().sum::<usize>() + self.width.common_main) as u64,
230            );
231            counter!("perm_cols", &labels)
232                .absolute(self.width.after_challenge.iter().sum::<usize>() as u64);
233        }
234    }
235}