openvm_stark_backend/air_builders/symbolic/
statistics.rs

1use std::{fs::File, io::Write, path::Path};
2
3use serde::{Deserialize, Serialize};
4
5use crate::air_builders::symbolic::{
6    SymbolicConstraintsDag, SymbolicExpressionDag, SymbolicExpressionNode,
7};
8
9#[derive(Debug, Default, Clone, Serialize, Deserialize)]
10pub struct AirStatistics {
11    pub air_name: String,
12    pub num_nodes: usize,
13    pub num_interactions: usize,
14    pub num_constraints: usize,
15    pub max_constraint_depth: usize,
16    pub average_constraint_depth: f64,
17    pub num_constants: usize,
18    pub num_variables: usize,
19    pub num_intermediates: usize,
20    pub max_intermediate_use: usize,
21    pub average_intermediate_use: f64,
22}
23
24#[derive(Debug, Default, Clone, Serialize, Deserialize)]
25pub struct NodeInfo {
26    pub depth: usize,
27    pub uses: usize,
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct AirStatisticsGenerator {
32    pub stats: Vec<AirStatistics>,
33}
34
35impl AirStatisticsGenerator {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    pub fn generate<F>(&mut self, name: String, dag: &SymbolicConstraintsDag<F>) {
41        let mut stats = AirStatistics {
42            air_name: name,
43            num_nodes: dag.constraints.nodes.len(),
44            num_interactions: dag.interactions.len(),
45            num_constraints: dag.constraints.constraint_idx.len(),
46            ..Default::default()
47        };
48        let mut node_info = vec![NodeInfo::default(); stats.num_nodes];
49        for (i, node) in dag.constraints.nodes.iter().enumerate() {
50            match node {
51                SymbolicExpressionNode::Variable(_) => {
52                    node_info[i].uses += 1;
53                    node_info[i].depth = 1;
54                    stats.num_variables += 1;
55                }
56                SymbolicExpressionNode::Constant(_) => {
57                    node_info[i].uses += 1;
58                    node_info[i].depth = 1;
59                    stats.num_constants += 1;
60                }
61                SymbolicExpressionNode::Add {
62                    left_idx,
63                    right_idx,
64                    ..
65                } => {
66                    node_info[i].uses += 1;
67                    node_info[*left_idx].uses += 1;
68                    node_info[*right_idx].uses += 1;
69                    node_info[i].depth =
70                        node_info[*left_idx].depth.max(node_info[*right_idx].depth) + 1;
71                    stats.num_intermediates += 1;
72                }
73                SymbolicExpressionNode::Sub {
74                    left_idx,
75                    right_idx,
76                    ..
77                } => {
78                    node_info[i].uses += 1;
79                    node_info[*left_idx].uses += 1;
80                    node_info[*right_idx].uses += 1;
81                    node_info[i].depth =
82                        node_info[*left_idx].depth.max(node_info[*right_idx].depth) + 1;
83                    stats.num_intermediates += 1;
84                }
85                SymbolicExpressionNode::Mul {
86                    left_idx,
87                    right_idx,
88                    ..
89                } => {
90                    node_info[i].uses += 1;
91                    node_info[*left_idx].uses += 1;
92                    node_info[*right_idx].uses += 1;
93                    node_info[i].depth =
94                        node_info[*left_idx].depth.max(node_info[*right_idx].depth) + 1;
95                    stats.num_intermediates += 1;
96                }
97                SymbolicExpressionNode::Neg { idx, .. } => {
98                    node_info[i].uses += 1;
99                    node_info[*idx].uses += 1;
100                    node_info[i].depth = node_info[*idx].depth + 1;
101                    stats.num_intermediates += 1;
102                }
103                _ => {
104                    node_info[i].uses += 1;
105                    node_info[i].depth = 1;
106                }
107            }
108        }
109
110        for node in node_info.iter().filter(|node| node.depth > 1) {
111            stats.max_intermediate_use = stats.max_intermediate_use.max(node.uses);
112            stats.average_intermediate_use += node.uses as f64;
113        }
114        stats.average_intermediate_use /= stats.num_intermediates as f64;
115
116        for constraint_idx in &dag.constraints.constraint_idx {
117            stats.max_constraint_depth = stats
118                .max_constraint_depth
119                .max(node_info[*constraint_idx].depth);
120            stats.average_constraint_depth += node_info[*constraint_idx].depth as f64;
121        }
122        stats.average_constraint_depth /= stats.num_constraints as f64;
123
124        self.stats.push(stats);
125    }
126
127    pub fn write_json<P: AsRef<Path>>(&self, file_path: P) -> eyre::Result<()> {
128        serde_json::to_writer_pretty(File::create(file_path)?, &self.stats)?;
129        Ok(())
130    }
131
132    pub fn write_csv<P: AsRef<Path>>(&self, file_path: P) -> eyre::Result<()> {
133        let mut file = File::create(file_path)?;
134        writeln!(
135            file,
136            "air_name,num_nodes,num_interactions,num_constraints,max_constraint_depth,average_constraint_depth,num_constants,num_variables,num_intermediates,max_intermediate_use,average_intermediate_use"
137        )?;
138        for stat in &self.stats {
139            writeln!(
140                file,
141                "{},{},{},{},{},{},{},{},{},{},{}",
142                stat.air_name,
143                stat.num_nodes,
144                stat.num_interactions,
145                stat.num_constraints,
146                stat.max_constraint_depth,
147                stat.average_constraint_depth,
148                stat.num_constants,
149                stat.num_variables,
150                stat.num_intermediates,
151                stat.max_intermediate_use,
152                stat.average_intermediate_use
153            )?;
154        }
155        Ok(())
156    }
157
158    pub fn print_dag<F: std::fmt::Debug>(dag: &SymbolicExpressionDag<F>) {
159        for (idx, node) in dag.nodes.iter().enumerate() {
160            println!("  Node {}: {:?}", idx, node);
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use p3_baby_bear::BabyBear;
168    use p3_field::FieldAlgebra;
169
170    use crate::{
171        air_builders::symbolic::{
172            build_symbolic_constraints_dag,
173            statistics::AirStatisticsGenerator,
174            symbolic_expression::SymbolicExpression,
175            symbolic_variable::{Entry, SymbolicVariable},
176        },
177        interaction::Interaction,
178    };
179
180    type F = BabyBear;
181
182    #[test]
183    fn test_dag_statistics() {
184        let expr1 = SymbolicExpression::Variable(SymbolicVariable::new(
185            Entry::Main {
186                part_index: 0,
187                offset: 0,
188            },
189            1,
190        ));
191        let expr2 = SymbolicExpression::Variable(SymbolicVariable::new(
192            Entry::Main {
193                part_index: 1,
194                offset: 1,
195            },
196            2,
197        ));
198        let expr3 = expr1.clone() + expr2.clone();
199        let expr4 = expr3.clone() * SymbolicExpression::Constant(F::TWO);
200        let expr5 = SymbolicExpression::IsFirstRow + expr4.clone();
201
202        let constraints = vec![expr5, expr4.clone()];
203        let interactions = vec![Interaction {
204            bus_index: 0,
205            message: vec![expr1, expr2],
206            count: SymbolicExpression::Constant(F::ONE),
207            count_weight: 1,
208        }];
209
210        let dag = build_symbolic_constraints_dag(&constraints, &interactions);
211        AirStatisticsGenerator::print_dag(&dag.constraints);
212
213        let mut generator = AirStatisticsGenerator::new();
214        generator.generate("test".to_string(), &dag);
215        println!("{:?}", generator.stats);
216
217        assert_eq!(generator.stats[0].num_nodes, 8);
218        assert_eq!(generator.stats[0].num_interactions, 1);
219        assert_eq!(generator.stats[0].num_constraints, 2);
220        assert_eq!(generator.stats[0].max_constraint_depth, 4);
221        assert_eq!(generator.stats[0].average_constraint_depth, 3.5);
222        assert_eq!(generator.stats[0].num_constants, 2);
223        assert_eq!(generator.stats[0].num_variables, 2);
224        assert_eq!(generator.stats[0].num_intermediates, 3);
225        assert_eq!(generator.stats[0].max_intermediate_use, 2);
226        assert_eq!(generator.stats[0].average_intermediate_use, 5f64 / 3f64);
227    }
228}