openvm_stark_backend/prover/
metrics.rs

1use std::fmt::Display;
2
3use itertools::zip_eq;
4use serde::{Deserialize, Serialize};
5use tracing::info;
6
7use super::{hal::ProverBackend, types::DeviceMultiStarkProvingKey};
8use crate::keygen::types::TraceWidth;
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct TraceMetrics {
12    pub per_air: Vec<SingleTraceMetrics>,
13    /// Total base field cells from all traces, excludes preprocessed.
14    pub total_cells: usize,
15    /// For each trace height constraint, the (weighted sum, threshold)
16    pub trace_height_inequalities: Vec<(usize, usize)>,
17}
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct SingleTraceMetrics {
21    pub air_name: String,
22    pub height: usize,
23    /// The after challenge width is adjusted to be in terms of **base field** elements.
24    pub width: TraceWidth,
25    pub cells: TraceCells,
26    /// Omitting preprocessed trace, the total base field cells from main and after challenge
27    /// traces.
28    pub total_cells: usize,
29}
30
31/// Trace cells, counted in terms of number of **base field** elements.
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct TraceCells {
34    pub preprocessed: Option<usize>,
35    pub cached_mains: Vec<usize>,
36    pub common_main: usize,
37    pub after_challenge: Vec<usize>,
38}
39
40impl Display for TraceMetrics {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        writeln!(
43            f,
44            "total_trace_cells = {} (excluding preprocessed)",
45            format_number_with_underscores(self.total_cells)
46        )?;
47        writeln!(
48            f,
49            "preprocessed_trace_cells = {}",
50            format_number_with_underscores(
51                self.per_air
52                    .iter()
53                    .map(|m| m.cells.preprocessed.unwrap_or(0))
54                    .sum::<usize>()
55            )
56        )?;
57        writeln!(
58            f,
59            "main_trace_cells = {}",
60            format_number_with_underscores(
61                self.per_air
62                    .iter()
63                    .map(|m| m.cells.cached_mains.iter().sum::<usize>() + m.cells.common_main)
64                    .sum::<usize>()
65            )
66        )?;
67        writeln!(
68            f,
69            "perm_trace_cells = {}",
70            format_number_with_underscores(
71                self.per_air
72                    .iter()
73                    .map(|m| m.cells.after_challenge.iter().sum::<usize>())
74                    .sum::<usize>()
75            )
76        )?;
77        for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate() {
78            writeln!(
79                f,
80                "trace_height_constraint_{i} | weighted_sum = {:<10} | threshold = {:<10}",
81                format_number_with_underscores(*weighted_sum),
82                format_number_with_underscores(*threshold)
83            )?;
84        }
85        for trace_metrics in &self.per_air {
86            writeln!(f, "{}", trace_metrics)?;
87        }
88        Ok(())
89    }
90}
91
92impl Display for SingleTraceMetrics {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        write!(
95            f,
96            "{:<20} | Rows = {:<10} | Cells = {:<11} | Prep Cols = {:<5} | Main Cols = {:<5} | Perm Cols = {:<5}",
97            self.air_name, format_number_with_underscores(self.height), format_number_with_underscores(self.total_cells), self.width.preprocessed.unwrap_or(0),
98            format!("{:?}", self.width.main_widths()),
99            format!("{:?}",self.width.after_challenge),
100        )?;
101        Ok(())
102    }
103}
104
105/// heights are the trace heights for each air
106pub fn trace_metrics<PB: ProverBackend>(
107    mpk: &DeviceMultiStarkProvingKey<PB>,
108    log_trace_heights: &[u8],
109) -> TraceMetrics {
110    let heights = log_trace_heights
111        .iter()
112        .map(|&h| 1usize << h)
113        .collect::<Vec<_>>();
114    let trace_height_inequalities = mpk
115        .trace_height_constraints
116        .iter()
117        .map(|trace_height_constraint| {
118            let weighted_sum = heights
119                .iter()
120                .enumerate()
121                .map(|(j, h)| {
122                    let air_id = mpk.air_ids[j];
123                    (trace_height_constraint.coefficients[air_id] as usize) * h
124                })
125                .sum::<usize>();
126            (weighted_sum, trace_height_constraint.threshold as usize)
127        })
128        .collect::<Vec<_>>();
129    let per_air: Vec<_> = zip_eq(&mpk.per_air, heights)
130        .map(|(pk, height)| {
131            let air_name = pk.air_name;
132            let mut width = pk.vk.params.width.clone();
133            let ext_degree = PB::CHALLENGE_EXT_DEGREE as usize;
134            for w in &mut width.after_challenge {
135                *w *= ext_degree;
136            }
137            let cells = TraceCells {
138                preprocessed: width.preprocessed.map(|w| w * height),
139                cached_mains: width.cached_mains.iter().map(|w| w * height).collect(),
140                common_main: width.common_main * height,
141                after_challenge: width.after_challenge.iter().map(|w| w * height).collect(),
142            };
143            let total_cells = cells
144                .cached_mains
145                .iter()
146                .chain([&cells.common_main])
147                .chain(cells.after_challenge.iter())
148                .sum::<usize>();
149            SingleTraceMetrics {
150                air_name: air_name.to_string(),
151                height,
152                width,
153                cells,
154                total_cells,
155            }
156        })
157        .collect();
158    let total_cells = per_air.iter().map(|m| m.total_cells).sum();
159    let metrics = TraceMetrics {
160        per_air,
161        total_cells,
162        trace_height_inequalities,
163    };
164    info!("{}", metrics);
165    metrics
166}
167
168pub fn format_number_with_underscores(n: usize) -> String {
169    let num_str = n.to_string();
170    let mut result = String::new();
171
172    // Start adding characters from the end of num_str
173    for (i, c) in num_str.chars().rev().enumerate() {
174        if i > 0 && i % 3 == 0 {
175            result.push('_');
176        }
177        result.push(c);
178    }
179
180    // Reverse the result to get the correct order
181    result.chars().rev().collect()
182}
183
184#[cfg(feature = "bench-metrics")]
185mod emit {
186    use metrics::counter;
187
188    use super::{SingleTraceMetrics, TraceMetrics};
189
190    impl TraceMetrics {
191        pub fn emit(&self) {
192            for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate()
193            {
194                let labels = [("trace_height_constraint", i.to_string())];
195                counter!("weighted_sum", &labels).absolute(*weighted_sum as u64);
196                counter!("threshold", &labels).absolute(*threshold as u64);
197            }
198            for trace_metrics in &self.per_air {
199                trace_metrics.emit();
200            }
201            counter!("total_cells").absolute(self.total_cells as u64);
202        }
203    }
204
205    impl SingleTraceMetrics {
206        pub fn emit(&self) {
207            let labels = [("air_name", self.air_name.clone())];
208            counter!("rows", &labels).absolute(self.height as u64);
209            counter!("cells", &labels).absolute(self.total_cells as u64);
210            counter!("prep_cols", &labels).absolute(self.width.preprocessed.unwrap_or(0) as u64);
211            counter!("main_cols", &labels).absolute(
212                (self.width.cached_mains.iter().sum::<usize>() + self.width.common_main) as u64,
213            );
214            counter!("perm_cols", &labels)
215                .absolute(self.width.after_challenge.iter().sum::<usize>() as u64);
216        }
217    }
218}