openvm_stark_backend/air_builders/symbolic/
symbolic_expression.rs

1// Copied from uni-stark/src/symbolic_expression.rs to use Arc instead of Rc.
2
3use core::{
4    fmt::Debug,
5    iter::{Product, Sum},
6    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
7};
8use std::{
9    hash::{Hash, Hasher},
10    ptr,
11    sync::Arc,
12};
13
14use p3_field::{Field, FieldAlgebra};
15use serde::{Deserialize, Serialize};
16
17use super::{dag::SymbolicExpressionNode, symbolic_variable::SymbolicVariable};
18
19/// An expression over `SymbolicVariable`s.
20// Note: avoid deriving Hash because it will hash the entire sub-tree
21#[derive(Clone, Debug, Serialize, Deserialize)]
22#[serde(bound = "F: Field")]
23pub enum SymbolicExpression<F> {
24    Variable(SymbolicVariable<F>),
25    IsFirstRow,
26    IsLastRow,
27    IsTransition,
28    Constant(F),
29    Add {
30        x: Arc<Self>,
31        y: Arc<Self>,
32        degree_multiple: usize,
33    },
34    Sub {
35        x: Arc<Self>,
36        y: Arc<Self>,
37        degree_multiple: usize,
38    },
39    Neg {
40        x: Arc<Self>,
41        degree_multiple: usize,
42    },
43    Mul {
44        x: Arc<Self>,
45        y: Arc<Self>,
46        degree_multiple: usize,
47    },
48}
49
50impl<F: Field> Hash for SymbolicExpression<F> {
51    fn hash<H: Hasher>(&self, state: &mut H) {
52        // First hash the discriminant of the enum
53        std::mem::discriminant(self).hash(state);
54        // Degree multiple is not necessary
55        match self {
56            Self::Variable(v) => v.hash(state),
57            Self::IsFirstRow => {}   // discriminant is enough
58            Self::IsLastRow => {}    // discriminant is enough
59            Self::IsTransition => {} // discriminant is enough
60            Self::Constant(f) => f.hash(state),
61            Self::Add { x, y, .. } => {
62                ptr::hash(&**x, state);
63                ptr::hash(&**y, state);
64            }
65            Self::Sub { x, y, .. } => {
66                ptr::hash(&**x, state);
67                ptr::hash(&**y, state);
68            }
69            Self::Neg { x, .. } => {
70                ptr::hash(&**x, state);
71            }
72            Self::Mul { x, y, .. } => {
73                ptr::hash(&**x, state);
74                ptr::hash(&**y, state);
75            }
76        }
77    }
78}
79
80// We intentionally do not compare degree_multiple in PartialEq and Eq because degree_multiple is
81// metadata used for optimization/debugging purposes but it does not change the underlying expression.
82impl<F: Field> PartialEq for SymbolicExpression<F> {
83    fn eq(&self, other: &Self) -> bool {
84        // First check if the variants match
85        if std::mem::discriminant(self) != std::mem::discriminant(other) {
86            return false;
87        }
88
89        // Then check equality based on variant-specific data
90        match (self, other) {
91            (Self::Variable(v1), Self::Variable(v2)) => v1 == v2,
92            // IsFirstRow, IsLastRow, and IsTransition are all unit variants,
93            // so if the discriminants match, they're equal
94            (Self::IsFirstRow, Self::IsFirstRow) => true,
95            (Self::IsLastRow, Self::IsLastRow) => true,
96            (Self::IsTransition, Self::IsTransition) => true,
97            (Self::Constant(c1), Self::Constant(c2)) => c1 == c2,
98            // For compound expressions, compare pointers to match how Hash is implemented
99            (Self::Add { x: x1, y: y1, .. }, Self::Add { x: x2, y: y2, .. }) => {
100                Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
101            }
102            (Self::Sub { x: x1, y: y1, .. }, Self::Sub { x: x2, y: y2, .. }) => {
103                Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
104            }
105            (Self::Neg { x: x1, .. }, Self::Neg { x: x2, .. }) => Arc::ptr_eq(x1, x2),
106            (Self::Mul { x: x1, y: y1, .. }, Self::Mul { x: x2, y: y2, .. }) => {
107                Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
108            }
109            // This should never be reached because we've already checked the discriminants
110            _ => false,
111        }
112    }
113}
114
115impl<F: Field> Eq for SymbolicExpression<F> {}
116
117impl<F: Field> SymbolicExpression<F> {
118    /// Returns the multiple of `n` (the trace length) in this expression's degree.
119    pub const fn degree_multiple(&self) -> usize {
120        match self {
121            SymbolicExpression::Variable(v) => v.degree_multiple(),
122            SymbolicExpression::IsFirstRow => 1,
123            SymbolicExpression::IsLastRow => 1,
124            SymbolicExpression::IsTransition => 0,
125            SymbolicExpression::Constant(_) => 0,
126            SymbolicExpression::Add {
127                degree_multiple, ..
128            } => *degree_multiple,
129            SymbolicExpression::Sub {
130                degree_multiple, ..
131            } => *degree_multiple,
132            SymbolicExpression::Neg {
133                degree_multiple, ..
134            } => *degree_multiple,
135            SymbolicExpression::Mul {
136                degree_multiple, ..
137            } => *degree_multiple,
138        }
139    }
140}
141
142impl<F: Field> Default for SymbolicExpression<F> {
143    fn default() -> Self {
144        Self::Constant(F::ZERO)
145    }
146}
147
148impl<F: Field> From<F> for SymbolicExpression<F> {
149    fn from(value: F) -> Self {
150        Self::Constant(value)
151    }
152}
153
154impl<F: Field> FieldAlgebra for SymbolicExpression<F> {
155    type F = F;
156
157    const ZERO: Self = Self::Constant(F::ZERO);
158    const ONE: Self = Self::Constant(F::ONE);
159    const TWO: Self = Self::Constant(F::TWO);
160    const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
161
162    #[inline]
163    fn from_f(f: Self::F) -> Self {
164        f.into()
165    }
166
167    fn from_bool(b: bool) -> Self {
168        Self::Constant(F::from_bool(b))
169    }
170
171    fn from_canonical_u8(n: u8) -> Self {
172        Self::Constant(F::from_canonical_u8(n))
173    }
174
175    fn from_canonical_u16(n: u16) -> Self {
176        Self::Constant(F::from_canonical_u16(n))
177    }
178
179    fn from_canonical_u32(n: u32) -> Self {
180        Self::Constant(F::from_canonical_u32(n))
181    }
182
183    fn from_canonical_u64(n: u64) -> Self {
184        Self::Constant(F::from_canonical_u64(n))
185    }
186
187    fn from_canonical_usize(n: usize) -> Self {
188        Self::Constant(F::from_canonical_usize(n))
189    }
190
191    fn from_wrapped_u32(n: u32) -> Self {
192        Self::Constant(F::from_wrapped_u32(n))
193    }
194
195    fn from_wrapped_u64(n: u64) -> Self {
196        Self::Constant(F::from_wrapped_u64(n))
197    }
198}
199
200impl<F: Field> Add for SymbolicExpression<F> {
201    type Output = Self;
202
203    fn add(self, rhs: Self) -> Self {
204        let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
205        Self::Add {
206            x: Arc::new(self),
207            y: Arc::new(rhs),
208            degree_multiple,
209        }
210    }
211}
212
213impl<F: Field> Add<F> for SymbolicExpression<F> {
214    type Output = Self;
215
216    fn add(self, rhs: F) -> Self {
217        self + Self::from(rhs)
218    }
219}
220
221impl<F: Field> AddAssign for SymbolicExpression<F> {
222    fn add_assign(&mut self, rhs: Self) {
223        *self = self.clone() + rhs;
224    }
225}
226
227impl<F: Field> AddAssign<F> for SymbolicExpression<F> {
228    fn add_assign(&mut self, rhs: F) {
229        *self += Self::from(rhs);
230    }
231}
232
233impl<F: Field> Sum for SymbolicExpression<F> {
234    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
235        iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO)
236    }
237}
238
239impl<F: Field> Sum<F> for SymbolicExpression<F> {
240    fn sum<I: Iterator<Item = F>>(iter: I) -> Self {
241        iter.map(|x| Self::from(x)).sum()
242    }
243}
244
245impl<F: Field> Sub for SymbolicExpression<F> {
246    type Output = Self;
247
248    fn sub(self, rhs: Self) -> Self {
249        let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
250        Self::Sub {
251            x: Arc::new(self),
252            y: Arc::new(rhs),
253            degree_multiple,
254        }
255    }
256}
257
258impl<F: Field> Sub<F> for SymbolicExpression<F> {
259    type Output = Self;
260
261    fn sub(self, rhs: F) -> Self {
262        self - Self::from(rhs)
263    }
264}
265
266impl<F: Field> SubAssign for SymbolicExpression<F> {
267    fn sub_assign(&mut self, rhs: Self) {
268        *self = self.clone() - rhs;
269    }
270}
271
272impl<F: Field> SubAssign<F> for SymbolicExpression<F> {
273    fn sub_assign(&mut self, rhs: F) {
274        *self -= Self::from(rhs);
275    }
276}
277
278impl<F: Field> Neg for SymbolicExpression<F> {
279    type Output = Self;
280
281    fn neg(self) -> Self {
282        let degree_multiple = self.degree_multiple();
283        Self::Neg {
284            x: Arc::new(self),
285            degree_multiple,
286        }
287    }
288}
289
290impl<F: Field> Mul for SymbolicExpression<F> {
291    type Output = Self;
292
293    fn mul(self, rhs: Self) -> Self {
294        #[allow(clippy::suspicious_arithmetic_impl)]
295        let degree_multiple = self.degree_multiple() + rhs.degree_multiple();
296        Self::Mul {
297            x: Arc::new(self),
298            y: Arc::new(rhs),
299            degree_multiple,
300        }
301    }
302}
303
304impl<F: Field> Mul<F> for SymbolicExpression<F> {
305    type Output = Self;
306
307    fn mul(self, rhs: F) -> Self {
308        self * Self::from(rhs)
309    }
310}
311
312impl<F: Field> MulAssign for SymbolicExpression<F> {
313    fn mul_assign(&mut self, rhs: Self) {
314        *self = self.clone() * rhs;
315    }
316}
317
318impl<F: Field> MulAssign<F> for SymbolicExpression<F> {
319    fn mul_assign(&mut self, rhs: F) {
320        *self *= Self::from(rhs);
321    }
322}
323
324impl<F: Field> Product for SymbolicExpression<F> {
325    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
326        iter.reduce(|x, y| x * y).unwrap_or(Self::ONE)
327    }
328}
329
330impl<F: Field> Product<F> for SymbolicExpression<F> {
331    fn product<I: Iterator<Item = F>>(iter: I) -> Self {
332        iter.map(|x| Self::from(x)).product()
333    }
334}
335
336pub trait SymbolicEvaluator<F, E>
337where
338    F: Field,
339    E: Add<E, Output = E> + Sub<E, Output = E> + Mul<E, Output = E> + Neg<Output = E>,
340{
341    fn eval_const(&self, c: F) -> E;
342    fn eval_var(&self, symbolic_var: SymbolicVariable<F>) -> E;
343    fn eval_is_first_row(&self) -> E;
344    fn eval_is_last_row(&self) -> E;
345    fn eval_is_transition(&self) -> E;
346
347    fn eval_expr(&self, symbolic_expr: &SymbolicExpression<F>) -> E {
348        match symbolic_expr {
349            SymbolicExpression::Variable(var) => self.eval_var(*var),
350            SymbolicExpression::Constant(c) => self.eval_const(*c),
351            SymbolicExpression::Add { x, y, .. } => self.eval_expr(x) + self.eval_expr(y),
352            SymbolicExpression::Sub { x, y, .. } => self.eval_expr(x) - self.eval_expr(y),
353            SymbolicExpression::Neg { x, .. } => -self.eval_expr(x),
354            SymbolicExpression::Mul { x, y, .. } => self.eval_expr(x) * self.eval_expr(y),
355            SymbolicExpression::IsFirstRow => self.eval_is_first_row(),
356            SymbolicExpression::IsLastRow => self.eval_is_last_row(),
357            SymbolicExpression::IsTransition => self.eval_is_transition(),
358        }
359    }
360
361    /// Assumes that `nodes` are in topological order (if B references A, then B comes after A).
362    /// Simple serial evaluation in order.
363    fn eval_nodes(&self, nodes: &[SymbolicExpressionNode<F>]) -> Vec<E>
364    where
365        E: Clone,
366    {
367        let mut exprs: Vec<E> = Vec::with_capacity(nodes.len());
368        for node in nodes {
369            let expr = match *node {
370                SymbolicExpressionNode::Variable(var) => self.eval_var(var),
371                SymbolicExpressionNode::Constant(c) => self.eval_const(c),
372                SymbolicExpressionNode::Add {
373                    left_idx,
374                    right_idx,
375                    ..
376                } => exprs[left_idx].clone() + exprs[right_idx].clone(),
377                SymbolicExpressionNode::Sub {
378                    left_idx,
379                    right_idx,
380                    ..
381                } => exprs[left_idx].clone() - exprs[right_idx].clone(),
382                SymbolicExpressionNode::Neg { idx, .. } => -exprs[idx].clone(),
383                SymbolicExpressionNode::Mul {
384                    left_idx,
385                    right_idx,
386                    ..
387                } => exprs[left_idx].clone() * exprs[right_idx].clone(),
388                SymbolicExpressionNode::IsFirstRow => self.eval_is_first_row(),
389                SymbolicExpressionNode::IsLastRow => self.eval_is_last_row(),
390                SymbolicExpressionNode::IsTransition => self.eval_is_transition(),
391            };
392            exprs.push(expr);
393        }
394        exprs
395    }
396}