use std::sync::Arc;
use p3_field::Field;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::SymbolicConstraints;
use crate::{
air_builders::symbolic::{
symbolic_expression::SymbolicExpression, symbolic_variable::SymbolicVariable,
},
interaction::{Interaction, SymbolicInteraction},
};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(bound = "F: Field")]
#[repr(C)]
pub enum SymbolicExpressionNode<F> {
Variable(SymbolicVariable<F>),
IsFirstRow,
IsLastRow,
IsTransition,
Constant(F),
Add {
left_idx: usize,
right_idx: usize,
degree_multiple: usize,
},
Sub {
left_idx: usize,
right_idx: usize,
degree_multiple: usize,
},
Neg {
idx: usize,
degree_multiple: usize,
},
Mul {
left_idx: usize,
right_idx: usize,
degree_multiple: usize,
},
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(bound = "F: Field")]
pub struct SymbolicExpressionDag<F> {
pub(crate) nodes: Vec<SymbolicExpressionNode<F>>,
pub(crate) constraint_idx: Vec<usize>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "F: Field")]
pub struct SymbolicConstraintsDag<F> {
pub constraints: SymbolicExpressionDag<F>,
pub interactions: Vec<Interaction<usize>>,
}
pub(crate) fn build_symbolic_constraints_dag<F: Field>(
constraints: &[SymbolicExpression<F>],
interactions: &[SymbolicInteraction<F>],
) -> SymbolicConstraintsDag<F> {
let mut expr_to_idx = FxHashMap::default();
let mut nodes = Vec::new();
let constraint_idx = constraints
.iter()
.map(|expr| topological_sort_symbolic_expr(expr, &mut expr_to_idx, &mut nodes))
.collect();
let interactions: Vec<Interaction<usize>> = interactions
.iter()
.map(|interaction| {
let fields: Vec<usize> = interaction
.fields
.iter()
.map(|field_expr| {
topological_sort_symbolic_expr(field_expr, &mut expr_to_idx, &mut nodes)
})
.collect();
let count =
topological_sort_symbolic_expr(&interaction.count, &mut expr_to_idx, &mut nodes);
Interaction {
fields,
count,
bus_index: interaction.bus_index,
interaction_type: interaction.interaction_type,
}
})
.collect();
let constraints = SymbolicExpressionDag {
nodes,
constraint_idx,
};
SymbolicConstraintsDag {
constraints,
interactions,
}
}
fn topological_sort_symbolic_expr<'a, F: Field>(
expr: &'a SymbolicExpression<F>,
expr_to_idx: &mut FxHashMap<&'a SymbolicExpression<F>, usize>,
nodes: &mut Vec<SymbolicExpressionNode<F>>,
) -> usize {
if let Some(&idx) = expr_to_idx.get(expr) {
return idx;
}
let node = match expr {
SymbolicExpression::Variable(var) => SymbolicExpressionNode::Variable(*var),
SymbolicExpression::IsFirstRow => SymbolicExpressionNode::IsFirstRow,
SymbolicExpression::IsLastRow => SymbolicExpressionNode::IsLastRow,
SymbolicExpression::IsTransition => SymbolicExpressionNode::IsTransition,
SymbolicExpression::Constant(cons) => SymbolicExpressionNode::Constant(*cons),
SymbolicExpression::Add {
x,
y,
degree_multiple,
} => {
let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes);
SymbolicExpressionNode::Add {
left_idx,
right_idx,
degree_multiple: *degree_multiple,
}
}
SymbolicExpression::Sub {
x,
y,
degree_multiple,
} => {
let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes);
SymbolicExpressionNode::Sub {
left_idx,
right_idx,
degree_multiple: *degree_multiple,
}
}
SymbolicExpression::Neg { x, degree_multiple } => {
let idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
SymbolicExpressionNode::Neg {
idx,
degree_multiple: *degree_multiple,
}
}
SymbolicExpression::Mul {
x,
y,
degree_multiple,
} => {
let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes);
let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes);
SymbolicExpressionNode::Mul {
left_idx,
right_idx,
degree_multiple: *degree_multiple,
}
}
};
let idx = nodes.len();
nodes.push(node);
expr_to_idx.insert(expr, idx);
idx
}
impl<F: Field> SymbolicExpressionDag<F> {
fn to_symbolic_expressions(&self) -> Vec<Arc<SymbolicExpression<F>>> {
let mut exprs: Vec<Arc<SymbolicExpression<_>>> = Vec::with_capacity(self.nodes.len());
for node in &self.nodes {
let expr = match *node {
SymbolicExpressionNode::Variable(var) => SymbolicExpression::Variable(var),
SymbolicExpressionNode::IsFirstRow => SymbolicExpression::IsFirstRow,
SymbolicExpressionNode::IsLastRow => SymbolicExpression::IsLastRow,
SymbolicExpressionNode::IsTransition => SymbolicExpression::IsTransition,
SymbolicExpressionNode::Constant(f) => SymbolicExpression::Constant(f),
SymbolicExpressionNode::Add {
left_idx,
right_idx,
degree_multiple,
} => SymbolicExpression::Add {
x: exprs[left_idx].clone(),
y: exprs[right_idx].clone(),
degree_multiple,
},
SymbolicExpressionNode::Sub {
left_idx,
right_idx,
degree_multiple,
} => SymbolicExpression::Sub {
x: exprs[left_idx].clone(),
y: exprs[right_idx].clone(),
degree_multiple,
},
SymbolicExpressionNode::Neg {
idx,
degree_multiple,
} => SymbolicExpression::Neg {
x: exprs[idx].clone(),
degree_multiple,
},
SymbolicExpressionNode::Mul {
left_idx,
right_idx,
degree_multiple,
} => SymbolicExpression::Mul {
x: exprs[left_idx].clone(),
y: exprs[right_idx].clone(),
degree_multiple,
},
};
exprs.push(Arc::new(expr));
}
exprs
}
}
impl<F: Field> From<SymbolicConstraintsDag<F>> for SymbolicConstraints<F> {
fn from(dag: SymbolicConstraintsDag<F>) -> Self {
let exprs = dag.constraints.to_symbolic_expressions();
let constraints = dag
.constraints
.constraint_idx
.into_iter()
.map(|idx| exprs[idx].as_ref().clone())
.collect::<Vec<_>>();
let interactions = dag
.interactions
.into_iter()
.map(|interaction| {
let fields = interaction
.fields
.into_iter()
.map(|idx| exprs[idx].as_ref().clone())
.collect();
let count = exprs[interaction.count].as_ref().clone();
Interaction {
fields,
count,
bus_index: interaction.bus_index,
interaction_type: interaction.interaction_type,
}
})
.collect::<Vec<_>>();
SymbolicConstraints {
constraints,
interactions,
}
}
}
impl<F: Field> From<SymbolicConstraints<F>> for SymbolicConstraintsDag<F> {
fn from(sc: SymbolicConstraints<F>) -> Self {
build_symbolic_constraints_dag(&sc.constraints, &sc.interactions)
}
}
#[cfg(test)]
mod tests {
use p3_baby_bear::BabyBear;
use p3_field::AbstractField;
use crate::{
air_builders::symbolic::{
dag::{build_symbolic_constraints_dag, SymbolicExpressionDag, SymbolicExpressionNode},
symbolic_expression::SymbolicExpression,
symbolic_variable::{Entry, SymbolicVariable},
SymbolicConstraints,
},
interaction::{Interaction, InteractionType},
};
type F = BabyBear;
#[test]
fn test_symbolic_constraints_dag() {
let expr = SymbolicExpression::Constant(F::ONE)
* SymbolicVariable::new(
Entry::Main {
part_index: 1,
offset: 2,
},
3,
);
let constraints = vec![
SymbolicExpression::IsFirstRow * SymbolicExpression::IsLastRow
+ SymbolicExpression::Constant(F::ONE)
+ SymbolicExpression::IsFirstRow * SymbolicExpression::IsLastRow
+ expr.clone(),
expr.clone() * expr.clone(),
];
let interactions = vec![Interaction {
bus_index: 0,
fields: vec![expr.clone(), SymbolicExpression::Constant(F::TWO)],
count: SymbolicExpression::Constant(F::ONE),
interaction_type: InteractionType::Send,
}];
let dag = build_symbolic_constraints_dag(&constraints, &interactions);
assert_eq!(
dag.constraints,
SymbolicExpressionDag::<F> {
nodes: vec![
SymbolicExpressionNode::IsFirstRow,
SymbolicExpressionNode::IsLastRow,
SymbolicExpressionNode::Mul {
left_idx: 0,
right_idx: 1,
degree_multiple: 2
},
SymbolicExpressionNode::Constant(F::ONE),
SymbolicExpressionNode::Add {
left_idx: 2,
right_idx: 3,
degree_multiple: 2
},
SymbolicExpressionNode::Mul {
left_idx: 0,
right_idx: 1,
degree_multiple: 2
},
SymbolicExpressionNode::Add {
left_idx: 4,
right_idx: 5,
degree_multiple: 2
},
SymbolicExpressionNode::Variable(SymbolicVariable::new(
Entry::Main {
part_index: 1,
offset: 2
},
3
)),
SymbolicExpressionNode::Mul {
left_idx: 3,
right_idx: 7,
degree_multiple: 1
},
SymbolicExpressionNode::Add {
left_idx: 6,
right_idx: 8,
degree_multiple: 2
},
SymbolicExpressionNode::Mul {
left_idx: 8,
right_idx: 8,
degree_multiple: 2
},
SymbolicExpressionNode::Constant(F::TWO),
],
constraint_idx: vec![9, 10],
}
);
assert_eq!(
dag.interactions,
vec![Interaction {
bus_index: 0,
fields: vec![8, 11],
count: 3,
interaction_type: InteractionType::Send,
}]
);
let sc = SymbolicConstraints {
constraints,
interactions,
};
let ser_str = serde_json::to_string(&sc).unwrap();
let new_sc: SymbolicConstraints<_> = serde_json::from_str(&ser_str).unwrap();
assert_eq!(sc, new_sc);
}
}