p3_uni_stark/
symbolic_expression.rs

1use alloc::rc::Rc;
2use core::cmp;
3use core::fmt::Debug;
4use core::iter::{Product, Sum};
5use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
6
7use p3_field::{Field, FieldAlgebra};
8
9use crate::symbolic_variable::SymbolicVariable;
10
11/// An expression over `SymbolicVariable`s.
12#[derive(Clone, Debug)]
13pub enum SymbolicExpression<F> {
14    Variable(SymbolicVariable<F>),
15    IsFirstRow,
16    IsLastRow,
17    IsTransition,
18    Constant(F),
19    Add {
20        x: Rc<Self>,
21        y: Rc<Self>,
22        degree_multiple: usize,
23    },
24    Sub {
25        x: Rc<Self>,
26        y: Rc<Self>,
27        degree_multiple: usize,
28    },
29    Neg {
30        x: Rc<Self>,
31        degree_multiple: usize,
32    },
33    Mul {
34        x: Rc<Self>,
35        y: Rc<Self>,
36        degree_multiple: usize,
37    },
38}
39
40impl<F> SymbolicExpression<F> {
41    /// Returns the multiple of `n` (the trace length) in this expression's degree.
42    pub const fn degree_multiple(&self) -> usize {
43        match self {
44            SymbolicExpression::Variable(v) => v.degree_multiple(),
45            SymbolicExpression::IsFirstRow => 1,
46            SymbolicExpression::IsLastRow => 1,
47            SymbolicExpression::IsTransition => 0,
48            SymbolicExpression::Constant(_) => 0,
49            SymbolicExpression::Add {
50                degree_multiple, ..
51            } => *degree_multiple,
52            SymbolicExpression::Sub {
53                degree_multiple, ..
54            } => *degree_multiple,
55            SymbolicExpression::Neg {
56                degree_multiple, ..
57            } => *degree_multiple,
58            SymbolicExpression::Mul {
59                degree_multiple, ..
60            } => *degree_multiple,
61        }
62    }
63}
64
65impl<F: Field> Default for SymbolicExpression<F> {
66    fn default() -> Self {
67        Self::Constant(F::ZERO)
68    }
69}
70
71impl<F: Field> From<F> for SymbolicExpression<F> {
72    fn from(value: F) -> Self {
73        Self::Constant(value)
74    }
75}
76
77impl<F: Field> FieldAlgebra for SymbolicExpression<F> {
78    type F = F;
79
80    const ZERO: Self = Self::Constant(F::ZERO);
81    const ONE: Self = Self::Constant(F::ONE);
82    const TWO: Self = Self::Constant(F::TWO);
83    const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
84
85    #[inline]
86    fn from_f(f: Self::F) -> Self {
87        f.into()
88    }
89
90    fn from_canonical_u8(n: u8) -> Self {
91        Self::Constant(F::from_canonical_u8(n))
92    }
93
94    fn from_canonical_u16(n: u16) -> Self {
95        Self::Constant(F::from_canonical_u16(n))
96    }
97
98    fn from_canonical_u32(n: u32) -> Self {
99        Self::Constant(F::from_canonical_u32(n))
100    }
101
102    fn from_canonical_u64(n: u64) -> Self {
103        Self::Constant(F::from_canonical_u64(n))
104    }
105
106    fn from_canonical_usize(n: usize) -> Self {
107        Self::Constant(F::from_canonical_usize(n))
108    }
109
110    fn from_wrapped_u32(n: u32) -> Self {
111        Self::Constant(F::from_wrapped_u32(n))
112    }
113
114    fn from_wrapped_u64(n: u64) -> Self {
115        Self::Constant(F::from_wrapped_u64(n))
116    }
117}
118
119impl<F: Field, T> Add<T> for SymbolicExpression<F>
120where
121    T: Into<Self>,
122{
123    type Output = Self;
124
125    fn add(self, rhs: T) -> Self {
126        let rhs = rhs.into();
127        match (self, rhs) {
128            (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs + rhs),
129            (lhs, rhs) => {
130                let degree_multiple = cmp::max(lhs.degree_multiple(), rhs.degree_multiple());
131                Self::Add {
132                    x: Rc::new(lhs),
133                    y: Rc::new(rhs),
134                    degree_multiple,
135                }
136            }
137        }
138    }
139}
140
141impl<F: Field, T> AddAssign<T> for SymbolicExpression<F>
142where
143    T: Into<Self>,
144{
145    fn add_assign(&mut self, rhs: T) {
146        *self = self.clone() + rhs.into();
147    }
148}
149
150impl<F: Field, T> Sum<T> for SymbolicExpression<F>
151where
152    T: Into<Self>,
153{
154    fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
155        iter.map(Into::into)
156            .reduce(|x, y| x + y)
157            .unwrap_or(Self::ZERO)
158    }
159}
160
161impl<F: Field, T> Sub<T> for SymbolicExpression<F>
162where
163    T: Into<Self>,
164{
165    type Output = Self;
166
167    fn sub(self, rhs: T) -> Self {
168        let rhs = rhs.into();
169        match (self, rhs) {
170            (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs - rhs),
171            (lhs, rhs) => {
172                let degree_multiple = cmp::max(lhs.degree_multiple(), rhs.degree_multiple());
173                Self::Sub {
174                    x: Rc::new(lhs),
175                    y: Rc::new(rhs),
176                    degree_multiple,
177                }
178            }
179        }
180    }
181}
182
183impl<F: Field, T> SubAssign<T> for SymbolicExpression<F>
184where
185    T: Into<Self>,
186{
187    fn sub_assign(&mut self, rhs: T) {
188        *self = self.clone() - rhs.into();
189    }
190}
191
192impl<F: Field> Neg for SymbolicExpression<F> {
193    type Output = Self;
194
195    fn neg(self) -> Self {
196        match self {
197            Self::Constant(c) => Self::Constant(-c),
198            expr => {
199                let degree_multiple = expr.degree_multiple();
200                Self::Neg {
201                    x: Rc::new(expr),
202                    degree_multiple,
203                }
204            }
205        }
206    }
207}
208
209impl<F: Field, T> Mul<T> for SymbolicExpression<F>
210where
211    T: Into<Self>,
212{
213    type Output = Self;
214
215    fn mul(self, rhs: T) -> Self {
216        let rhs = rhs.into();
217        match (self, rhs) {
218            (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs * rhs),
219            (lhs, rhs) => {
220                #[allow(clippy::suspicious_arithmetic_impl)]
221                let degree_multiple = lhs.degree_multiple() + rhs.degree_multiple();
222                Self::Mul {
223                    x: Rc::new(lhs),
224                    y: Rc::new(rhs),
225                    degree_multiple,
226                }
227            }
228        }
229    }
230}
231
232impl<F: Field, T> MulAssign<T> for SymbolicExpression<F>
233where
234    T: Into<Self>,
235{
236    fn mul_assign(&mut self, rhs: T) {
237        *self = self.clone() * rhs.into();
238    }
239}
240
241impl<F: Field, T> Product<T> for SymbolicExpression<F>
242where
243    T: Into<Self>,
244{
245    fn product<I: Iterator<Item = T>>(iter: I) -> Self {
246        iter.map(Into::into)
247            .reduce(|x, y| x * y)
248            .unwrap_or(Self::ONE)
249    }
250}