openvm_mod_circuit_builder/
symbolic_expr.rs

1use std::{
2    cmp::{max, min},
3    convert::identity,
4    iter::repeat,
5    ops::{Add, Div, Mul, Sub},
6};
7
8use num_bigint::{BigInt, BigUint, Sign};
9use num_traits::{FromPrimitive, One, Zero};
10use openvm_circuit_primitives::bigint::{
11    check_carry_to_zero::get_carry_max_abs_and_bits, OverflowInt,
12};
13use openvm_stark_backend::{
14    p3_air::AirBuilder, p3_field::PrimeCharacteristicRing, p3_util::log2_ceil_usize,
15};
16
17/// Example: If there are 4 inputs (x1, y1, x2, y2), and one intermediate variable lambda,
18/// Mul(Var(0), Var(0)) - Input(0) - Input(2) =>
19/// lambda * lambda - x1 - x2
20#[derive(Clone, Debug, PartialEq)]
21pub enum SymbolicExpr {
22    Input(usize),
23    Var(usize),
24    Const(usize, BigUint, usize), // (index, value, number of limbs)
25    Add(Box<SymbolicExpr>, Box<SymbolicExpr>),
26    Sub(Box<SymbolicExpr>, Box<SymbolicExpr>),
27    Mul(Box<SymbolicExpr>, Box<SymbolicExpr>),
28    // Division is not allowed in "constraints", but can only be used in "computes"
29    // Note that division by zero in "computes" will panic.
30    Div(Box<SymbolicExpr>, Box<SymbolicExpr>),
31    // Add integer
32    IntAdd(Box<SymbolicExpr>, isize),
33    // Multiply each limb with an integer. For BigInt this is just scalar multiplication.
34    IntMul(Box<SymbolicExpr>, isize),
35    // Select one of the two expressions based on the flag.
36    // The two expressions must have the same structure (number of limbs etc), e.g. a+b and a-b.
37    Select(usize, Box<SymbolicExpr>, Box<SymbolicExpr>),
38}
39
40impl std::fmt::Display for SymbolicExpr {
41    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42        match self {
43            SymbolicExpr::Input(i) => write!(f, "Input_{i}"),
44            SymbolicExpr::Var(i) => write!(f, "Var_{i}"),
45            SymbolicExpr::Const(i, _, _) => write!(f, "Const_{i}"),
46            SymbolicExpr::Add(lhs, rhs) => write!(f, "({lhs} + {rhs})"),
47            SymbolicExpr::Sub(lhs, rhs) => write!(f, "({lhs} - {rhs})"),
48            SymbolicExpr::Mul(lhs, rhs) => write!(f, "{lhs} * {rhs}"),
49            SymbolicExpr::Div(lhs, rhs) => write!(f, "({lhs} / {rhs})"),
50            SymbolicExpr::IntAdd(lhs, s) => write!(f, "({lhs} + {s})"),
51            SymbolicExpr::IntMul(lhs, s) => write!(f, "({lhs} x {s})"),
52            SymbolicExpr::Select(flag_id, lhs, rhs) => {
53                write!(f, "(if {flag_id} then {lhs} else {rhs})")
54            }
55        }
56    }
57}
58
59impl Add for SymbolicExpr {
60    type Output = SymbolicExpr;
61
62    fn add(self, rhs: Self) -> Self::Output {
63        SymbolicExpr::Add(Box::new(self), Box::new(rhs))
64    }
65}
66
67impl Add<&SymbolicExpr> for SymbolicExpr {
68    type Output = SymbolicExpr;
69
70    fn add(self, rhs: &SymbolicExpr) -> Self::Output {
71        SymbolicExpr::Add(Box::new(self), Box::new(rhs.clone()))
72    }
73}
74
75impl Add for &SymbolicExpr {
76    type Output = SymbolicExpr;
77
78    fn add(self, rhs: &SymbolicExpr) -> Self::Output {
79        SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs.clone()))
80    }
81}
82
83impl Add<SymbolicExpr> for &SymbolicExpr {
84    type Output = SymbolicExpr;
85
86    fn add(self, rhs: SymbolicExpr) -> Self::Output {
87        SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs))
88    }
89}
90
91impl Sub for SymbolicExpr {
92    type Output = SymbolicExpr;
93
94    fn sub(self, rhs: Self) -> Self::Output {
95        SymbolicExpr::Sub(Box::new(self), Box::new(rhs))
96    }
97}
98
99impl Sub<&SymbolicExpr> for SymbolicExpr {
100    type Output = SymbolicExpr;
101
102    fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
103        SymbolicExpr::Sub(Box::new(self), Box::new(rhs.clone()))
104    }
105}
106
107impl Sub for &SymbolicExpr {
108    type Output = SymbolicExpr;
109
110    fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
111        SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs.clone()))
112    }
113}
114
115impl Sub<SymbolicExpr> for &SymbolicExpr {
116    type Output = SymbolicExpr;
117
118    fn sub(self, rhs: SymbolicExpr) -> Self::Output {
119        SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs))
120    }
121}
122
123impl Mul for SymbolicExpr {
124    type Output = SymbolicExpr;
125
126    fn mul(self, rhs: Self) -> Self::Output {
127        SymbolicExpr::Mul(Box::new(self), Box::new(rhs))
128    }
129}
130
131impl Mul<&SymbolicExpr> for SymbolicExpr {
132    type Output = SymbolicExpr;
133
134    fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
135        SymbolicExpr::Mul(Box::new(self), Box::new(rhs.clone()))
136    }
137}
138
139impl Mul for &SymbolicExpr {
140    type Output = SymbolicExpr;
141
142    fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
143        SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs.clone()))
144    }
145}
146
147impl Mul<SymbolicExpr> for &SymbolicExpr {
148    type Output = SymbolicExpr;
149
150    fn mul(self, rhs: SymbolicExpr) -> Self::Output {
151        SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs))
152    }
153}
154
155// Note that division by zero will panic.
156impl Div for SymbolicExpr {
157    type Output = SymbolicExpr;
158
159    fn div(self, rhs: Self) -> Self::Output {
160        SymbolicExpr::Div(Box::new(self), Box::new(rhs))
161    }
162}
163
164// Note that division by zero will panic.
165impl Div<&SymbolicExpr> for SymbolicExpr {
166    type Output = SymbolicExpr;
167
168    fn div(self, rhs: &SymbolicExpr) -> Self::Output {
169        SymbolicExpr::Div(Box::new(self), Box::new(rhs.clone()))
170    }
171}
172
173// Note that division by zero will panic.
174impl Div for &SymbolicExpr {
175    type Output = SymbolicExpr;
176
177    fn div(self, rhs: &SymbolicExpr) -> Self::Output {
178        SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs.clone()))
179    }
180}
181
182// Note that division by zero will panic.
183impl Div<SymbolicExpr> for &SymbolicExpr {
184    type Output = SymbolicExpr;
185
186    fn div(self, rhs: SymbolicExpr) -> Self::Output {
187        SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs))
188    }
189}
190
191impl SymbolicExpr {
192    /// Returns maximum absolute positive and negative value of the expression.
193    /// That is, if `(r, l) = expr.max_abs(p)` then `l,r >= 0` and `-l <= expr <= r`.
194    /// Needed in `constraint_limbs` to estimate the number of limbs of q.
195    ///
196    /// It is assumed that any `Input` or `Var` is a non-negative big integer with value
197    /// in the range `[0, proper_max]`.
198    fn max_abs(&self, proper_max: &BigUint) -> (BigUint, BigUint) {
199        match self {
200            SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => (proper_max.clone(), BigUint::zero()),
201            SymbolicExpr::Const(_, val, _) => (val.clone(), BigUint::zero()),
202            SymbolicExpr::Add(lhs, rhs) => {
203                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
204                let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
205                (lhs_max_pos + rhs_max_pos, lhs_max_neg + rhs_max_neg)
206            }
207            SymbolicExpr::Sub(lhs, rhs) => {
208                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
209                let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
210                (lhs_max_pos + rhs_max_neg, lhs_max_neg + rhs_max_pos)
211            }
212            SymbolicExpr::Mul(lhs, rhs) => {
213                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
214                let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
215                (
216                    max(&lhs_max_pos * &rhs_max_pos, &lhs_max_neg * &rhs_max_neg),
217                    max(&lhs_max_pos * &rhs_max_neg, &lhs_max_neg * &rhs_max_pos),
218                )
219            }
220            SymbolicExpr::Div(_, _) => {
221                // Should not have division in expression when calling this.
222                unreachable!()
223            }
224            SymbolicExpr::IntAdd(lhs, s) => {
225                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
226                let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
227                // Optimization opportunity: since `s` is a constant, we can likely do better than
228                // this bound.
229                (lhs_max_pos + &scalar, lhs_max_neg + &scalar)
230            }
231            SymbolicExpr::IntMul(lhs, s) => {
232                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
233                let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
234                if *s < 0 {
235                    (lhs_max_neg * &scalar, lhs_max_pos * &scalar)
236                } else {
237                    (lhs_max_pos * &scalar, lhs_max_neg * &scalar)
238                }
239            }
240            SymbolicExpr::Select(_, lhs, rhs) => {
241                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
242                let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
243                (max(lhs_max_pos, rhs_max_pos), max(lhs_max_neg, rhs_max_neg))
244            }
245        }
246    }
247
248    /// Returns the maximum possible size, in bits, of each limb in `self.expr`.
249    /// This is already tracked in `FieldVariable`. However when auto saving in
250    /// `FieldVariable::div`, we need to know it from the `SymbolicExpr` only.
251    /// self should be a constraint expr.
252    pub fn constraint_limb_max_abs(&self, limb_bits: usize, num_limbs: usize) -> usize {
253        let canonical_limb_max_abs = (1 << limb_bits) - 1;
254        match self {
255            SymbolicExpr::Input(_) | SymbolicExpr::Var(_) | SymbolicExpr::Const(_, _, _) => {
256                canonical_limb_max_abs
257            }
258            SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
259                lhs.constraint_limb_max_abs(limb_bits, num_limbs)
260                    + rhs.constraint_limb_max_abs(limb_bits, num_limbs)
261            }
262            SymbolicExpr::Mul(lhs, rhs) => {
263                let left_num_limbs = lhs.expr_limbs(num_limbs);
264                let right_num_limbs = rhs.expr_limbs(num_limbs);
265                lhs.constraint_limb_max_abs(limb_bits, num_limbs)
266                    * rhs.constraint_limb_max_abs(limb_bits, num_limbs)
267                    * min(left_num_limbs, right_num_limbs)
268            }
269            SymbolicExpr::IntAdd(lhs, i) => {
270                lhs.constraint_limb_max_abs(limb_bits, num_limbs) + i.unsigned_abs()
271            }
272            SymbolicExpr::IntMul(lhs, i) => {
273                lhs.constraint_limb_max_abs(limb_bits, num_limbs) * i.unsigned_abs()
274            }
275            SymbolicExpr::Select(_, lhs, rhs) => max(
276                lhs.constraint_limb_max_abs(limb_bits, num_limbs),
277                rhs.constraint_limb_max_abs(limb_bits, num_limbs),
278            ),
279            SymbolicExpr::Div(_, _) => {
280                unreachable!("should not have division when calling limb_max_abs")
281            }
282        }
283    }
284
285    /// Returns the maximum possible size, in bits, of each carry in `self.expr - q * p`.
286    /// self should be a constraint expr.
287    ///
288    /// The cached value `proper_max` should equal `2^{limb_bits * num_limbs} - 1`.
289    pub fn constraint_carry_bits_with_pq(
290        &self,
291        prime: &BigUint,
292        limb_bits: usize,
293        num_limbs: usize,
294        proper_max: &BigUint,
295    ) -> usize {
296        let without_pq = self.constraint_limb_max_abs(limb_bits, num_limbs);
297        let (q_limbs, _) = self.constraint_limbs(prime, limb_bits, num_limbs, proper_max);
298        let canonical_limb_max_abs = (1 << limb_bits) - 1;
299        let limb_max_abs =
300            without_pq + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, num_limbs);
301        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
302        let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
303        carry_bits
304    }
305
306    /// Returns the number of limbs needed to represent the expression.
307    /// The parameter `num_limbs` is the number of limbs of a canonical field element.
308    pub fn expr_limbs(&self, num_limbs: usize) -> usize {
309        match self {
310            SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => num_limbs,
311            SymbolicExpr::Const(_, _, limbs) => *limbs,
312            SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
313                max(lhs.expr_limbs(num_limbs), rhs.expr_limbs(num_limbs))
314            }
315            SymbolicExpr::Mul(lhs, rhs) => {
316                lhs.expr_limbs(num_limbs) + rhs.expr_limbs(num_limbs) - 1
317            }
318            SymbolicExpr::Div(_, _) => {
319                unimplemented!()
320            }
321            SymbolicExpr::IntAdd(lhs, _) => lhs.expr_limbs(num_limbs),
322            SymbolicExpr::IntMul(lhs, _) => lhs.expr_limbs(num_limbs),
323            SymbolicExpr::Select(_, lhs, rhs) => {
324                let left = lhs.expr_limbs(num_limbs);
325                let right = rhs.expr_limbs(num_limbs);
326                assert_eq!(left, right);
327                left
328            }
329        }
330    }
331
332    /// Let `q` be such that `self.expr = q * p`.
333    /// Returns (q_limbs, carry_limbs) where q_limbs is the number of limbs in q
334    /// and carry_limbs is the number of limbs in the carry of the constraint self.expr - q * p = 0.
335    /// self should be a constraint expression.
336    ///
337    /// The cached value `proper_max` should equal `2^{limb_bits * num_limbs} - 1`.
338    pub fn constraint_limbs(
339        &self,
340        prime: &BigUint,
341        limb_bits: usize,
342        num_limbs: usize,
343        proper_max: &BigUint,
344    ) -> (usize, usize) {
345        let (max_pos_abs, max_neg_abs) = self.max_abs(proper_max);
346        let max_abs = max(max_pos_abs, max_neg_abs);
347        let max_q_abs = (&max_abs + prime - BigUint::one()) / prime;
348        let q_bits = max_q_abs.bits() as usize;
349        let p_bits = prime.bits() as usize;
350        let q_limbs = q_bits.div_ceil(limb_bits);
351        // Attention! This must match with prime_overflow in `FieldExpr::generate_subrow`
352        let p_limbs = p_bits.div_ceil(limb_bits);
353        let qp_limbs = q_limbs + p_limbs - 1;
354
355        let expr_limbs = self.expr_limbs(num_limbs);
356        let carry_limbs = max(expr_limbs, qp_limbs);
357        (q_limbs, carry_limbs)
358    }
359
360    /// Used in trace gen to compute `q``.
361    /// self should be a constraint expression.
362    pub fn evaluate_bigint(
363        &self,
364        inputs: &[BigInt],
365        variables: &[BigInt],
366        flags: &[bool],
367    ) -> BigInt {
368        match self {
369            SymbolicExpr::IntAdd(lhs, s) => {
370                lhs.evaluate_bigint(inputs, variables, flags) + BigInt::from_isize(*s).unwrap()
371            }
372            SymbolicExpr::IntMul(lhs, s) => {
373                lhs.evaluate_bigint(inputs, variables, flags) * BigInt::from_isize(*s).unwrap()
374            }
375            SymbolicExpr::Input(i) => inputs[*i].clone(),
376            SymbolicExpr::Var(i) => variables[*i].clone(),
377            SymbolicExpr::Const(_, val, _) => {
378                if val.is_zero() {
379                    BigInt::zero()
380                } else {
381                    BigInt::from_biguint(Sign::Plus, val.clone())
382                }
383            }
384            SymbolicExpr::Add(lhs, rhs) => {
385                lhs.evaluate_bigint(inputs, variables, flags)
386                    + rhs.evaluate_bigint(inputs, variables, flags)
387            }
388            SymbolicExpr::Sub(lhs, rhs) => {
389                lhs.evaluate_bigint(inputs, variables, flags)
390                    - rhs.evaluate_bigint(inputs, variables, flags)
391            }
392            SymbolicExpr::Mul(lhs, rhs) => {
393                lhs.evaluate_bigint(inputs, variables, flags)
394                    * rhs.evaluate_bigint(inputs, variables, flags)
395            }
396            SymbolicExpr::Select(flag_id, lhs, rhs) => {
397                if flags[*flag_id] {
398                    lhs.evaluate_bigint(inputs, variables, flags)
399                } else {
400                    rhs.evaluate_bigint(inputs, variables, flags)
401                }
402            }
403            SymbolicExpr::Div(_, _) => unreachable!(), // Division is not allowed in constraints.
404        }
405    }
406
407    /// Used in trace gen to compute carries.
408    /// self should be a constraint expression.
409    pub fn evaluate_overflow_isize(
410        &self,
411        inputs: &[OverflowInt<isize>],
412        variables: &[OverflowInt<isize>],
413        constants: &[OverflowInt<isize>],
414        flags: &[bool],
415    ) -> OverflowInt<isize> {
416        match self {
417            SymbolicExpr::IntAdd(lhs, s) => {
418                let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
419                left.int_add(*s, identity)
420            }
421            SymbolicExpr::IntMul(lhs, s) => {
422                let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
423                left.int_mul(*s, identity)
424            }
425            SymbolicExpr::Input(i) => inputs[*i].clone(),
426            SymbolicExpr::Var(i) => variables[*i].clone(),
427            SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
428            SymbolicExpr::Add(lhs, rhs) => {
429                lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
430                    + rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
431            }
432            SymbolicExpr::Sub(lhs, rhs) => {
433                lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
434                    - rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
435            }
436            SymbolicExpr::Mul(lhs, rhs) => {
437                lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
438                    * rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
439            }
440            SymbolicExpr::Select(flag_id, lhs, rhs) => {
441                let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
442                let right = rhs.evaluate_overflow_isize(inputs, variables, constants, flags);
443                let num_limbs = max(left.num_limbs(), right.num_limbs());
444
445                let res = if flags[*flag_id] {
446                    left.limbs().to_vec()
447                } else {
448                    right.limbs().to_vec()
449                };
450                let res = res.into_iter().chain(repeat(0)).take(num_limbs).collect();
451
452                OverflowInt::from_computed_limbs(
453                    res,
454                    max(left.limb_max_abs(), right.limb_max_abs()),
455                    max(left.max_overflow_bits(), right.max_overflow_bits()),
456                )
457            }
458            SymbolicExpr::Div(_, _) => unreachable!(), // Division is not allowed in constraints.
459        }
460    }
461
462    fn isize_to_expr<AB: AirBuilder>(s: isize) -> AB::Expr {
463        if s >= 0 {
464            AB::Expr::from_usize(s as usize)
465        } else {
466            -AB::Expr::from_usize(s.unsigned_abs())
467        }
468    }
469
470    /// Used in AIR eval.
471    /// self should be a constraint expression.
472    pub fn evaluate_overflow_expr<AB: AirBuilder>(
473        &self,
474        inputs: &[OverflowInt<AB::Expr>],
475        variables: &[OverflowInt<AB::Expr>],
476        constants: &[OverflowInt<AB::Expr>],
477        flags: &[AB::Var],
478    ) -> OverflowInt<AB::Expr> {
479        match self {
480            SymbolicExpr::IntAdd(lhs, s) => {
481                let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
482                left.int_add(*s, Self::isize_to_expr::<AB>)
483            }
484            SymbolicExpr::IntMul(lhs, s) => {
485                let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
486                left.int_mul(*s, Self::isize_to_expr::<AB>)
487            }
488            SymbolicExpr::Input(i) => inputs[*i].clone(),
489            SymbolicExpr::Var(i) => variables[*i].clone(),
490            SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
491            SymbolicExpr::Add(lhs, rhs) => {
492                lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
493                    + rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
494            }
495            SymbolicExpr::Sub(lhs, rhs) => {
496                lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
497                    - rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
498            }
499            SymbolicExpr::Mul(lhs, rhs) => {
500                lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
501                    * rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
502            }
503            SymbolicExpr::Select(flag_id, lhs, rhs) => {
504                let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
505                let right = rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
506                let num_limbs = max(left.num_limbs(), right.num_limbs());
507                let flag = &flags[*flag_id];
508                let mut res = vec![];
509                for i in 0..num_limbs {
510                    res.push(
511                        (if i < left.num_limbs() {
512                            left.limb(i).clone()
513                        } else {
514                            AB::Expr::ZERO
515                        }) * flag.clone()
516                            + (if i < right.num_limbs() {
517                                right.limb(i).clone()
518                            } else {
519                                AB::Expr::ZERO
520                            }) * (AB::Expr::ONE - flag.clone()),
521                    );
522                }
523                OverflowInt::from_computed_limbs(
524                    res,
525                    max(left.limb_max_abs(), right.limb_max_abs()),
526                    max(left.max_overflow_bits(), right.max_overflow_bits()),
527                )
528            }
529            SymbolicExpr::Div(_, _) => unreachable!(), // Division is not allowed in constraints.
530        }
531    }
532
533    /// Result will be within [0, prime).
534    /// self should be a compute expression.
535    /// Note that division by zero will panic.
536    pub fn compute(
537        &self,
538        inputs: &[BigUint],
539        variables: &[BigUint],
540        flags: &[bool],
541        prime: &BigUint,
542    ) -> BigUint {
543        let res = match self {
544            SymbolicExpr::Input(i) => inputs[*i].clone() % prime,
545            SymbolicExpr::Var(i) => variables[*i].clone(),
546            SymbolicExpr::Const(_, val, _) => val.clone(),
547            SymbolicExpr::Add(lhs, rhs) => {
548                (lhs.compute(inputs, variables, flags, prime)
549                    + rhs.compute(inputs, variables, flags, prime))
550                    % prime
551            }
552            SymbolicExpr::Sub(lhs, rhs) => {
553                (prime + lhs.compute(inputs, variables, flags, prime)
554                    - rhs.compute(inputs, variables, flags, prime))
555                    % prime
556            }
557            SymbolicExpr::Mul(lhs, rhs) => {
558                (lhs.compute(inputs, variables, flags, prime)
559                    * rhs.compute(inputs, variables, flags, prime))
560                    % prime
561            }
562            SymbolicExpr::Div(lhs, rhs) => {
563                let left = lhs.compute(inputs, variables, flags, prime);
564                let right = rhs.compute(inputs, variables, flags, prime);
565                let right_inv = right.modinv(prime).unwrap();
566                (left * right_inv) % prime
567            }
568            SymbolicExpr::IntAdd(lhs, s) => {
569                let left = lhs.compute(inputs, variables, flags, prime);
570                let right = if *s >= 0 {
571                    BigUint::from_usize(*s as usize).unwrap()
572                } else {
573                    prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
574                };
575                (left + right) % prime
576            }
577            SymbolicExpr::IntMul(lhs, s) => {
578                let left = lhs.compute(inputs, variables, flags, prime);
579                let right = if *s >= 0 {
580                    BigUint::from_usize(*s as usize).unwrap()
581                } else {
582                    prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
583                };
584                (left * right) % prime
585            }
586            SymbolicExpr::Select(flag_id, lhs, rhs) => {
587                if flags[*flag_id] {
588                    lhs.compute(inputs, variables, flags, prime)
589                } else {
590                    rhs.compute(inputs, variables, flags, prime)
591                }
592            }
593        };
594        assert!(
595            res < prime.clone(),
596            "symbolic expr: {self} evaluation exceeds prime"
597        );
598        res
599    }
600}