1use core::{
4 fmt::Debug,
5 iter::{Product, Sum},
6 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
7};
8use std::{
9 hash::{Hash, Hasher},
10 ptr,
11 sync::Arc,
12};
13
14use p3_field::{Field, FieldAlgebra};
15use serde::{Deserialize, Serialize};
16
17use super::{dag::SymbolicExpressionNode, symbolic_variable::SymbolicVariable};
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
22#[serde(bound = "F: Field")]
23pub enum SymbolicExpression<F> {
24 Variable(SymbolicVariable<F>),
25 IsFirstRow,
26 IsLastRow,
27 IsTransition,
28 Constant(F),
29 Add {
30 x: Arc<Self>,
31 y: Arc<Self>,
32 degree_multiple: usize,
33 },
34 Sub {
35 x: Arc<Self>,
36 y: Arc<Self>,
37 degree_multiple: usize,
38 },
39 Neg {
40 x: Arc<Self>,
41 degree_multiple: usize,
42 },
43 Mul {
44 x: Arc<Self>,
45 y: Arc<Self>,
46 degree_multiple: usize,
47 },
48}
49
50impl<F: Field> Hash for SymbolicExpression<F> {
51 fn hash<H: Hasher>(&self, state: &mut H) {
52 std::mem::discriminant(self).hash(state);
54 match self {
56 Self::Variable(v) => v.hash(state),
57 Self::IsFirstRow => {} Self::IsLastRow => {} Self::IsTransition => {} Self::Constant(f) => f.hash(state),
61 Self::Add { x, y, .. } => {
62 ptr::hash(&**x, state);
63 ptr::hash(&**y, state);
64 }
65 Self::Sub { x, y, .. } => {
66 ptr::hash(&**x, state);
67 ptr::hash(&**y, state);
68 }
69 Self::Neg { x, .. } => {
70 ptr::hash(&**x, state);
71 }
72 Self::Mul { x, y, .. } => {
73 ptr::hash(&**x, state);
74 ptr::hash(&**y, state);
75 }
76 }
77 }
78}
79
80impl<F: Field> PartialEq for SymbolicExpression<F> {
84 fn eq(&self, other: &Self) -> bool {
85 if std::mem::discriminant(self) != std::mem::discriminant(other) {
87 return false;
88 }
89
90 match (self, other) {
92 (Self::Variable(v1), Self::Variable(v2)) => v1 == v2,
93 (Self::IsFirstRow, Self::IsFirstRow) => true,
96 (Self::IsLastRow, Self::IsLastRow) => true,
97 (Self::IsTransition, Self::IsTransition) => true,
98 (Self::Constant(c1), Self::Constant(c2)) => c1 == c2,
99 (Self::Add { x: x1, y: y1, .. }, Self::Add { x: x2, y: y2, .. }) => {
101 Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
102 }
103 (Self::Sub { x: x1, y: y1, .. }, Self::Sub { x: x2, y: y2, .. }) => {
104 Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
105 }
106 (Self::Neg { x: x1, .. }, Self::Neg { x: x2, .. }) => Arc::ptr_eq(x1, x2),
107 (Self::Mul { x: x1, y: y1, .. }, Self::Mul { x: x2, y: y2, .. }) => {
108 Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
109 }
110 _ => false,
112 }
113 }
114}
115
116impl<F: Field> Eq for SymbolicExpression<F> {}
117
118impl<F: Field> SymbolicExpression<F> {
119 pub const fn degree_multiple(&self) -> usize {
121 match self {
122 SymbolicExpression::Variable(v) => v.degree_multiple(),
123 SymbolicExpression::IsFirstRow => 1,
124 SymbolicExpression::IsLastRow => 1,
125 SymbolicExpression::IsTransition => 0,
126 SymbolicExpression::Constant(_) => 0,
127 SymbolicExpression::Add {
128 degree_multiple, ..
129 } => *degree_multiple,
130 SymbolicExpression::Sub {
131 degree_multiple, ..
132 } => *degree_multiple,
133 SymbolicExpression::Neg {
134 degree_multiple, ..
135 } => *degree_multiple,
136 SymbolicExpression::Mul {
137 degree_multiple, ..
138 } => *degree_multiple,
139 }
140 }
141}
142
143impl<F: Field> Default for SymbolicExpression<F> {
144 fn default() -> Self {
145 Self::Constant(F::ZERO)
146 }
147}
148
149impl<F: Field> From<F> for SymbolicExpression<F> {
150 fn from(value: F) -> Self {
151 Self::Constant(value)
152 }
153}
154
155impl<F: Field> FieldAlgebra for SymbolicExpression<F> {
156 type F = F;
157
158 const ZERO: Self = Self::Constant(F::ZERO);
159 const ONE: Self = Self::Constant(F::ONE);
160 const TWO: Self = Self::Constant(F::TWO);
161 const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
162
163 #[inline]
164 fn from_f(f: Self::F) -> Self {
165 f.into()
166 }
167
168 fn from_bool(b: bool) -> Self {
169 Self::Constant(F::from_bool(b))
170 }
171
172 fn from_canonical_u8(n: u8) -> Self {
173 Self::Constant(F::from_canonical_u8(n))
174 }
175
176 fn from_canonical_u16(n: u16) -> Self {
177 Self::Constant(F::from_canonical_u16(n))
178 }
179
180 fn from_canonical_u32(n: u32) -> Self {
181 Self::Constant(F::from_canonical_u32(n))
182 }
183
184 fn from_canonical_u64(n: u64) -> Self {
185 Self::Constant(F::from_canonical_u64(n))
186 }
187
188 fn from_canonical_usize(n: usize) -> Self {
189 Self::Constant(F::from_canonical_usize(n))
190 }
191
192 fn from_wrapped_u32(n: u32) -> Self {
193 Self::Constant(F::from_wrapped_u32(n))
194 }
195
196 fn from_wrapped_u64(n: u64) -> Self {
197 Self::Constant(F::from_wrapped_u64(n))
198 }
199}
200
201impl<F: Field> Add for SymbolicExpression<F> {
202 type Output = Self;
203
204 fn add(self, rhs: Self) -> Self {
205 let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
206 Self::Add {
207 x: Arc::new(self),
208 y: Arc::new(rhs),
209 degree_multiple,
210 }
211 }
212}
213
214impl<F: Field> Add<F> for SymbolicExpression<F> {
215 type Output = Self;
216
217 fn add(self, rhs: F) -> Self {
218 self + Self::from(rhs)
219 }
220}
221
222impl<F: Field> AddAssign for SymbolicExpression<F> {
223 fn add_assign(&mut self, rhs: Self) {
224 *self = self.clone() + rhs;
225 }
226}
227
228impl<F: Field> AddAssign<F> for SymbolicExpression<F> {
229 fn add_assign(&mut self, rhs: F) {
230 *self += Self::from(rhs);
231 }
232}
233
234impl<F: Field> Sum for SymbolicExpression<F> {
235 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
236 iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO)
237 }
238}
239
240impl<F: Field> Sum<F> for SymbolicExpression<F> {
241 fn sum<I: Iterator<Item = F>>(iter: I) -> Self {
242 iter.map(|x| Self::from(x)).sum()
243 }
244}
245
246impl<F: Field> Sub for SymbolicExpression<F> {
247 type Output = Self;
248
249 fn sub(self, rhs: Self) -> Self {
250 let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
251 Self::Sub {
252 x: Arc::new(self),
253 y: Arc::new(rhs),
254 degree_multiple,
255 }
256 }
257}
258
259impl<F: Field> Sub<F> for SymbolicExpression<F> {
260 type Output = Self;
261
262 fn sub(self, rhs: F) -> Self {
263 self - Self::from(rhs)
264 }
265}
266
267impl<F: Field> SubAssign for SymbolicExpression<F> {
268 fn sub_assign(&mut self, rhs: Self) {
269 *self = self.clone() - rhs;
270 }
271}
272
273impl<F: Field> SubAssign<F> for SymbolicExpression<F> {
274 fn sub_assign(&mut self, rhs: F) {
275 *self -= Self::from(rhs);
276 }
277}
278
279impl<F: Field> Neg for SymbolicExpression<F> {
280 type Output = Self;
281
282 fn neg(self) -> Self {
283 let degree_multiple = self.degree_multiple();
284 Self::Neg {
285 x: Arc::new(self),
286 degree_multiple,
287 }
288 }
289}
290
291impl<F: Field> Mul for SymbolicExpression<F> {
292 type Output = Self;
293
294 fn mul(self, rhs: Self) -> Self {
295 #[allow(clippy::suspicious_arithmetic_impl)]
296 let degree_multiple = self.degree_multiple() + rhs.degree_multiple();
297 Self::Mul {
298 x: Arc::new(self),
299 y: Arc::new(rhs),
300 degree_multiple,
301 }
302 }
303}
304
305impl<F: Field> Mul<F> for SymbolicExpression<F> {
306 type Output = Self;
307
308 fn mul(self, rhs: F) -> Self {
309 self * Self::from(rhs)
310 }
311}
312
313impl<F: Field> MulAssign for SymbolicExpression<F> {
314 fn mul_assign(&mut self, rhs: Self) {
315 *self = self.clone() * rhs;
316 }
317}
318
319impl<F: Field> MulAssign<F> for SymbolicExpression<F> {
320 fn mul_assign(&mut self, rhs: F) {
321 *self *= Self::from(rhs);
322 }
323}
324
325impl<F: Field> Product for SymbolicExpression<F> {
326 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
327 iter.reduce(|x, y| x * y).unwrap_or(Self::ONE)
328 }
329}
330
331impl<F: Field> Product<F> for SymbolicExpression<F> {
332 fn product<I: Iterator<Item = F>>(iter: I) -> Self {
333 iter.map(|x| Self::from(x)).product()
334 }
335}
336
337pub trait SymbolicEvaluator<F, E>
338where
339 F: Field,
340 E: Add<E, Output = E> + Sub<E, Output = E> + Mul<E, Output = E> + Neg<Output = E>,
341{
342 fn eval_const(&self, c: F) -> E;
343 fn eval_var(&self, symbolic_var: SymbolicVariable<F>) -> E;
344 fn eval_is_first_row(&self) -> E;
345 fn eval_is_last_row(&self) -> E;
346 fn eval_is_transition(&self) -> E;
347
348 fn eval_expr(&self, symbolic_expr: &SymbolicExpression<F>) -> E {
349 match symbolic_expr {
350 SymbolicExpression::Variable(var) => self.eval_var(*var),
351 SymbolicExpression::Constant(c) => self.eval_const(*c),
352 SymbolicExpression::Add { x, y, .. } => self.eval_expr(x) + self.eval_expr(y),
353 SymbolicExpression::Sub { x, y, .. } => self.eval_expr(x) - self.eval_expr(y),
354 SymbolicExpression::Neg { x, .. } => -self.eval_expr(x),
355 SymbolicExpression::Mul { x, y, .. } => self.eval_expr(x) * self.eval_expr(y),
356 SymbolicExpression::IsFirstRow => self.eval_is_first_row(),
357 SymbolicExpression::IsLastRow => self.eval_is_last_row(),
358 SymbolicExpression::IsTransition => self.eval_is_transition(),
359 }
360 }
361
362 fn eval_nodes(&self, nodes: &[SymbolicExpressionNode<F>]) -> Vec<E>
365 where
366 E: Clone,
367 {
368 let mut exprs: Vec<E> = Vec::with_capacity(nodes.len());
369 for node in nodes {
370 let expr = match *node {
371 SymbolicExpressionNode::Variable(var) => self.eval_var(var),
372 SymbolicExpressionNode::Constant(c) => self.eval_const(c),
373 SymbolicExpressionNode::Add {
374 left_idx,
375 right_idx,
376 ..
377 } => exprs[left_idx].clone() + exprs[right_idx].clone(),
378 SymbolicExpressionNode::Sub {
379 left_idx,
380 right_idx,
381 ..
382 } => exprs[left_idx].clone() - exprs[right_idx].clone(),
383 SymbolicExpressionNode::Neg { idx, .. } => -exprs[idx].clone(),
384 SymbolicExpressionNode::Mul {
385 left_idx,
386 right_idx,
387 ..
388 } => exprs[left_idx].clone() * exprs[right_idx].clone(),
389 SymbolicExpressionNode::IsFirstRow => self.eval_is_first_row(),
390 SymbolicExpressionNode::IsLastRow => self.eval_is_last_row(),
391 SymbolicExpressionNode::IsTransition => self.eval_is_transition(),
392 };
393 exprs.push(expr);
394 }
395 exprs
396 }
397}