openvm_stark_backend/prover/
metrics.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
use std::fmt::Display;

use itertools::Itertools;
use p3_field::AbstractExtensionField;
use serde::{Deserialize, Serialize};

use crate::{
    config::{StarkGenericConfig, Val},
    keygen::types::{StarkProvingKey, TraceWidth},
};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TraceMetrics {
    pub per_air: Vec<SingleTraceMetrics>,
    /// Total base field cells from all traces, excludes preprocessed.
    pub total_cells: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SingleTraceMetrics {
    pub air_name: String,
    pub height: usize,
    /// The after challenge width is adjusted to be in terms of **base field** elements.
    pub width: TraceWidth,
    pub cells: TraceCells,
    /// Omitting preprocessed trace, the total base field cells from main and after challenge
    /// traces.
    pub total_cells: usize,
}

/// Trace cells, counted in terms of number of **base field** elements.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TraceCells {
    pub preprocessed: Option<usize>,
    pub cached_mains: Vec<usize>,
    pub common_main: usize,
    pub after_challenge: Vec<usize>,
}

impl Display for TraceMetrics {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        writeln!(
            f,
            "Total Cells: {} (excluding preprocessed)",
            format_number_with_underscores(self.total_cells)
        )?;
        for trace_metrics in &self.per_air {
            writeln!(f, "{}", trace_metrics)?;
        }
        Ok(())
    }
}

impl Display for SingleTraceMetrics {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "{:<20} | Rows = {:<10} | Cells = {:<11} | Prep Cols = {:<5} | Main Cols = {:<5} | Perm Cols = {:<5}",
            self.air_name, format_number_with_underscores(self.height), format_number_with_underscores(self.total_cells), self.width.preprocessed.unwrap_or(0),
            format!("{:?}", self.width.main_widths()),
            format!("{:?}",self.width.after_challenge),
        )?;
        Ok(())
    }
}

/// heights are the trace heights for each air
pub fn trace_metrics<SC: StarkGenericConfig>(
    pk: &[&StarkProvingKey<SC>],
    heights: &[usize],
) -> TraceMetrics {
    let per_air: Vec<_> = pk
        .iter()
        .zip_eq(heights)
        .map(|(pk, &height)| {
            let air_name = pk.air_name.clone();
            let mut width = pk.vk.params.width.clone();
            let ext_degree = <SC::Challenge as AbstractExtensionField<Val<SC>>>::D;
            for w in &mut width.after_challenge {
                *w *= ext_degree;
            }
            let cells = TraceCells {
                preprocessed: width.preprocessed.map(|w| w * height),
                cached_mains: width.cached_mains.iter().map(|w| w * height).collect(),
                common_main: width.common_main * height,
                after_challenge: width.after_challenge.iter().map(|w| w * height).collect(),
            };
            let total_cells = cells
                .cached_mains
                .iter()
                .chain([&cells.common_main])
                .chain(cells.after_challenge.iter())
                .sum::<usize>();
            SingleTraceMetrics {
                air_name,
                height,
                width,
                cells,
                total_cells,
            }
        })
        .collect();
    let total_cells = per_air.iter().map(|m| m.total_cells).sum();
    TraceMetrics {
        per_air,
        total_cells,
    }
}

pub fn format_number_with_underscores(n: usize) -> String {
    let num_str = n.to_string();
    let mut result = String::new();

    // Start adding characters from the end of num_str
    for (i, c) in num_str.chars().rev().enumerate() {
        if i > 0 && i % 3 == 0 {
            result.push('_');
        }
        result.push(c);
    }

    // Reverse the result to get the correct order
    result.chars().rev().collect()
}

#[cfg(feature = "bench-metrics")]
mod emit {
    use metrics::counter;

    use super::{SingleTraceMetrics, TraceMetrics};

    impl TraceMetrics {
        pub fn emit(&self) {
            for trace_metrics in &self.per_air {
                trace_metrics.emit();
            }
            counter!("total_cells").absolute(self.total_cells as u64);
        }
    }

    impl SingleTraceMetrics {
        pub fn emit(&self) {
            let labels = [("air_name", self.air_name.clone())];
            counter!("rows", &labels).absolute(self.height as u64);
            counter!("cells", &labels).absolute(self.total_cells as u64);
            counter!("prep_cols", &labels).absolute(self.width.preprocessed.unwrap_or(0) as u64);
            counter!("main_cols", &labels).absolute(
                (self.width.cached_mains.iter().sum::<usize>() + self.width.common_main) as u64,
            );
            counter!("perm_cols", &labels)
                .absolute(self.width.after_challenge.iter().sum::<usize>() as u64);
        }
    }
}