openvm_stark_backend/prover/
metrics.rs

1use std::fmt::Display;
2
3use itertools::zip_eq;
4use serde::{Deserialize, Serialize};
5use tracing::{debug, info};
6
7use super::hal::ProverBackend;
8use crate::{keygen::types::TraceWidth, prover::types::DeviceMultiStarkProvingKeyView};
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct TraceMetrics {
12    pub per_air: Vec<SingleTraceMetrics>,
13    /// Total base field cells from all traces, excludes preprocessed.
14    pub total_cells: usize,
15    /// For each trace height constraint, the (weighted sum, threshold)
16    pub trace_height_inequalities: Vec<(usize, usize)>,
17}
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct SingleTraceMetrics {
21    pub air_name: String,
22    pub height: usize,
23    /// The after challenge width is adjusted to be in terms of **base field** elements.
24    pub width: TraceWidth,
25    pub cells: TraceCells,
26    /// Omitting preprocessed trace, the total base field cells from main and after challenge
27    /// traces.
28    pub total_cells: usize,
29    /// Base field cells for evaluation of quotient polynomial on the quotient domain
30    pub quotient_poly_cells: usize,
31}
32
33/// Trace cells, counted in terms of number of **base field** elements.
34#[derive(Clone, Debug, Serialize, Deserialize)]
35pub struct TraceCells {
36    pub preprocessed: Option<usize>,
37    pub cached_mains: Vec<usize>,
38    pub common_main: usize,
39    pub after_challenge: Vec<usize>,
40}
41
42impl Display for TraceMetrics {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate() {
45            writeln!(
46                f,
47                "trace_height_constraint_{i} | weighted_sum = {:<10} | threshold = {:<10}",
48                format_number_with_underscores(*weighted_sum),
49                format_number_with_underscores(*threshold)
50            )?;
51        }
52        for trace_metrics in &self.per_air {
53            writeln!(f, "{}", trace_metrics)?;
54        }
55        Ok(())
56    }
57}
58
59impl Display for SingleTraceMetrics {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(
62            f,
63            "{:<20} | Rows = {:<10} | Cells = {:<11} | Prep Cols = {:<5} | Main Cols = {:<5} | Perm Cols = {:<5}",
64            self.air_name, format_number_with_underscores(self.height), format_number_with_underscores(self.total_cells), self.width.preprocessed.unwrap_or(0),
65            format!("{:?}", self.width.main_widths()),
66            format!("{:?}",self.width.after_challenge),
67        )?;
68        Ok(())
69    }
70}
71
72/// heights are the trace heights for each air
73pub fn trace_metrics<PB: ProverBackend>(
74    mpk: &DeviceMultiStarkProvingKeyView<PB>,
75    log_trace_heights: &[u8],
76) -> TraceMetrics {
77    let heights = log_trace_heights
78        .iter()
79        .map(|&h| 1usize << h)
80        .collect::<Vec<_>>();
81    let trace_height_inequalities = mpk
82        .trace_height_constraints
83        .iter()
84        .map(|trace_height_constraint| {
85            let weighted_sum = heights
86                .iter()
87                .enumerate()
88                .map(|(j, h)| {
89                    let air_id = mpk.air_ids[j];
90                    (trace_height_constraint.coefficients[air_id] as usize) * h
91                })
92                .sum::<usize>();
93            (weighted_sum, trace_height_constraint.threshold as usize)
94        })
95        .collect::<Vec<_>>();
96    let per_air: Vec<_> = zip_eq(&mpk.per_air, heights)
97        .map(|(pk, height)| {
98            let air_name = &pk.air_name;
99            let mut width = pk.vk.params.width.clone();
100            let ext_degree = PB::CHALLENGE_EXT_DEGREE as usize;
101            for w in &mut width.after_challenge {
102                *w *= ext_degree;
103            }
104            let cells = TraceCells {
105                preprocessed: width.preprocessed.map(|w| w * height),
106                cached_mains: width.cached_mains.iter().map(|w| w * height).collect(),
107                common_main: width.common_main * height,
108                after_challenge: width.after_challenge.iter().map(|w| w * height).collect(),
109            };
110            let total_cells = cells
111                .cached_mains
112                .iter()
113                .chain([&cells.common_main])
114                .chain(cells.after_challenge.iter())
115                .sum::<usize>();
116            SingleTraceMetrics {
117                air_name: air_name.to_string(),
118                height,
119                width,
120                cells,
121                total_cells,
122                quotient_poly_cells: height * (pk.vk.quotient_degree as usize) * ext_degree,
123            }
124        })
125        .collect();
126    let total_cells = per_air.iter().map(|m| m.total_cells).sum();
127    let metrics = TraceMetrics {
128        per_air,
129        total_cells,
130        trace_height_inequalities,
131    };
132    info!(
133        "total_trace_cells = {} (excluding preprocessed)",
134        format_number_with_underscores(metrics.total_cells)
135    );
136    info!(
137        "preprocessed_trace_cells = {}",
138        format_number_with_underscores(
139            metrics
140                .per_air
141                .iter()
142                .map(|m| m.cells.preprocessed.unwrap_or(0))
143                .sum::<usize>()
144        )
145    );
146    info!(
147        "main_trace_cells = {}",
148        format_number_with_underscores(
149            metrics
150                .per_air
151                .iter()
152                .map(|m| m.cells.cached_mains.iter().sum::<usize>() + m.cells.common_main)
153                .sum::<usize>()
154        )
155    );
156    info!(
157        "perm_trace_cells = {}",
158        format_number_with_underscores(
159            metrics
160                .per_air
161                .iter()
162                .map(|m| m.cells.after_challenge.iter().sum::<usize>())
163                .sum::<usize>()
164        )
165    );
166    info!(
167        "quotient_poly_cells = {}",
168        format_number_with_underscores(
169            metrics
170                .per_air
171                .iter()
172                .map(|m| m.quotient_poly_cells)
173                .sum::<usize>()
174        )
175    );
176    debug!("{}", metrics);
177    metrics
178}
179
180pub fn format_number_with_underscores(n: usize) -> String {
181    let num_str = n.to_string();
182    let mut result = String::new();
183
184    // Start adding characters from the end of num_str
185    for (i, c) in num_str.chars().rev().enumerate() {
186        if i > 0 && i % 3 == 0 {
187            result.push('_');
188        }
189        result.push(c);
190    }
191
192    // Reverse the result to get the correct order
193    result.chars().rev().collect()
194}
195
196#[cfg(feature = "metrics")]
197mod emit {
198    use metrics::counter;
199
200    use super::{SingleTraceMetrics, TraceMetrics};
201
202    impl TraceMetrics {
203        pub fn emit(&self) {
204            for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate()
205            {
206                let labels = [("trace_height_constraint", i.to_string())];
207                counter!("weighted_sum", &labels).absolute(*weighted_sum as u64);
208                counter!("threshold", &labels).absolute(*threshold as u64);
209            }
210            for trace_metrics in &self.per_air {
211                trace_metrics.emit();
212            }
213            counter!("total_cells").absolute(self.total_cells as u64);
214        }
215    }
216
217    impl SingleTraceMetrics {
218        pub fn emit(&self) {
219            let labels = [("air_name", self.air_name.clone())];
220            counter!("rows", &labels).absolute(self.height as u64);
221            counter!("cells", &labels).absolute(self.total_cells as u64);
222            counter!("prep_cols", &labels).absolute(self.width.preprocessed.unwrap_or(0) as u64);
223            counter!("main_cols", &labels).absolute(
224                (self.width.cached_mains.iter().sum::<usize>() + self.width.common_main) as u64,
225            );
226            counter!("perm_cols", &labels)
227                .absolute(self.width.after_challenge.iter().sum::<usize>() as u64);
228        }
229    }
230}