openvm_stark_backend/air_builders/symbolic/
statistics.rs1use std::{fs::File, io::Write, path::Path};
2
3use serde::{Deserialize, Serialize};
4
5use crate::air_builders::symbolic::{
6 SymbolicConstraintsDag, SymbolicExpressionDag, SymbolicExpressionNode,
7};
8
9#[derive(Debug, Default, Clone, Serialize, Deserialize)]
10pub struct AirStatistics {
11 pub air_name: String,
12 pub num_nodes: usize,
13 pub num_interactions: usize,
14 pub num_constraints: usize,
15 pub max_constraint_depth: usize,
16 pub average_constraint_depth: f64,
17 pub num_constants: usize,
18 pub num_variables: usize,
19 pub num_intermediates: usize,
20 pub max_intermediate_use: usize,
21 pub average_intermediate_use: f64,
22}
23
24#[derive(Debug, Default, Clone, Serialize, Deserialize)]
25pub struct NodeInfo {
26 pub depth: usize,
27 pub uses: usize,
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct AirStatisticsGenerator {
32 pub stats: Vec<AirStatistics>,
33}
34
35impl AirStatisticsGenerator {
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn generate<F>(&mut self, name: String, dag: &SymbolicConstraintsDag<F>) {
41 let mut stats = AirStatistics {
42 air_name: name,
43 num_nodes: dag.constraints.nodes.len(),
44 num_interactions: dag.interactions.len(),
45 num_constraints: dag.constraints.constraint_idx.len(),
46 ..Default::default()
47 };
48 let mut node_info = vec![NodeInfo::default(); stats.num_nodes];
49 for (i, node) in dag.constraints.nodes.iter().enumerate() {
50 match node {
51 SymbolicExpressionNode::Variable(_) => {
52 node_info[i].uses += 1;
53 node_info[i].depth = 1;
54 stats.num_variables += 1;
55 }
56 SymbolicExpressionNode::Constant(_) => {
57 node_info[i].uses += 1;
58 node_info[i].depth = 1;
59 stats.num_constants += 1;
60 }
61 SymbolicExpressionNode::Add {
62 left_idx,
63 right_idx,
64 ..
65 } => {
66 node_info[i].uses += 1;
67 node_info[*left_idx].uses += 1;
68 node_info[*right_idx].uses += 1;
69 node_info[i].depth =
70 node_info[*left_idx].depth.max(node_info[*right_idx].depth) + 1;
71 stats.num_intermediates += 1;
72 }
73 SymbolicExpressionNode::Sub {
74 left_idx,
75 right_idx,
76 ..
77 } => {
78 node_info[i].uses += 1;
79 node_info[*left_idx].uses += 1;
80 node_info[*right_idx].uses += 1;
81 node_info[i].depth =
82 node_info[*left_idx].depth.max(node_info[*right_idx].depth) + 1;
83 stats.num_intermediates += 1;
84 }
85 SymbolicExpressionNode::Mul {
86 left_idx,
87 right_idx,
88 ..
89 } => {
90 node_info[i].uses += 1;
91 node_info[*left_idx].uses += 1;
92 node_info[*right_idx].uses += 1;
93 node_info[i].depth =
94 node_info[*left_idx].depth.max(node_info[*right_idx].depth) + 1;
95 stats.num_intermediates += 1;
96 }
97 SymbolicExpressionNode::Neg { idx, .. } => {
98 node_info[i].uses += 1;
99 node_info[*idx].uses += 1;
100 node_info[i].depth = node_info[*idx].depth + 1;
101 stats.num_intermediates += 1;
102 }
103 _ => {
104 node_info[i].uses += 1;
105 node_info[i].depth = 1;
106 }
107 }
108 }
109
110 for node in node_info.iter().filter(|node| node.depth > 1) {
111 stats.max_intermediate_use = stats.max_intermediate_use.max(node.uses);
112 stats.average_intermediate_use += node.uses as f64;
113 }
114 stats.average_intermediate_use /= stats.num_intermediates as f64;
115
116 for constraint_idx in &dag.constraints.constraint_idx {
117 stats.max_constraint_depth = stats
118 .max_constraint_depth
119 .max(node_info[*constraint_idx].depth);
120 stats.average_constraint_depth += node_info[*constraint_idx].depth as f64;
121 }
122 stats.average_constraint_depth /= stats.num_constraints as f64;
123
124 self.stats.push(stats);
125 }
126
127 pub fn write_json<P: AsRef<Path>>(&self, file_path: P) -> eyre::Result<()> {
128 serde_json::to_writer_pretty(File::create(file_path)?, &self.stats)?;
129 Ok(())
130 }
131
132 pub fn write_csv<P: AsRef<Path>>(&self, file_path: P) -> eyre::Result<()> {
133 let mut file = File::create(file_path)?;
134 writeln!(
135 file,
136 "air_name,num_nodes,num_interactions,num_constraints,max_constraint_depth,average_constraint_depth,num_constants,num_variables,num_intermediates,max_intermediate_use,average_intermediate_use"
137 )?;
138 for stat in &self.stats {
139 writeln!(
140 file,
141 "{},{},{},{},{},{},{},{},{},{},{}",
142 stat.air_name,
143 stat.num_nodes,
144 stat.num_interactions,
145 stat.num_constraints,
146 stat.max_constraint_depth,
147 stat.average_constraint_depth,
148 stat.num_constants,
149 stat.num_variables,
150 stat.num_intermediates,
151 stat.max_intermediate_use,
152 stat.average_intermediate_use
153 )?;
154 }
155 Ok(())
156 }
157
158 pub fn print_dag<F: std::fmt::Debug>(dag: &SymbolicExpressionDag<F>) {
159 for (idx, node) in dag.nodes.iter().enumerate() {
160 println!(" Node {}: {:?}", idx, node);
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use p3_baby_bear::BabyBear;
168 use p3_field::FieldAlgebra;
169
170 use crate::{
171 air_builders::symbolic::{
172 build_symbolic_constraints_dag,
173 statistics::AirStatisticsGenerator,
174 symbolic_expression::SymbolicExpression,
175 symbolic_variable::{Entry, SymbolicVariable},
176 },
177 interaction::Interaction,
178 };
179
180 type F = BabyBear;
181
182 #[test]
183 fn test_dag_statistics() {
184 let expr1 = SymbolicExpression::Variable(SymbolicVariable::new(
185 Entry::Main {
186 part_index: 0,
187 offset: 0,
188 },
189 1,
190 ));
191 let expr2 = SymbolicExpression::Variable(SymbolicVariable::new(
192 Entry::Main {
193 part_index: 1,
194 offset: 1,
195 },
196 2,
197 ));
198 let expr3 = expr1.clone() + expr2.clone();
199 let expr4 = expr3.clone() * SymbolicExpression::Constant(F::TWO);
200 let expr5 = SymbolicExpression::IsFirstRow + expr4.clone();
201
202 let constraints = vec![expr5, expr4.clone()];
203 let interactions = vec![Interaction {
204 bus_index: 0,
205 message: vec![expr1, expr2],
206 count: SymbolicExpression::Constant(F::ONE),
207 count_weight: 1,
208 }];
209
210 let dag = build_symbolic_constraints_dag(&constraints, &interactions);
211 AirStatisticsGenerator::print_dag(&dag.constraints);
212
213 let mut generator = AirStatisticsGenerator::new();
214 generator.generate("test".to_string(), &dag);
215 println!("{:?}", generator.stats);
216
217 assert_eq!(generator.stats[0].num_nodes, 8);
218 assert_eq!(generator.stats[0].num_interactions, 1);
219 assert_eq!(generator.stats[0].num_constraints, 2);
220 assert_eq!(generator.stats[0].max_constraint_depth, 4);
221 assert_eq!(generator.stats[0].average_constraint_depth, 3.5);
222 assert_eq!(generator.stats[0].num_constants, 2);
223 assert_eq!(generator.stats[0].num_variables, 2);
224 assert_eq!(generator.stats[0].num_intermediates, 3);
225 assert_eq!(generator.stats[0].max_intermediate_use, 2);
226 assert_eq!(generator.stats[0].average_intermediate_use, 5f64 / 3f64);
227 }
228}