openvm_stark_backend/air_builders/symbolic/
symbolic_expression.rs

1// Copied from uni-stark/src/symbolic_expression.rs to use Arc instead of Rc.
2
3use core::{
4    fmt::Debug,
5    iter::{Product, Sum},
6    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
7};
8use std::{
9    hash::{Hash, Hasher},
10    ptr,
11    sync::Arc,
12};
13
14use p3_field::{Algebra, Field, PrimeCharacteristicRing};
15use serde::{Deserialize, Serialize};
16
17use super::{dag::SymbolicExpressionNode, symbolic_variable::SymbolicVariable};
18
19/// An expression over `SymbolicVariable`s.
20// Note: avoid deriving Hash because it will hash the entire sub-tree
21#[derive(Clone, Debug, Serialize, Deserialize)]
22#[serde(bound = "F: Field")]
23pub enum SymbolicExpression<F> {
24    Variable(SymbolicVariable<F>),
25    IsFirstRow,
26    IsLastRow,
27    IsTransition,
28    Constant(F),
29    Add {
30        x: Arc<Self>,
31        y: Arc<Self>,
32        degree_multiple: usize,
33    },
34    Sub {
35        x: Arc<Self>,
36        y: Arc<Self>,
37        degree_multiple: usize,
38    },
39    Neg {
40        x: Arc<Self>,
41        degree_multiple: usize,
42    },
43    Mul {
44        x: Arc<Self>,
45        y: Arc<Self>,
46        degree_multiple: usize,
47    },
48}
49
50impl<F: Field> Hash for SymbolicExpression<F> {
51    fn hash<H: Hasher>(&self, state: &mut H) {
52        // First hash the discriminant of the enum
53        std::mem::discriminant(self).hash(state);
54        // Degree multiple is not necessary
55        match self {
56            Self::Variable(v) => v.hash(state),
57            Self::IsFirstRow => {}   // discriminant is enough
58            Self::IsLastRow => {}    // discriminant is enough
59            Self::IsTransition => {} // discriminant is enough
60            Self::Constant(f) => f.hash(state),
61            Self::Add { x, y, .. } => {
62                ptr::hash(&**x, state);
63                ptr::hash(&**y, state);
64            }
65            Self::Sub { x, y, .. } => {
66                ptr::hash(&**x, state);
67                ptr::hash(&**y, state);
68            }
69            Self::Neg { x, .. } => {
70                ptr::hash(&**x, state);
71            }
72            Self::Mul { x, y, .. } => {
73                ptr::hash(&**x, state);
74                ptr::hash(&**y, state);
75            }
76        }
77    }
78}
79
80// We intentionally do not compare degree_multiple in PartialEq and Eq because degree_multiple is
81// metadata used for optimization/debugging purposes but it does not change the underlying
82// expression.
83impl<F: Field> PartialEq for SymbolicExpression<F> {
84    fn eq(&self, other: &Self) -> bool {
85        // First check if the variants match
86        if std::mem::discriminant(self) != std::mem::discriminant(other) {
87            return false;
88        }
89
90        // Then check equality based on variant-specific data
91        match (self, other) {
92            (Self::Variable(v1), Self::Variable(v2)) => v1 == v2,
93            // IsFirstRow, IsLastRow, and IsTransition are all unit variants,
94            // so if the discriminants match, they're equal
95            (Self::IsFirstRow, Self::IsFirstRow) => true,
96            (Self::IsLastRow, Self::IsLastRow) => true,
97            (Self::IsTransition, Self::IsTransition) => true,
98            (Self::Constant(c1), Self::Constant(c2)) => c1 == c2,
99            // For compound expressions, compare pointers to match how Hash is implemented
100            (Self::Add { x: x1, y: y1, .. }, Self::Add { x: x2, y: y2, .. }) => {
101                Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
102            }
103            (Self::Sub { x: x1, y: y1, .. }, Self::Sub { x: x2, y: y2, .. }) => {
104                Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
105            }
106            (Self::Neg { x: x1, .. }, Self::Neg { x: x2, .. }) => Arc::ptr_eq(x1, x2),
107            (Self::Mul { x: x1, y: y1, .. }, Self::Mul { x: x2, y: y2, .. }) => {
108                Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
109            }
110            // This should never be reached because we've already checked the discriminants
111            _ => false,
112        }
113    }
114}
115
116impl<F: Field> Eq for SymbolicExpression<F> {}
117
118impl<F: Field> SymbolicExpression<F> {
119    /// Returns the multiple of `n` (the trace length) in this expression's degree.
120    pub const fn degree_multiple(&self) -> usize {
121        match self {
122            SymbolicExpression::Variable(v) => v.degree_multiple(),
123            SymbolicExpression::IsFirstRow => 1,
124            SymbolicExpression::IsLastRow => 1,
125            SymbolicExpression::IsTransition => 0,
126            SymbolicExpression::Constant(_) => 0,
127            SymbolicExpression::Add {
128                degree_multiple, ..
129            } => *degree_multiple,
130            SymbolicExpression::Sub {
131                degree_multiple, ..
132            } => *degree_multiple,
133            SymbolicExpression::Neg {
134                degree_multiple, ..
135            } => *degree_multiple,
136            SymbolicExpression::Mul {
137                degree_multiple, ..
138            } => *degree_multiple,
139        }
140    }
141}
142
143impl<F: Field> Default for SymbolicExpression<F> {
144    fn default() -> Self {
145        Self::Constant(F::ZERO)
146    }
147}
148
149impl<F: Field> From<F> for SymbolicExpression<F> {
150    fn from(value: F) -> Self {
151        Self::Constant(value)
152    }
153}
154
155impl<F: Field> PrimeCharacteristicRing for SymbolicExpression<F> {
156    type PrimeSubfield = F::PrimeSubfield;
157
158    const ZERO: Self = Self::Constant(F::ZERO);
159    const ONE: Self = Self::Constant(F::ONE);
160    const TWO: Self = Self::Constant(F::TWO);
161    const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
162
163    #[inline]
164    fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
165        F::from_prime_subfield(f).into()
166    }
167}
168
169impl<F: Field> Add for SymbolicExpression<F> {
170    type Output = Self;
171
172    fn add(self, rhs: Self) -> Self {
173        let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
174        Self::Add {
175            x: Arc::new(self),
176            y: Arc::new(rhs),
177            degree_multiple,
178        }
179    }
180}
181
182impl<F: Field> Add<F> for SymbolicExpression<F> {
183    type Output = Self;
184
185    fn add(self, rhs: F) -> Self {
186        self + Self::from(rhs)
187    }
188}
189
190impl<F: Field> AddAssign for SymbolicExpression<F> {
191    fn add_assign(&mut self, rhs: Self) {
192        *self = self.clone() + rhs;
193    }
194}
195
196impl<F: Field> AddAssign<SymbolicVariable<F>> for SymbolicExpression<F> {
197    fn add_assign(&mut self, rhs: SymbolicVariable<F>) {
198        *self += SymbolicExpression::from(rhs);
199    }
200}
201
202impl<F: Field> AddAssign<F> for SymbolicExpression<F> {
203    fn add_assign(&mut self, rhs: F) {
204        *self += Self::from(rhs);
205    }
206}
207
208impl<F: Field> Sum for SymbolicExpression<F> {
209    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
210        iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO)
211    }
212}
213
214impl<F: Field> Sum<F> for SymbolicExpression<F> {
215    fn sum<I: Iterator<Item = F>>(iter: I) -> Self {
216        iter.map(|x| Self::from(x)).sum()
217    }
218}
219
220impl<F: Field> Sub for SymbolicExpression<F> {
221    type Output = Self;
222
223    fn sub(self, rhs: Self) -> Self {
224        let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
225        Self::Sub {
226            x: Arc::new(self),
227            y: Arc::new(rhs),
228            degree_multiple,
229        }
230    }
231}
232
233impl<F: Field> Sub<F> for SymbolicExpression<F> {
234    type Output = Self;
235
236    fn sub(self, rhs: F) -> Self {
237        self - Self::from(rhs)
238    }
239}
240
241impl<F: Field> SubAssign for SymbolicExpression<F> {
242    fn sub_assign(&mut self, rhs: Self) {
243        *self = self.clone() - rhs;
244    }
245}
246
247impl<F: Field> SubAssign<F> for SymbolicExpression<F> {
248    fn sub_assign(&mut self, rhs: F) {
249        *self -= Self::from(rhs);
250    }
251}
252
253impl<F: Field> SubAssign<SymbolicVariable<F>> for SymbolicExpression<F> {
254    fn sub_assign(&mut self, rhs: SymbolicVariable<F>) {
255        *self -= SymbolicExpression::from(rhs);
256    }
257}
258
259impl<F: Field> Neg for SymbolicExpression<F> {
260    type Output = Self;
261
262    fn neg(self) -> Self {
263        let degree_multiple = self.degree_multiple();
264        Self::Neg {
265            x: Arc::new(self),
266            degree_multiple,
267        }
268    }
269}
270
271impl<F: Field> Mul for SymbolicExpression<F> {
272    type Output = Self;
273
274    fn mul(self, rhs: Self) -> Self {
275        #[allow(clippy::suspicious_arithmetic_impl)]
276        let degree_multiple = self.degree_multiple() + rhs.degree_multiple();
277        Self::Mul {
278            x: Arc::new(self),
279            y: Arc::new(rhs),
280            degree_multiple,
281        }
282    }
283}
284
285impl<F: Field> Mul<F> for SymbolicExpression<F> {
286    type Output = Self;
287
288    fn mul(self, rhs: F) -> Self {
289        self * Self::from(rhs)
290    }
291}
292
293impl<F: Field> MulAssign for SymbolicExpression<F> {
294    fn mul_assign(&mut self, rhs: Self) {
295        *self = self.clone() * rhs;
296    }
297}
298
299impl<F: Field> MulAssign<F> for SymbolicExpression<F> {
300    fn mul_assign(&mut self, rhs: F) {
301        *self *= Self::from(rhs);
302    }
303}
304
305impl<F: Field> MulAssign<SymbolicVariable<F>> for SymbolicExpression<F> {
306    fn mul_assign(&mut self, rhs: SymbolicVariable<F>) {
307        *self *= SymbolicExpression::from(rhs);
308    }
309}
310
311impl<F: Field> Product for SymbolicExpression<F> {
312    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
313        iter.reduce(|x, y| x * y).unwrap_or(Self::ONE)
314    }
315}
316
317impl<F: Field> Product<F> for SymbolicExpression<F> {
318    fn product<I: Iterator<Item = F>>(iter: I) -> Self {
319        iter.map(|x| Self::from(x)).product()
320    }
321}
322
323impl<F: Field> Algebra<F> for SymbolicExpression<F> {}
324impl<F: Field> Algebra<SymbolicVariable<F>> for SymbolicExpression<F> {}
325
326pub trait SymbolicEvaluator<F, E>
327where
328    F: Field,
329    E: Add<E, Output = E> + Sub<E, Output = E> + Mul<E, Output = E> + Neg<Output = E>,
330{
331    fn eval_const(&self, c: F) -> E;
332    fn eval_var(&self, symbolic_var: SymbolicVariable<F>) -> E;
333    fn eval_is_first_row(&self) -> E;
334    fn eval_is_last_row(&self) -> E;
335    fn eval_is_transition(&self) -> E;
336
337    fn eval_expr(&self, symbolic_expr: &SymbolicExpression<F>) -> E {
338        match symbolic_expr {
339            SymbolicExpression::Variable(var) => self.eval_var(*var),
340            SymbolicExpression::Constant(c) => self.eval_const(*c),
341            SymbolicExpression::Add { x, y, .. } => self.eval_expr(x) + self.eval_expr(y),
342            SymbolicExpression::Sub { x, y, .. } => self.eval_expr(x) - self.eval_expr(y),
343            SymbolicExpression::Neg { x, .. } => -self.eval_expr(x),
344            SymbolicExpression::Mul { x, y, .. } => self.eval_expr(x) * self.eval_expr(y),
345            SymbolicExpression::IsFirstRow => self.eval_is_first_row(),
346            SymbolicExpression::IsLastRow => self.eval_is_last_row(),
347            SymbolicExpression::IsTransition => self.eval_is_transition(),
348        }
349    }
350
351    /// Assumes that `nodes` are in topological order (if B references A, then B comes after A).
352    /// Simple serial evaluation in order.
353    fn eval_nodes(&self, nodes: &[SymbolicExpressionNode<F>]) -> Vec<E>
354    where
355        E: Clone,
356    {
357        let mut exprs: Vec<E> = Vec::with_capacity(nodes.len());
358        for node in nodes {
359            let expr = match *node {
360                SymbolicExpressionNode::Variable(var) => self.eval_var(var),
361                SymbolicExpressionNode::Constant(c) => self.eval_const(c),
362                SymbolicExpressionNode::Add {
363                    left_idx,
364                    right_idx,
365                    ..
366                } => exprs[left_idx].clone() + exprs[right_idx].clone(),
367                SymbolicExpressionNode::Sub {
368                    left_idx,
369                    right_idx,
370                    ..
371                } => exprs[left_idx].clone() - exprs[right_idx].clone(),
372                SymbolicExpressionNode::Neg { idx, .. } => -exprs[idx].clone(),
373                SymbolicExpressionNode::Mul {
374                    left_idx,
375                    right_idx,
376                    ..
377                } => exprs[left_idx].clone() * exprs[right_idx].clone(),
378                SymbolicExpressionNode::IsFirstRow => self.eval_is_first_row(),
379                SymbolicExpressionNode::IsLastRow => self.eval_is_last_row(),
380                SymbolicExpressionNode::IsTransition => self.eval_is_transition(),
381            };
382            exprs.push(expr);
383        }
384        exprs
385    }
386}