1use std::sync::Arc;
2
3use p3_field::Field;
4use rustc_hash::FxHashMap;
5use serde::{Deserialize, Serialize};
6
7use super::SymbolicConstraints;
8use crate::{
9 air_builders::symbolic::{
10 symbolic_expression::SymbolicExpression, symbolic_variable::SymbolicVariable,
11 },
12 interaction::{Interaction, SymbolicInteraction},
13};
14
15#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
19#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))]
20#[repr(C)]
21pub enum SymbolicExpressionNode<F> {
22 Variable(SymbolicVariable<F>),
23 IsFirstRow,
24 IsLastRow,
25 IsTransition,
26 Constant(F),
27 Add {
28 left_idx: usize,
29 right_idx: usize,
30 degree_multiple: usize,
31 },
32 Sub {
33 left_idx: usize,
34 right_idx: usize,
35 degree_multiple: usize,
36 },
37 Neg {
38 idx: usize,
39 degree_multiple: usize,
40 },
41 Mul {
42 left_idx: usize,
43 right_idx: usize,
44 degree_multiple: usize,
45 },
46}
47
48#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
49#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))]
50#[repr(C)]
51pub struct SymbolicExpressionDag<F> {
52 pub nodes: Vec<SymbolicExpressionNode<F>>,
54 pub constraint_idx: Vec<usize>,
56}
57
58impl<F> SymbolicExpressionDag<F> {
59 pub fn max_rotation(&self) -> usize {
60 let mut rotation = 0;
61 for node in &self.nodes {
62 if let SymbolicExpressionNode::Variable(var) = node {
63 rotation = rotation.max(var.entry.offset().unwrap_or(0));
64 }
65 }
66 rotation
67 }
68
69 pub fn num_constraints(&self) -> usize {
70 self.constraint_idx.len()
71 }
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize)]
75#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))]
76#[repr(C)]
77pub struct SymbolicConstraintsDag<F> {
78 pub constraints: SymbolicExpressionDag<F>,
82 pub interactions: Vec<Interaction<usize>>,
92}
93
94pub(crate) fn build_symbolic_constraints_dag<F: Field>(
95 constraints: &[SymbolicExpression<F>],
96 interactions: &[SymbolicInteraction<F>],
97) -> SymbolicConstraintsDag<F> {
98 let mut expr_to_idx = FxHashMap::default();
99 let mut nodes = Vec::new();
100 let mut constraint_idx: Vec<usize> = constraints
101 .iter()
102 .map(|expr| topological_sort_symbolic_expr(expr, &mut expr_to_idx, &mut nodes))
103 .collect();
104 constraint_idx.sort();
105 let interactions: Vec<Interaction<usize>> = interactions
106 .iter()
107 .map(|interaction| {
108 let fields: Vec<usize> = interaction
109 .message
110 .iter()
111 .map(|field_expr| {
112 topological_sort_symbolic_expr(field_expr, &mut expr_to_idx, &mut nodes)
113 })
114 .collect();
115 let count =
116 topological_sort_symbolic_expr(&interaction.count, &mut expr_to_idx, &mut nodes);
117 Interaction {
118 message: fields,
119 count,
120 bus_index: interaction.bus_index,
121 count_weight: interaction.count_weight,
122 }
123 })
124 .collect();
125 let constraints = SymbolicExpressionDag {
129 nodes,
130 constraint_idx,
131 };
132 SymbolicConstraintsDag {
133 constraints,
134 interactions,
135 }
136}
137
138fn topological_sort_symbolic_expr<'a, F: Field>(
141 expr: &'a SymbolicExpression<F>,
142 expr_to_idx: &mut FxHashMap<&'a SymbolicExpression<F>, usize>,
143 nodes: &mut Vec<SymbolicExpressionNode<F>>,
144) -> usize {
145 if let Some(&idx) = expr_to_idx.get(expr) {
146 return idx;
147 }
148 let node = match expr {
149 SymbolicExpression::Variable(var) => SymbolicExpressionNode::Variable(*var),
150 SymbolicExpression::IsFirstRow => SymbolicExpressionNode::IsFirstRow,
151 SymbolicExpression::IsLastRow => SymbolicExpressionNode::IsLastRow,
152 SymbolicExpression::IsTransition => SymbolicExpressionNode::IsTransition,
153 SymbolicExpression::Constant(cons) => SymbolicExpressionNode::Constant(*cons),
154 SymbolicExpression::Add {
155 x,
156 y,
157 degree_multiple,
158 } => {
159 let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
160 let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes);
161 SymbolicExpressionNode::Add {
162 left_idx,
163 right_idx,
164 degree_multiple: *degree_multiple,
165 }
166 }
167 SymbolicExpression::Sub {
168 x,
169 y,
170 degree_multiple,
171 } => {
172 let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
173 let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes);
174 SymbolicExpressionNode::Sub {
175 left_idx,
176 right_idx,
177 degree_multiple: *degree_multiple,
178 }
179 }
180 SymbolicExpression::Neg { x, degree_multiple } => {
181 let idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
182 SymbolicExpressionNode::Neg {
183 idx,
184 degree_multiple: *degree_multiple,
185 }
186 }
187 SymbolicExpression::Mul {
188 x,
189 y,
190 degree_multiple,
191 } => {
192 let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
196 let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes);
197 SymbolicExpressionNode::Mul {
198 left_idx,
199 right_idx,
200 degree_multiple: *degree_multiple,
201 }
202 }
203 };
204
205 let idx = nodes.len();
206 nodes.push(node);
207 expr_to_idx.insert(expr, idx);
208 idx
209}
210
211impl<F: Field> SymbolicExpressionDag<F> {
212 fn to_symbolic_expressions(&self) -> Vec<Arc<SymbolicExpression<F>>> {
215 let mut exprs: Vec<Arc<SymbolicExpression<_>>> = Vec::with_capacity(self.nodes.len());
216 for node in &self.nodes {
217 let expr = match *node {
218 SymbolicExpressionNode::Variable(var) => SymbolicExpression::Variable(var),
219 SymbolicExpressionNode::IsFirstRow => SymbolicExpression::IsFirstRow,
220 SymbolicExpressionNode::IsLastRow => SymbolicExpression::IsLastRow,
221 SymbolicExpressionNode::IsTransition => SymbolicExpression::IsTransition,
222 SymbolicExpressionNode::Constant(f) => SymbolicExpression::Constant(f),
223 SymbolicExpressionNode::Add {
224 left_idx,
225 right_idx,
226 degree_multiple,
227 } => SymbolicExpression::Add {
228 x: exprs[left_idx].clone(),
229 y: exprs[right_idx].clone(),
230 degree_multiple,
231 },
232 SymbolicExpressionNode::Sub {
233 left_idx,
234 right_idx,
235 degree_multiple,
236 } => SymbolicExpression::Sub {
237 x: exprs[left_idx].clone(),
238 y: exprs[right_idx].clone(),
239 degree_multiple,
240 },
241 SymbolicExpressionNode::Neg {
242 idx,
243 degree_multiple,
244 } => SymbolicExpression::Neg {
245 x: exprs[idx].clone(),
246 degree_multiple,
247 },
248 SymbolicExpressionNode::Mul {
249 left_idx,
250 right_idx,
251 degree_multiple,
252 } => SymbolicExpression::Mul {
253 x: exprs[left_idx].clone(),
254 y: exprs[right_idx].clone(),
255 degree_multiple,
256 },
257 };
258 exprs.push(Arc::new(expr));
259 }
260 exprs
261 }
262}
263
264impl<'a, F: Field> From<&'a SymbolicConstraintsDag<F>> for SymbolicConstraints<F> {
266 fn from(dag: &'a SymbolicConstraintsDag<F>) -> Self {
267 let exprs = dag.constraints.to_symbolic_expressions();
268 let constraints = dag
269 .constraints
270 .constraint_idx
271 .iter()
272 .map(|&idx| exprs[idx].as_ref().clone())
273 .collect::<Vec<_>>();
274 let interactions = dag
275 .interactions
276 .iter()
277 .map(|interaction| {
278 let fields = interaction
279 .message
280 .iter()
281 .map(|&idx| exprs[idx].as_ref().clone())
282 .collect();
283 let count = exprs[interaction.count].as_ref().clone();
284 Interaction {
285 message: fields,
286 count,
287 bus_index: interaction.bus_index,
288 count_weight: interaction.count_weight,
289 }
290 })
291 .collect::<Vec<_>>();
292 SymbolicConstraints {
293 constraints,
294 interactions,
295 }
296 }
297}
298
299impl<F: Field> From<SymbolicConstraintsDag<F>> for SymbolicConstraints<F> {
300 fn from(dag: SymbolicConstraintsDag<F>) -> Self {
301 (&dag).into()
302 }
303}
304
305impl<F: Field> From<SymbolicConstraints<F>> for SymbolicConstraintsDag<F> {
306 fn from(sc: SymbolicConstraints<F>) -> Self {
307 build_symbolic_constraints_dag(&sc.constraints, &sc.interactions)
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use p3_baby_bear::BabyBear;
314 use p3_field::FieldAlgebra;
315
316 use crate::{
317 air_builders::symbolic::{
318 dag::{build_symbolic_constraints_dag, SymbolicExpressionDag, SymbolicExpressionNode},
319 symbolic_expression::SymbolicExpression,
320 symbolic_variable::{Entry, SymbolicVariable},
321 },
322 interaction::Interaction,
323 };
324
325 type F = BabyBear;
326
327 #[test]
328 fn test_symbolic_constraints_dag() {
329 let expr = SymbolicExpression::Constant(F::ONE)
330 * SymbolicVariable::new(
331 Entry::Main {
332 part_index: 1,
333 offset: 2,
334 },
335 3,
336 );
337 let constraints = vec![
338 SymbolicExpression::IsFirstRow * SymbolicExpression::IsLastRow
339 + SymbolicExpression::Constant(F::ONE)
340 + SymbolicExpression::IsFirstRow * SymbolicExpression::IsLastRow
341 + expr.clone(),
342 expr.clone() * expr.clone(),
343 ];
344 let interactions = vec![Interaction {
345 bus_index: 0,
346 message: vec![expr.clone(), SymbolicExpression::Constant(F::TWO)],
347 count: SymbolicExpression::Constant(F::ONE),
348 count_weight: 1,
349 }];
350 let dag = build_symbolic_constraints_dag(&constraints, &interactions);
351 assert_eq!(
352 dag.constraints,
353 SymbolicExpressionDag::<F> {
354 nodes: vec![
355 SymbolicExpressionNode::IsFirstRow,
356 SymbolicExpressionNode::IsLastRow,
357 SymbolicExpressionNode::Mul {
358 left_idx: 0,
359 right_idx: 1,
360 degree_multiple: 2
361 },
362 SymbolicExpressionNode::Constant(F::ONE),
363 SymbolicExpressionNode::Add {
364 left_idx: 2,
365 right_idx: 3,
366 degree_multiple: 2
367 },
368 SymbolicExpressionNode::Mul {
370 left_idx: 0,
371 right_idx: 1,
372 degree_multiple: 2
373 },
374 SymbolicExpressionNode::Add {
375 left_idx: 4,
376 right_idx: 5,
377 degree_multiple: 2
378 },
379 SymbolicExpressionNode::Variable(SymbolicVariable::new(
380 Entry::Main {
381 part_index: 1,
382 offset: 2
383 },
384 3
385 )),
386 SymbolicExpressionNode::Mul {
387 left_idx: 3,
388 right_idx: 7,
389 degree_multiple: 1
390 },
391 SymbolicExpressionNode::Add {
392 left_idx: 6,
393 right_idx: 8,
394 degree_multiple: 2
395 },
396 SymbolicExpressionNode::Mul {
397 left_idx: 8,
398 right_idx: 8,
399 degree_multiple: 2
400 },
401 SymbolicExpressionNode::Constant(F::TWO),
402 ],
403 constraint_idx: vec![9, 10],
404 }
405 );
406 assert_eq!(
407 dag.interactions,
408 vec![Interaction {
409 bus_index: 0,
410 message: vec![8, 11],
411 count: 3,
412 count_weight: 1,
413 }]
414 );
415 }
416}