openvm_stark_backend/air_builders/symbolic/
dag.rs

1use std::sync::Arc;
2
3use p3_field::Field;
4use rustc_hash::FxHashMap;
5use serde::{Deserialize, Serialize};
6
7use super::SymbolicConstraints;
8use crate::{
9    air_builders::symbolic::{
10        symbolic_expression::SymbolicExpression, symbolic_variable::SymbolicVariable,
11    },
12    interaction::{Interaction, SymbolicInteraction},
13};
14
15/// A node in symbolic expression DAG.
16/// Basically replace `Arc`s in `SymbolicExpression` with node IDs.
17/// Intended to be serializable and deserializable.
18#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
19#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))]
20#[repr(C)]
21pub enum SymbolicExpressionNode<F> {
22    Variable(SymbolicVariable<F>),
23    IsFirstRow,
24    IsLastRow,
25    IsTransition,
26    Constant(F),
27    Add {
28        left_idx: usize,
29        right_idx: usize,
30        degree_multiple: usize,
31    },
32    Sub {
33        left_idx: usize,
34        right_idx: usize,
35        degree_multiple: usize,
36    },
37    Neg {
38        idx: usize,
39        degree_multiple: usize,
40    },
41    Mul {
42        left_idx: usize,
43        right_idx: usize,
44        degree_multiple: usize,
45    },
46}
47
48#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
49#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))]
50#[repr(C)]
51pub struct SymbolicExpressionDag<F> {
52    /// Nodes in **topological** order.
53    pub nodes: Vec<SymbolicExpressionNode<F>>,
54    /// Node indices of expressions to assert equal zero.
55    pub constraint_idx: Vec<usize>,
56}
57
58impl<F> SymbolicExpressionDag<F> {
59    pub fn max_rotation(&self) -> usize {
60        let mut rotation = 0;
61        for node in &self.nodes {
62            if let SymbolicExpressionNode::Variable(var) = node {
63                rotation = rotation.max(var.entry.offset().unwrap_or(0));
64            }
65        }
66        rotation
67    }
68
69    pub fn num_constraints(&self) -> usize {
70        self.constraint_idx.len()
71    }
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize)]
75#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))]
76#[repr(C)]
77pub struct SymbolicConstraintsDag<F> {
78    /// DAG with all symbolic expressions as nodes.
79    /// A subset of the nodes represents all constraints that will be
80    /// included in the quotient polynomial via DEEP-ALI.
81    pub constraints: SymbolicExpressionDag<F>,
82    /// List of all interactions, where expressions in the interactions
83    /// are referenced by node idx as `usize`.
84    ///
85    /// This is used by the prover for after challenge trace generation,
86    /// and some partial information may be used by the verifier.
87    ///
88    /// **However**, any contributions to the quotient polynomial from
89    /// logup are already included in `constraints` and do not need to
90    /// be separately calculated from `interactions`.
91    pub interactions: Vec<Interaction<usize>>,
92}
93
94pub(crate) fn build_symbolic_constraints_dag<F: Field>(
95    constraints: &[SymbolicExpression<F>],
96    interactions: &[SymbolicInteraction<F>],
97) -> SymbolicConstraintsDag<F> {
98    let mut expr_to_idx = FxHashMap::default();
99    let mut nodes = Vec::new();
100    let mut constraint_idx: Vec<usize> = constraints
101        .iter()
102        .map(|expr| topological_sort_symbolic_expr(expr, &mut expr_to_idx, &mut nodes))
103        .collect();
104    constraint_idx.sort();
105    let interactions: Vec<Interaction<usize>> = interactions
106        .iter()
107        .map(|interaction| {
108            let fields: Vec<usize> = interaction
109                .message
110                .iter()
111                .map(|field_expr| {
112                    topological_sort_symbolic_expr(field_expr, &mut expr_to_idx, &mut nodes)
113                })
114                .collect();
115            let count =
116                topological_sort_symbolic_expr(&interaction.count, &mut expr_to_idx, &mut nodes);
117            Interaction {
118                message: fields,
119                count,
120                bus_index: interaction.bus_index,
121                count_weight: interaction.count_weight,
122            }
123        })
124        .collect();
125    // Note[jpw]: there could be few nodes created after `constraint_idx` is built
126    // from `interactions` even though constraints already contain all interactions.
127    // This should be marginal and is not optimized for now.
128    let constraints = SymbolicExpressionDag {
129        nodes,
130        constraint_idx,
131    };
132    SymbolicConstraintsDag {
133        constraints,
134        interactions,
135    }
136}
137
138/// `expr_to_idx` is a cache so that the `Arc<_>` references within symbolic expressions get
139/// mapped to the same node ID if their underlying references are the same.
140fn topological_sort_symbolic_expr<'a, F: Field>(
141    expr: &'a SymbolicExpression<F>,
142    expr_to_idx: &mut FxHashMap<&'a SymbolicExpression<F>, usize>,
143    nodes: &mut Vec<SymbolicExpressionNode<F>>,
144) -> usize {
145    if let Some(&idx) = expr_to_idx.get(expr) {
146        return idx;
147    }
148    let node = match expr {
149        SymbolicExpression::Variable(var) => SymbolicExpressionNode::Variable(*var),
150        SymbolicExpression::IsFirstRow => SymbolicExpressionNode::IsFirstRow,
151        SymbolicExpression::IsLastRow => SymbolicExpressionNode::IsLastRow,
152        SymbolicExpression::IsTransition => SymbolicExpressionNode::IsTransition,
153        SymbolicExpression::Constant(cons) => SymbolicExpressionNode::Constant(*cons),
154        SymbolicExpression::Add {
155            x,
156            y,
157            degree_multiple,
158        } => {
159            let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
160            let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes);
161            SymbolicExpressionNode::Add {
162                left_idx,
163                right_idx,
164                degree_multiple: *degree_multiple,
165            }
166        }
167        SymbolicExpression::Sub {
168            x,
169            y,
170            degree_multiple,
171        } => {
172            let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
173            let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes);
174            SymbolicExpressionNode::Sub {
175                left_idx,
176                right_idx,
177                degree_multiple: *degree_multiple,
178            }
179        }
180        SymbolicExpression::Neg { x, degree_multiple } => {
181            let idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
182            SymbolicExpressionNode::Neg {
183                idx,
184                degree_multiple: *degree_multiple,
185            }
186        }
187        SymbolicExpression::Mul {
188            x,
189            y,
190            degree_multiple,
191        } => {
192            // An important case to remember: square will have Arc::as_ptr(&x) == Arc::as_ptr(&y)
193            // The `expr_to_id` will ensure only one topological sort is done to prevent exponential
194            // behavior.
195            let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
196            let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes);
197            SymbolicExpressionNode::Mul {
198                left_idx,
199                right_idx,
200                degree_multiple: *degree_multiple,
201            }
202        }
203    };
204
205    let idx = nodes.len();
206    nodes.push(node);
207    expr_to_idx.insert(expr, idx);
208    idx
209}
210
211impl<F: Field> SymbolicExpressionDag<F> {
212    /// Convert each node to a [`SymbolicExpression<F>`] reference and return
213    /// the full list.
214    fn to_symbolic_expressions(&self) -> Vec<Arc<SymbolicExpression<F>>> {
215        let mut exprs: Vec<Arc<SymbolicExpression<_>>> = Vec::with_capacity(self.nodes.len());
216        for node in &self.nodes {
217            let expr = match *node {
218                SymbolicExpressionNode::Variable(var) => SymbolicExpression::Variable(var),
219                SymbolicExpressionNode::IsFirstRow => SymbolicExpression::IsFirstRow,
220                SymbolicExpressionNode::IsLastRow => SymbolicExpression::IsLastRow,
221                SymbolicExpressionNode::IsTransition => SymbolicExpression::IsTransition,
222                SymbolicExpressionNode::Constant(f) => SymbolicExpression::Constant(f),
223                SymbolicExpressionNode::Add {
224                    left_idx,
225                    right_idx,
226                    degree_multiple,
227                } => SymbolicExpression::Add {
228                    x: exprs[left_idx].clone(),
229                    y: exprs[right_idx].clone(),
230                    degree_multiple,
231                },
232                SymbolicExpressionNode::Sub {
233                    left_idx,
234                    right_idx,
235                    degree_multiple,
236                } => SymbolicExpression::Sub {
237                    x: exprs[left_idx].clone(),
238                    y: exprs[right_idx].clone(),
239                    degree_multiple,
240                },
241                SymbolicExpressionNode::Neg {
242                    idx,
243                    degree_multiple,
244                } => SymbolicExpression::Neg {
245                    x: exprs[idx].clone(),
246                    degree_multiple,
247                },
248                SymbolicExpressionNode::Mul {
249                    left_idx,
250                    right_idx,
251                    degree_multiple,
252                } => SymbolicExpression::Mul {
253                    x: exprs[left_idx].clone(),
254                    y: exprs[right_idx].clone(),
255                    degree_multiple,
256                },
257            };
258            exprs.push(Arc::new(expr));
259        }
260        exprs
261    }
262}
263
264// TEMPORARY conversions until we switch main interfaces to use SymbolicConstraintsDag
265impl<'a, F: Field> From<&'a SymbolicConstraintsDag<F>> for SymbolicConstraints<F> {
266    fn from(dag: &'a SymbolicConstraintsDag<F>) -> Self {
267        let exprs = dag.constraints.to_symbolic_expressions();
268        let constraints = dag
269            .constraints
270            .constraint_idx
271            .iter()
272            .map(|&idx| exprs[idx].as_ref().clone())
273            .collect::<Vec<_>>();
274        let interactions = dag
275            .interactions
276            .iter()
277            .map(|interaction| {
278                let fields = interaction
279                    .message
280                    .iter()
281                    .map(|&idx| exprs[idx].as_ref().clone())
282                    .collect();
283                let count = exprs[interaction.count].as_ref().clone();
284                Interaction {
285                    message: fields,
286                    count,
287                    bus_index: interaction.bus_index,
288                    count_weight: interaction.count_weight,
289                }
290            })
291            .collect::<Vec<_>>();
292        SymbolicConstraints {
293            constraints,
294            interactions,
295        }
296    }
297}
298
299impl<F: Field> From<SymbolicConstraintsDag<F>> for SymbolicConstraints<F> {
300    fn from(dag: SymbolicConstraintsDag<F>) -> Self {
301        (&dag).into()
302    }
303}
304
305impl<F: Field> From<SymbolicConstraints<F>> for SymbolicConstraintsDag<F> {
306    fn from(sc: SymbolicConstraints<F>) -> Self {
307        build_symbolic_constraints_dag(&sc.constraints, &sc.interactions)
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use p3_baby_bear::BabyBear;
314    use p3_field::FieldAlgebra;
315
316    use crate::{
317        air_builders::symbolic::{
318            dag::{build_symbolic_constraints_dag, SymbolicExpressionDag, SymbolicExpressionNode},
319            symbolic_expression::SymbolicExpression,
320            symbolic_variable::{Entry, SymbolicVariable},
321        },
322        interaction::Interaction,
323    };
324
325    type F = BabyBear;
326
327    #[test]
328    fn test_symbolic_constraints_dag() {
329        let expr = SymbolicExpression::Constant(F::ONE)
330            * SymbolicVariable::new(
331                Entry::Main {
332                    part_index: 1,
333                    offset: 2,
334                },
335                3,
336            );
337        let constraints = vec![
338            SymbolicExpression::IsFirstRow * SymbolicExpression::IsLastRow
339                + SymbolicExpression::Constant(F::ONE)
340                + SymbolicExpression::IsFirstRow * SymbolicExpression::IsLastRow
341                + expr.clone(),
342            expr.clone() * expr.clone(),
343        ];
344        let interactions = vec![Interaction {
345            bus_index: 0,
346            message: vec![expr.clone(), SymbolicExpression::Constant(F::TWO)],
347            count: SymbolicExpression::Constant(F::ONE),
348            count_weight: 1,
349        }];
350        let dag = build_symbolic_constraints_dag(&constraints, &interactions);
351        assert_eq!(
352            dag.constraints,
353            SymbolicExpressionDag::<F> {
354                nodes: vec![
355                    SymbolicExpressionNode::IsFirstRow,
356                    SymbolicExpressionNode::IsLastRow,
357                    SymbolicExpressionNode::Mul {
358                        left_idx: 0,
359                        right_idx: 1,
360                        degree_multiple: 2
361                    },
362                    SymbolicExpressionNode::Constant(F::ONE),
363                    SymbolicExpressionNode::Add {
364                        left_idx: 2,
365                        right_idx: 3,
366                        degree_multiple: 2
367                    },
368                    // Currently topological sort does not detect all subgraph isomorphisms. For example each IsFirstRow and IsLastRow is a new reference so ptr::hash is distinct.
369                    SymbolicExpressionNode::Mul {
370                        left_idx: 0,
371                        right_idx: 1,
372                        degree_multiple: 2
373                    },
374                    SymbolicExpressionNode::Add {
375                        left_idx: 4,
376                        right_idx: 5,
377                        degree_multiple: 2
378                    },
379                    SymbolicExpressionNode::Variable(SymbolicVariable::new(
380                        Entry::Main {
381                            part_index: 1,
382                            offset: 2
383                        },
384                        3
385                    )),
386                    SymbolicExpressionNode::Mul {
387                        left_idx: 3,
388                        right_idx: 7,
389                        degree_multiple: 1
390                    },
391                    SymbolicExpressionNode::Add {
392                        left_idx: 6,
393                        right_idx: 8,
394                        degree_multiple: 2
395                    },
396                    SymbolicExpressionNode::Mul {
397                        left_idx: 8,
398                        right_idx: 8,
399                        degree_multiple: 2
400                    },
401                    SymbolicExpressionNode::Constant(F::TWO),
402                ],
403                constraint_idx: vec![9, 10],
404            }
405        );
406        assert_eq!(
407            dag.interactions,
408            vec![Interaction {
409                bus_index: 0,
410                message: vec![8, 11],
411                count: 3,
412                count_weight: 1,
413            }]
414        );
415    }
416}