openvm_stark_backend/air_builders/symbolic/
symbolic_expression.rs

1// Copied from uni-stark/src/symbolic_expression.rs to use Arc instead of Rc.
2
3use core::{
4    fmt::Debug,
5    iter::{Product, Sum},
6    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
7};
8use std::{
9    hash::{Hash, Hasher},
10    ptr,
11    sync::Arc,
12};
13
14use p3_field::{Field, FieldAlgebra};
15use serde::{Deserialize, Serialize};
16
17use super::{dag::SymbolicExpressionNode, symbolic_variable::SymbolicVariable};
18
19/// An expression over `SymbolicVariable`s.
20// Note: avoid deriving Hash because it will hash the entire sub-tree
21#[derive(Clone, Debug, Serialize, Deserialize)]
22#[serde(bound = "F: Field")]
23pub enum SymbolicExpression<F> {
24    Variable(SymbolicVariable<F>),
25    IsFirstRow,
26    IsLastRow,
27    IsTransition,
28    Constant(F),
29    Add {
30        x: Arc<Self>,
31        y: Arc<Self>,
32        degree_multiple: usize,
33    },
34    Sub {
35        x: Arc<Self>,
36        y: Arc<Self>,
37        degree_multiple: usize,
38    },
39    Neg {
40        x: Arc<Self>,
41        degree_multiple: usize,
42    },
43    Mul {
44        x: Arc<Self>,
45        y: Arc<Self>,
46        degree_multiple: usize,
47    },
48}
49
50impl<F: Field> Hash for SymbolicExpression<F> {
51    fn hash<H: Hasher>(&self, state: &mut H) {
52        // First hash the discriminant of the enum
53        std::mem::discriminant(self).hash(state);
54        // Degree multiple is not necessary
55        match self {
56            Self::Variable(v) => v.hash(state),
57            Self::IsFirstRow => {}   // discriminant is enough
58            Self::IsLastRow => {}    // discriminant is enough
59            Self::IsTransition => {} // discriminant is enough
60            Self::Constant(f) => f.hash(state),
61            Self::Add { x, y, .. } => {
62                ptr::hash(&**x, state);
63                ptr::hash(&**y, state);
64            }
65            Self::Sub { x, y, .. } => {
66                ptr::hash(&**x, state);
67                ptr::hash(&**y, state);
68            }
69            Self::Neg { x, .. } => {
70                ptr::hash(&**x, state);
71            }
72            Self::Mul { x, y, .. } => {
73                ptr::hash(&**x, state);
74                ptr::hash(&**y, state);
75            }
76        }
77    }
78}
79
80// We intentionally do not compare degree_multiple in PartialEq and Eq because degree_multiple is
81// metadata used for optimization/debugging purposes but it does not change the underlying
82// expression.
83impl<F: Field> PartialEq for SymbolicExpression<F> {
84    fn eq(&self, other: &Self) -> bool {
85        // First check if the variants match
86        if std::mem::discriminant(self) != std::mem::discriminant(other) {
87            return false;
88        }
89
90        // Then check equality based on variant-specific data
91        match (self, other) {
92            (Self::Variable(v1), Self::Variable(v2)) => v1 == v2,
93            // IsFirstRow, IsLastRow, and IsTransition are all unit variants,
94            // so if the discriminants match, they're equal
95            (Self::IsFirstRow, Self::IsFirstRow) => true,
96            (Self::IsLastRow, Self::IsLastRow) => true,
97            (Self::IsTransition, Self::IsTransition) => true,
98            (Self::Constant(c1), Self::Constant(c2)) => c1 == c2,
99            // For compound expressions, compare pointers to match how Hash is implemented
100            (Self::Add { x: x1, y: y1, .. }, Self::Add { x: x2, y: y2, .. }) => {
101                Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
102            }
103            (Self::Sub { x: x1, y: y1, .. }, Self::Sub { x: x2, y: y2, .. }) => {
104                Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
105            }
106            (Self::Neg { x: x1, .. }, Self::Neg { x: x2, .. }) => Arc::ptr_eq(x1, x2),
107            (Self::Mul { x: x1, y: y1, .. }, Self::Mul { x: x2, y: y2, .. }) => {
108                Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
109            }
110            // This should never be reached because we've already checked the discriminants
111            _ => false,
112        }
113    }
114}
115
116impl<F: Field> Eq for SymbolicExpression<F> {}
117
118impl<F: Field> SymbolicExpression<F> {
119    /// Returns the multiple of `n` (the trace length) in this expression's degree.
120    pub const fn degree_multiple(&self) -> usize {
121        match self {
122            SymbolicExpression::Variable(v) => v.degree_multiple(),
123            SymbolicExpression::IsFirstRow => 1,
124            SymbolicExpression::IsLastRow => 1,
125            SymbolicExpression::IsTransition => 0,
126            SymbolicExpression::Constant(_) => 0,
127            SymbolicExpression::Add {
128                degree_multiple, ..
129            } => *degree_multiple,
130            SymbolicExpression::Sub {
131                degree_multiple, ..
132            } => *degree_multiple,
133            SymbolicExpression::Neg {
134                degree_multiple, ..
135            } => *degree_multiple,
136            SymbolicExpression::Mul {
137                degree_multiple, ..
138            } => *degree_multiple,
139        }
140    }
141}
142
143impl<F: Field> Default for SymbolicExpression<F> {
144    fn default() -> Self {
145        Self::Constant(F::ZERO)
146    }
147}
148
149impl<F: Field> From<F> for SymbolicExpression<F> {
150    fn from(value: F) -> Self {
151        Self::Constant(value)
152    }
153}
154
155impl<F: Field> FieldAlgebra for SymbolicExpression<F> {
156    type F = F;
157
158    const ZERO: Self = Self::Constant(F::ZERO);
159    const ONE: Self = Self::Constant(F::ONE);
160    const TWO: Self = Self::Constant(F::TWO);
161    const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
162
163    #[inline]
164    fn from_f(f: Self::F) -> Self {
165        f.into()
166    }
167
168    fn from_bool(b: bool) -> Self {
169        Self::Constant(F::from_bool(b))
170    }
171
172    fn from_canonical_u8(n: u8) -> Self {
173        Self::Constant(F::from_canonical_u8(n))
174    }
175
176    fn from_canonical_u16(n: u16) -> Self {
177        Self::Constant(F::from_canonical_u16(n))
178    }
179
180    fn from_canonical_u32(n: u32) -> Self {
181        Self::Constant(F::from_canonical_u32(n))
182    }
183
184    fn from_canonical_u64(n: u64) -> Self {
185        Self::Constant(F::from_canonical_u64(n))
186    }
187
188    fn from_canonical_usize(n: usize) -> Self {
189        Self::Constant(F::from_canonical_usize(n))
190    }
191
192    fn from_wrapped_u32(n: u32) -> Self {
193        Self::Constant(F::from_wrapped_u32(n))
194    }
195
196    fn from_wrapped_u64(n: u64) -> Self {
197        Self::Constant(F::from_wrapped_u64(n))
198    }
199}
200
201impl<F: Field> Add for SymbolicExpression<F> {
202    type Output = Self;
203
204    fn add(self, rhs: Self) -> Self {
205        let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
206        Self::Add {
207            x: Arc::new(self),
208            y: Arc::new(rhs),
209            degree_multiple,
210        }
211    }
212}
213
214impl<F: Field> Add<F> for SymbolicExpression<F> {
215    type Output = Self;
216
217    fn add(self, rhs: F) -> Self {
218        self + Self::from(rhs)
219    }
220}
221
222impl<F: Field> AddAssign for SymbolicExpression<F> {
223    fn add_assign(&mut self, rhs: Self) {
224        *self = self.clone() + rhs;
225    }
226}
227
228impl<F: Field> AddAssign<F> for SymbolicExpression<F> {
229    fn add_assign(&mut self, rhs: F) {
230        *self += Self::from(rhs);
231    }
232}
233
234impl<F: Field> Sum for SymbolicExpression<F> {
235    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
236        iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO)
237    }
238}
239
240impl<F: Field> Sum<F> for SymbolicExpression<F> {
241    fn sum<I: Iterator<Item = F>>(iter: I) -> Self {
242        iter.map(|x| Self::from(x)).sum()
243    }
244}
245
246impl<F: Field> Sub for SymbolicExpression<F> {
247    type Output = Self;
248
249    fn sub(self, rhs: Self) -> Self {
250        let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
251        Self::Sub {
252            x: Arc::new(self),
253            y: Arc::new(rhs),
254            degree_multiple,
255        }
256    }
257}
258
259impl<F: Field> Sub<F> for SymbolicExpression<F> {
260    type Output = Self;
261
262    fn sub(self, rhs: F) -> Self {
263        self - Self::from(rhs)
264    }
265}
266
267impl<F: Field> SubAssign for SymbolicExpression<F> {
268    fn sub_assign(&mut self, rhs: Self) {
269        *self = self.clone() - rhs;
270    }
271}
272
273impl<F: Field> SubAssign<F> for SymbolicExpression<F> {
274    fn sub_assign(&mut self, rhs: F) {
275        *self -= Self::from(rhs);
276    }
277}
278
279impl<F: Field> Neg for SymbolicExpression<F> {
280    type Output = Self;
281
282    fn neg(self) -> Self {
283        let degree_multiple = self.degree_multiple();
284        Self::Neg {
285            x: Arc::new(self),
286            degree_multiple,
287        }
288    }
289}
290
291impl<F: Field> Mul for SymbolicExpression<F> {
292    type Output = Self;
293
294    fn mul(self, rhs: Self) -> Self {
295        #[allow(clippy::suspicious_arithmetic_impl)]
296        let degree_multiple = self.degree_multiple() + rhs.degree_multiple();
297        Self::Mul {
298            x: Arc::new(self),
299            y: Arc::new(rhs),
300            degree_multiple,
301        }
302    }
303}
304
305impl<F: Field> Mul<F> for SymbolicExpression<F> {
306    type Output = Self;
307
308    fn mul(self, rhs: F) -> Self {
309        self * Self::from(rhs)
310    }
311}
312
313impl<F: Field> MulAssign for SymbolicExpression<F> {
314    fn mul_assign(&mut self, rhs: Self) {
315        *self = self.clone() * rhs;
316    }
317}
318
319impl<F: Field> MulAssign<F> for SymbolicExpression<F> {
320    fn mul_assign(&mut self, rhs: F) {
321        *self *= Self::from(rhs);
322    }
323}
324
325impl<F: Field> Product for SymbolicExpression<F> {
326    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
327        iter.reduce(|x, y| x * y).unwrap_or(Self::ONE)
328    }
329}
330
331impl<F: Field> Product<F> for SymbolicExpression<F> {
332    fn product<I: Iterator<Item = F>>(iter: I) -> Self {
333        iter.map(|x| Self::from(x)).product()
334    }
335}
336
337pub trait SymbolicEvaluator<F, E>
338where
339    F: Field,
340    E: Add<E, Output = E> + Sub<E, Output = E> + Mul<E, Output = E> + Neg<Output = E>,
341{
342    fn eval_const(&self, c: F) -> E;
343    fn eval_var(&self, symbolic_var: SymbolicVariable<F>) -> E;
344    fn eval_is_first_row(&self) -> E;
345    fn eval_is_last_row(&self) -> E;
346    fn eval_is_transition(&self) -> E;
347
348    fn eval_expr(&self, symbolic_expr: &SymbolicExpression<F>) -> E {
349        match symbolic_expr {
350            SymbolicExpression::Variable(var) => self.eval_var(*var),
351            SymbolicExpression::Constant(c) => self.eval_const(*c),
352            SymbolicExpression::Add { x, y, .. } => self.eval_expr(x) + self.eval_expr(y),
353            SymbolicExpression::Sub { x, y, .. } => self.eval_expr(x) - self.eval_expr(y),
354            SymbolicExpression::Neg { x, .. } => -self.eval_expr(x),
355            SymbolicExpression::Mul { x, y, .. } => self.eval_expr(x) * self.eval_expr(y),
356            SymbolicExpression::IsFirstRow => self.eval_is_first_row(),
357            SymbolicExpression::IsLastRow => self.eval_is_last_row(),
358            SymbolicExpression::IsTransition => self.eval_is_transition(),
359        }
360    }
361
362    /// Assumes that `nodes` are in topological order (if B references A, then B comes after A).
363    /// Simple serial evaluation in order.
364    fn eval_nodes(&self, nodes: &[SymbolicExpressionNode<F>]) -> Vec<E>
365    where
366        E: Clone,
367    {
368        let mut exprs: Vec<E> = Vec::with_capacity(nodes.len());
369        for node in nodes {
370            let expr = match *node {
371                SymbolicExpressionNode::Variable(var) => self.eval_var(var),
372                SymbolicExpressionNode::Constant(c) => self.eval_const(c),
373                SymbolicExpressionNode::Add {
374                    left_idx,
375                    right_idx,
376                    ..
377                } => exprs[left_idx].clone() + exprs[right_idx].clone(),
378                SymbolicExpressionNode::Sub {
379                    left_idx,
380                    right_idx,
381                    ..
382                } => exprs[left_idx].clone() - exprs[right_idx].clone(),
383                SymbolicExpressionNode::Neg { idx, .. } => -exprs[idx].clone(),
384                SymbolicExpressionNode::Mul {
385                    left_idx,
386                    right_idx,
387                    ..
388                } => exprs[left_idx].clone() * exprs[right_idx].clone(),
389                SymbolicExpressionNode::IsFirstRow => self.eval_is_first_row(),
390                SymbolicExpressionNode::IsLastRow => self.eval_is_last_row(),
391                SymbolicExpressionNode::IsTransition => self.eval_is_transition(),
392            };
393            exprs.push(expr);
394        }
395        exprs
396    }
397}