1use core::{
4 fmt::Debug,
5 iter::{Product, Sum},
6 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
7};
8use std::{
9 hash::{Hash, Hasher},
10 ptr,
11 sync::Arc,
12};
13
14use p3_field::{Field, FieldAlgebra};
15use serde::{Deserialize, Serialize};
16
17use super::{dag::SymbolicExpressionNode, symbolic_variable::SymbolicVariable};
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
22#[serde(bound = "F: Field")]
23pub enum SymbolicExpression<F> {
24 Variable(SymbolicVariable<F>),
25 IsFirstRow,
26 IsLastRow,
27 IsTransition,
28 Constant(F),
29 Add {
30 x: Arc<Self>,
31 y: Arc<Self>,
32 degree_multiple: usize,
33 },
34 Sub {
35 x: Arc<Self>,
36 y: Arc<Self>,
37 degree_multiple: usize,
38 },
39 Neg {
40 x: Arc<Self>,
41 degree_multiple: usize,
42 },
43 Mul {
44 x: Arc<Self>,
45 y: Arc<Self>,
46 degree_multiple: usize,
47 },
48}
49
50impl<F: Field> Hash for SymbolicExpression<F> {
51 fn hash<H: Hasher>(&self, state: &mut H) {
52 std::mem::discriminant(self).hash(state);
54 match self {
56 Self::Variable(v) => v.hash(state),
57 Self::IsFirstRow => {} Self::IsLastRow => {} Self::IsTransition => {} Self::Constant(f) => f.hash(state),
61 Self::Add { x, y, .. } => {
62 ptr::hash(&**x, state);
63 ptr::hash(&**y, state);
64 }
65 Self::Sub { x, y, .. } => {
66 ptr::hash(&**x, state);
67 ptr::hash(&**y, state);
68 }
69 Self::Neg { x, .. } => {
70 ptr::hash(&**x, state);
71 }
72 Self::Mul { x, y, .. } => {
73 ptr::hash(&**x, state);
74 ptr::hash(&**y, state);
75 }
76 }
77 }
78}
79
80impl<F: Field> PartialEq for SymbolicExpression<F> {
83 fn eq(&self, other: &Self) -> bool {
84 if std::mem::discriminant(self) != std::mem::discriminant(other) {
86 return false;
87 }
88
89 match (self, other) {
91 (Self::Variable(v1), Self::Variable(v2)) => v1 == v2,
92 (Self::IsFirstRow, Self::IsFirstRow) => true,
95 (Self::IsLastRow, Self::IsLastRow) => true,
96 (Self::IsTransition, Self::IsTransition) => true,
97 (Self::Constant(c1), Self::Constant(c2)) => c1 == c2,
98 (Self::Add { x: x1, y: y1, .. }, Self::Add { x: x2, y: y2, .. }) => {
100 Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
101 }
102 (Self::Sub { x: x1, y: y1, .. }, Self::Sub { x: x2, y: y2, .. }) => {
103 Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
104 }
105 (Self::Neg { x: x1, .. }, Self::Neg { x: x2, .. }) => Arc::ptr_eq(x1, x2),
106 (Self::Mul { x: x1, y: y1, .. }, Self::Mul { x: x2, y: y2, .. }) => {
107 Arc::ptr_eq(x1, x2) && Arc::ptr_eq(y1, y2)
108 }
109 _ => false,
111 }
112 }
113}
114
115impl<F: Field> Eq for SymbolicExpression<F> {}
116
117impl<F: Field> SymbolicExpression<F> {
118 pub const fn degree_multiple(&self) -> usize {
120 match self {
121 SymbolicExpression::Variable(v) => v.degree_multiple(),
122 SymbolicExpression::IsFirstRow => 1,
123 SymbolicExpression::IsLastRow => 1,
124 SymbolicExpression::IsTransition => 0,
125 SymbolicExpression::Constant(_) => 0,
126 SymbolicExpression::Add {
127 degree_multiple, ..
128 } => *degree_multiple,
129 SymbolicExpression::Sub {
130 degree_multiple, ..
131 } => *degree_multiple,
132 SymbolicExpression::Neg {
133 degree_multiple, ..
134 } => *degree_multiple,
135 SymbolicExpression::Mul {
136 degree_multiple, ..
137 } => *degree_multiple,
138 }
139 }
140}
141
142impl<F: Field> Default for SymbolicExpression<F> {
143 fn default() -> Self {
144 Self::Constant(F::ZERO)
145 }
146}
147
148impl<F: Field> From<F> for SymbolicExpression<F> {
149 fn from(value: F) -> Self {
150 Self::Constant(value)
151 }
152}
153
154impl<F: Field> FieldAlgebra for SymbolicExpression<F> {
155 type F = F;
156
157 const ZERO: Self = Self::Constant(F::ZERO);
158 const ONE: Self = Self::Constant(F::ONE);
159 const TWO: Self = Self::Constant(F::TWO);
160 const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
161
162 #[inline]
163 fn from_f(f: Self::F) -> Self {
164 f.into()
165 }
166
167 fn from_bool(b: bool) -> Self {
168 Self::Constant(F::from_bool(b))
169 }
170
171 fn from_canonical_u8(n: u8) -> Self {
172 Self::Constant(F::from_canonical_u8(n))
173 }
174
175 fn from_canonical_u16(n: u16) -> Self {
176 Self::Constant(F::from_canonical_u16(n))
177 }
178
179 fn from_canonical_u32(n: u32) -> Self {
180 Self::Constant(F::from_canonical_u32(n))
181 }
182
183 fn from_canonical_u64(n: u64) -> Self {
184 Self::Constant(F::from_canonical_u64(n))
185 }
186
187 fn from_canonical_usize(n: usize) -> Self {
188 Self::Constant(F::from_canonical_usize(n))
189 }
190
191 fn from_wrapped_u32(n: u32) -> Self {
192 Self::Constant(F::from_wrapped_u32(n))
193 }
194
195 fn from_wrapped_u64(n: u64) -> Self {
196 Self::Constant(F::from_wrapped_u64(n))
197 }
198}
199
200impl<F: Field> Add for SymbolicExpression<F> {
201 type Output = Self;
202
203 fn add(self, rhs: Self) -> Self {
204 let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
205 Self::Add {
206 x: Arc::new(self),
207 y: Arc::new(rhs),
208 degree_multiple,
209 }
210 }
211}
212
213impl<F: Field> Add<F> for SymbolicExpression<F> {
214 type Output = Self;
215
216 fn add(self, rhs: F) -> Self {
217 self + Self::from(rhs)
218 }
219}
220
221impl<F: Field> AddAssign for SymbolicExpression<F> {
222 fn add_assign(&mut self, rhs: Self) {
223 *self = self.clone() + rhs;
224 }
225}
226
227impl<F: Field> AddAssign<F> for SymbolicExpression<F> {
228 fn add_assign(&mut self, rhs: F) {
229 *self += Self::from(rhs);
230 }
231}
232
233impl<F: Field> Sum for SymbolicExpression<F> {
234 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
235 iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO)
236 }
237}
238
239impl<F: Field> Sum<F> for SymbolicExpression<F> {
240 fn sum<I: Iterator<Item = F>>(iter: I) -> Self {
241 iter.map(|x| Self::from(x)).sum()
242 }
243}
244
245impl<F: Field> Sub for SymbolicExpression<F> {
246 type Output = Self;
247
248 fn sub(self, rhs: Self) -> Self {
249 let degree_multiple = self.degree_multiple().max(rhs.degree_multiple());
250 Self::Sub {
251 x: Arc::new(self),
252 y: Arc::new(rhs),
253 degree_multiple,
254 }
255 }
256}
257
258impl<F: Field> Sub<F> for SymbolicExpression<F> {
259 type Output = Self;
260
261 fn sub(self, rhs: F) -> Self {
262 self - Self::from(rhs)
263 }
264}
265
266impl<F: Field> SubAssign for SymbolicExpression<F> {
267 fn sub_assign(&mut self, rhs: Self) {
268 *self = self.clone() - rhs;
269 }
270}
271
272impl<F: Field> SubAssign<F> for SymbolicExpression<F> {
273 fn sub_assign(&mut self, rhs: F) {
274 *self -= Self::from(rhs);
275 }
276}
277
278impl<F: Field> Neg for SymbolicExpression<F> {
279 type Output = Self;
280
281 fn neg(self) -> Self {
282 let degree_multiple = self.degree_multiple();
283 Self::Neg {
284 x: Arc::new(self),
285 degree_multiple,
286 }
287 }
288}
289
290impl<F: Field> Mul for SymbolicExpression<F> {
291 type Output = Self;
292
293 fn mul(self, rhs: Self) -> Self {
294 #[allow(clippy::suspicious_arithmetic_impl)]
295 let degree_multiple = self.degree_multiple() + rhs.degree_multiple();
296 Self::Mul {
297 x: Arc::new(self),
298 y: Arc::new(rhs),
299 degree_multiple,
300 }
301 }
302}
303
304impl<F: Field> Mul<F> for SymbolicExpression<F> {
305 type Output = Self;
306
307 fn mul(self, rhs: F) -> Self {
308 self * Self::from(rhs)
309 }
310}
311
312impl<F: Field> MulAssign for SymbolicExpression<F> {
313 fn mul_assign(&mut self, rhs: Self) {
314 *self = self.clone() * rhs;
315 }
316}
317
318impl<F: Field> MulAssign<F> for SymbolicExpression<F> {
319 fn mul_assign(&mut self, rhs: F) {
320 *self *= Self::from(rhs);
321 }
322}
323
324impl<F: Field> Product for SymbolicExpression<F> {
325 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
326 iter.reduce(|x, y| x * y).unwrap_or(Self::ONE)
327 }
328}
329
330impl<F: Field> Product<F> for SymbolicExpression<F> {
331 fn product<I: Iterator<Item = F>>(iter: I) -> Self {
332 iter.map(|x| Self::from(x)).product()
333 }
334}
335
336pub trait SymbolicEvaluator<F, E>
337where
338 F: Field,
339 E: Add<E, Output = E> + Sub<E, Output = E> + Mul<E, Output = E> + Neg<Output = E>,
340{
341 fn eval_const(&self, c: F) -> E;
342 fn eval_var(&self, symbolic_var: SymbolicVariable<F>) -> E;
343 fn eval_is_first_row(&self) -> E;
344 fn eval_is_last_row(&self) -> E;
345 fn eval_is_transition(&self) -> E;
346
347 fn eval_expr(&self, symbolic_expr: &SymbolicExpression<F>) -> E {
348 match symbolic_expr {
349 SymbolicExpression::Variable(var) => self.eval_var(*var),
350 SymbolicExpression::Constant(c) => self.eval_const(*c),
351 SymbolicExpression::Add { x, y, .. } => self.eval_expr(x) + self.eval_expr(y),
352 SymbolicExpression::Sub { x, y, .. } => self.eval_expr(x) - self.eval_expr(y),
353 SymbolicExpression::Neg { x, .. } => -self.eval_expr(x),
354 SymbolicExpression::Mul { x, y, .. } => self.eval_expr(x) * self.eval_expr(y),
355 SymbolicExpression::IsFirstRow => self.eval_is_first_row(),
356 SymbolicExpression::IsLastRow => self.eval_is_last_row(),
357 SymbolicExpression::IsTransition => self.eval_is_transition(),
358 }
359 }
360
361 fn eval_nodes(&self, nodes: &[SymbolicExpressionNode<F>]) -> Vec<E>
364 where
365 E: Clone,
366 {
367 let mut exprs: Vec<E> = Vec::with_capacity(nodes.len());
368 for node in nodes {
369 let expr = match *node {
370 SymbolicExpressionNode::Variable(var) => self.eval_var(var),
371 SymbolicExpressionNode::Constant(c) => self.eval_const(c),
372 SymbolicExpressionNode::Add {
373 left_idx,
374 right_idx,
375 ..
376 } => exprs[left_idx].clone() + exprs[right_idx].clone(),
377 SymbolicExpressionNode::Sub {
378 left_idx,
379 right_idx,
380 ..
381 } => exprs[left_idx].clone() - exprs[right_idx].clone(),
382 SymbolicExpressionNode::Neg { idx, .. } => -exprs[idx].clone(),
383 SymbolicExpressionNode::Mul {
384 left_idx,
385 right_idx,
386 ..
387 } => exprs[left_idx].clone() * exprs[right_idx].clone(),
388 SymbolicExpressionNode::IsFirstRow => self.eval_is_first_row(),
389 SymbolicExpressionNode::IsLastRow => self.eval_is_last_row(),
390 SymbolicExpressionNode::IsTransition => self.eval_is_transition(),
391 };
392 exprs.push(expr);
393 }
394 exprs
395 }
396}