1use core::{
4 fmt::Debug,
5 iter::{Product, Sum},
6 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
7};
8use std::{
9 hash::{Hash, Hasher},
10 ptr,
11 sync::Arc,
12};
13
14use p3_field::{Algebra, Field, PrimeCharacteristicRing};
15use serde::{Deserialize, Serialize};
16
17use super::{dag::SymbolicExpressionNode, symbolic_variable::SymbolicVariable};
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
22#[serde(bound = "F: Field")]
23pub enum SymbolicExpression<F> {
24 Variable(SymbolicVariable<F>),
25 IsFirstRow,
26 IsLastRow,
27 IsTransition,
28 Constant(F),
29 Add {
30 x: Arc<Self>,
31 y: Arc<Self>,
32 degree_multiple: usize,
33 },
34 Sub {
35 x: Arc<Self>,
36 y: Arc<Self>,
37 degree_multiple: usize,
38 },
39 Neg {
40 x: Arc<Self>,
41 degree_multiple: usize,
42 },
43 Mul {
44 x: Arc<Self>,
45 y: Arc<Self>,
46 degree_multiple: usize,
47 },
48}
49
50impl<F: Field> Hash for SymbolicExpression<F> {
51 fn hash<H: Hasher>(&self, state: &mut H) {
52 std::mem::discriminant(self).hash(state);
54 match self {
56 Self::Variable(v) => v.hash(state),
57 Self::IsFirstRow => {} Self::IsLastRow => {} Self::IsTransition => {} Self::Constant(f) => f.hash(state),
61 Self::Add { x, y, .. } => {
62 ptr::hash(&**x, state);
63 ptr::hash(&**y, state);
64 }
65 Self::Sub { x, y, .. } => {
66 ptr::hash(&**x, state);
67 ptr::hash(&**y, state);
68 }
69 Self::Neg { x, .. } => {
70 ptr::hash(&**x, state);
71 }
72 Self::Mul { x, y, .. } => {
73 ptr::hash(&**x, state);
74 ptr::hash(&**y, state);
75 }
76 }
77 }
78}
79
80impl<F: Field> PartialEq for SymbolicExpression<F> {
84 fn eq(&self, other: &Self) -> bool {
85 if std::mem::discriminant(self) != std::mem::discriminant(other) {
87 return false;
88 }
89
90 match (self, other) {
92 (Self::Variable(v1), Self::Variable(v2)) => v1 == v2,
93 (Self::IsFirstRow, Self::IsFirstRow) => true,
96 (Self::IsLastRow, Self::IsLastRow) => true,
97 (Self::IsTransition, Self::IsTransition) => true,
98 (Self::Constant(c1), Self::Constant(c2)) => c1 == c2,
99 (Self::Add { x: x1, y: y1, .. }, Self::Add { x: x2, y: y2, .. }) => {
101 Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
102 }
103 (Self::Sub { x: x1, y: y1, .. }, Self::Sub { x: x2, y: y2, .. }) => {
104 Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
105 }
106 (Self::Neg { x: x1, .. }, Self::Neg { x: x2, .. }) => Arc::ptr_eq(x1, x2),
107 (Self::Mul { x: x1, y: y1, .. }, Self::Mul { x: x2, y: y2, .. }) => {
108 Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
109 }
110 _ => false,
112 }
113 }
114}
115
116impl<F: Field> Eq for SymbolicExpression<F> {}
117
118impl<F: Field> SymbolicExpression<F> {
119 pub const fn degree_multiple(&self) -> usize {
121 match self {
122 SymbolicExpression::Variable(v) => v.degree_multiple(),
123 SymbolicExpression::IsFirstRow => 1,
124 SymbolicExpression::IsLastRow => 1,
125 SymbolicExpression::IsTransition => 0,
126 SymbolicExpression::Constant(_) => 0,
127 SymbolicExpression::Add {
128 degree_multiple, ..
129 } => *degree_multiple,
130 SymbolicExpression::Sub {
131 degree_multiple, ..
132 } => *degree_multiple,
133 SymbolicExpression::Neg {
134 degree_multiple, ..
135 } => *degree_multiple,
136 SymbolicExpression::Mul {
137 degree_multiple, ..
138 } => *degree_multiple,
139 }
140 }
141}
142
143impl<F: Field> Default for SymbolicExpression<F> {
144 fn default() -> Self {
145 Self::Constant(F::ZERO)
146 }
147}
148
149impl<F: Field> From<F> for SymbolicExpression<F> {
150 fn from(value: F) -> Self {
151 Self::Constant(value)
152 }
153}
154
155impl<F: Field> PrimeCharacteristicRing for SymbolicExpression<F> {
156 type PrimeSubfield = F::PrimeSubfield;
157
158 const ZERO: Self = Self::Constant(F::ZERO);
159 const ONE: Self = Self::Constant(F::ONE);
160 const TWO: Self = Self::Constant(F::TWO);
161 const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
162
163 #[inline]
164 fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
165 F::from_prime_subfield(f).into()
166 }
167}
168
169impl<F: Field> Add for SymbolicExpression<F> {
170 type Output = Self;
171
172 fn add(self, rhs: Self) -> Self {
173 let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
174 Self::Add {
175 x: Arc::new(self),
176 y: Arc::new(rhs),
177 degree_multiple,
178 }
179 }
180}
181
182impl<F: Field> Add<F> for SymbolicExpression<F> {
183 type Output = Self;
184
185 fn add(self, rhs: F) -> Self {
186 self + Self::from(rhs)
187 }
188}
189
190impl<F: Field> AddAssign for SymbolicExpression<F> {
191 fn add_assign(&mut self, rhs: Self) {
192 *self = self.clone() + rhs;
193 }
194}
195
196impl<F: Field> AddAssign<SymbolicVariable<F>> for SymbolicExpression<F> {
197 fn add_assign(&mut self, rhs: SymbolicVariable<F>) {
198 *self += SymbolicExpression::from(rhs);
199 }
200}
201
202impl<F: Field> AddAssign<F> for SymbolicExpression<F> {
203 fn add_assign(&mut self, rhs: F) {
204 *self += Self::from(rhs);
205 }
206}
207
208impl<F: Field> Sum for SymbolicExpression<F> {
209 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
210 iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO)
211 }
212}
213
214impl<F: Field> Sum<F> for SymbolicExpression<F> {
215 fn sum<I: Iterator<Item = F>>(iter: I) -> Self {
216 iter.map(|x| Self::from(x)).sum()
217 }
218}
219
220impl<F: Field> Sub for SymbolicExpression<F> {
221 type Output = Self;
222
223 fn sub(self, rhs: Self) -> Self {
224 let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
225 Self::Sub {
226 x: Arc::new(self),
227 y: Arc::new(rhs),
228 degree_multiple,
229 }
230 }
231}
232
233impl<F: Field> Sub<F> for SymbolicExpression<F> {
234 type Output = Self;
235
236 fn sub(self, rhs: F) -> Self {
237 self - Self::from(rhs)
238 }
239}
240
241impl<F: Field> SubAssign for SymbolicExpression<F> {
242 fn sub_assign(&mut self, rhs: Self) {
243 *self = self.clone() - rhs;
244 }
245}
246
247impl<F: Field> SubAssign<F> for SymbolicExpression<F> {
248 fn sub_assign(&mut self, rhs: F) {
249 *self -= Self::from(rhs);
250 }
251}
252
253impl<F: Field> SubAssign<SymbolicVariable<F>> for SymbolicExpression<F> {
254 fn sub_assign(&mut self, rhs: SymbolicVariable<F>) {
255 *self -= SymbolicExpression::from(rhs);
256 }
257}
258
259impl<F: Field> Neg for SymbolicExpression<F> {
260 type Output = Self;
261
262 fn neg(self) -> Self {
263 let degree_multiple = self.degree_multiple();
264 Self::Neg {
265 x: Arc::new(self),
266 degree_multiple,
267 }
268 }
269}
270
271impl<F: Field> Mul for SymbolicExpression<F> {
272 type Output = Self;
273
274 fn mul(self, rhs: Self) -> Self {
275 #[allow(clippy::suspicious_arithmetic_impl)]
276 let degree_multiple = self.degree_multiple() + rhs.degree_multiple();
277 Self::Mul {
278 x: Arc::new(self),
279 y: Arc::new(rhs),
280 degree_multiple,
281 }
282 }
283}
284
285impl<F: Field> Mul<F> for SymbolicExpression<F> {
286 type Output = Self;
287
288 fn mul(self, rhs: F) -> Self {
289 self * Self::from(rhs)
290 }
291}
292
293impl<F: Field> MulAssign for SymbolicExpression<F> {
294 fn mul_assign(&mut self, rhs: Self) {
295 *self = self.clone() * rhs;
296 }
297}
298
299impl<F: Field> MulAssign<F> for SymbolicExpression<F> {
300 fn mul_assign(&mut self, rhs: F) {
301 *self *= Self::from(rhs);
302 }
303}
304
305impl<F: Field> MulAssign<SymbolicVariable<F>> for SymbolicExpression<F> {
306 fn mul_assign(&mut self, rhs: SymbolicVariable<F>) {
307 *self *= SymbolicExpression::from(rhs);
308 }
309}
310
311impl<F: Field> Product for SymbolicExpression<F> {
312 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
313 iter.reduce(|x, y| x * y).unwrap_or(Self::ONE)
314 }
315}
316
317impl<F: Field> Product<F> for SymbolicExpression<F> {
318 fn product<I: Iterator<Item = F>>(iter: I) -> Self {
319 iter.map(|x| Self::from(x)).product()
320 }
321}
322
323impl<F: Field> Algebra<F> for SymbolicExpression<F> {}
324impl<F: Field> Algebra<SymbolicVariable<F>> for SymbolicExpression<F> {}
325
326pub trait SymbolicEvaluator<F, E>
327where
328 F: Field,
329 E: Add<E, Output = E> + Sub<E, Output = E> + Mul<E, Output = E> + Neg<Output = E>,
330{
331 fn eval_const(&self, c: F) -> E;
332 fn eval_var(&self, symbolic_var: SymbolicVariable<F>) -> E;
333 fn eval_is_first_row(&self) -> E;
334 fn eval_is_last_row(&self) -> E;
335 fn eval_is_transition(&self) -> E;
336
337 fn eval_expr(&self, symbolic_expr: &SymbolicExpression<F>) -> E {
338 match symbolic_expr {
339 SymbolicExpression::Variable(var) => self.eval_var(*var),
340 SymbolicExpression::Constant(c) => self.eval_const(*c),
341 SymbolicExpression::Add { x, y, .. } => self.eval_expr(x) + self.eval_expr(y),
342 SymbolicExpression::Sub { x, y, .. } => self.eval_expr(x) - self.eval_expr(y),
343 SymbolicExpression::Neg { x, .. } => -self.eval_expr(x),
344 SymbolicExpression::Mul { x, y, .. } => self.eval_expr(x) * self.eval_expr(y),
345 SymbolicExpression::IsFirstRow => self.eval_is_first_row(),
346 SymbolicExpression::IsLastRow => self.eval_is_last_row(),
347 SymbolicExpression::IsTransition => self.eval_is_transition(),
348 }
349 }
350
351 fn eval_nodes(&self, nodes: &[SymbolicExpressionNode<F>]) -> Vec<E>
354 where
355 E: Clone,
356 {
357 let mut exprs: Vec<E> = Vec::with_capacity(nodes.len());
358 for node in nodes {
359 let expr = match *node {
360 SymbolicExpressionNode::Variable(var) => self.eval_var(var),
361 SymbolicExpressionNode::Constant(c) => self.eval_const(c),
362 SymbolicExpressionNode::Add {
363 left_idx,
364 right_idx,
365 ..
366 } => exprs[left_idx].clone() + exprs[right_idx].clone(),
367 SymbolicExpressionNode::Sub {
368 left_idx,
369 right_idx,
370 ..
371 } => exprs[left_idx].clone() - exprs[right_idx].clone(),
372 SymbolicExpressionNode::Neg { idx, .. } => -exprs[idx].clone(),
373 SymbolicExpressionNode::Mul {
374 left_idx,
375 right_idx,
376 ..
377 } => exprs[left_idx].clone() * exprs[right_idx].clone(),
378 SymbolicExpressionNode::IsFirstRow => self.eval_is_first_row(),
379 SymbolicExpressionNode::IsLastRow => self.eval_is_last_row(),
380 SymbolicExpressionNode::IsTransition => self.eval_is_transition(),
381 };
382 exprs.push(expr);
383 }
384 exprs
385 }
386}