openvm_mod_circuit_builder/
symbolic_expr.rs

1use std::{
2    cmp::{max, min},
3    convert::identity,
4    iter::repeat,
5    ops::{Add, Div, Mul, Sub},
6};
7
8use num_bigint::{BigInt, BigUint, Sign};
9use num_traits::{FromPrimitive, One, Zero};
10use openvm_circuit_primitives::bigint::{
11    check_carry_to_zero::get_carry_max_abs_and_bits, OverflowInt,
12};
13use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra, p3_util::log2_ceil_usize};
14
15/// Example: If there are 4 inputs (x1, y1, x2, y2), and one intermediate variable lambda,
16/// Mul(Var(0), Var(0)) - Input(0) - Input(2) =>
17/// lambda * lambda - x1 - x2
18#[derive(Clone, Debug, PartialEq)]
19pub enum SymbolicExpr {
20    Input(usize),
21    Var(usize),
22    Const(usize, BigUint, usize), // (index, value, number of limbs)
23    Add(Box<SymbolicExpr>, Box<SymbolicExpr>),
24    Sub(Box<SymbolicExpr>, Box<SymbolicExpr>),
25    Mul(Box<SymbolicExpr>, Box<SymbolicExpr>),
26    // Division is not allowed in "constraints", but can only be used in "computes"
27    // Note that division by zero in "computes" will panic.
28    Div(Box<SymbolicExpr>, Box<SymbolicExpr>),
29    // Add integer
30    IntAdd(Box<SymbolicExpr>, isize),
31    // Multiply each limb with an integer. For BigInt this is just scalar multiplication.
32    IntMul(Box<SymbolicExpr>, isize),
33    // Select one of the two expressions based on the flag.
34    // The two expressions must have the same structure (number of limbs etc), e.g. a+b and a-b.
35    Select(usize, Box<SymbolicExpr>, Box<SymbolicExpr>),
36}
37
38impl std::fmt::Display for SymbolicExpr {
39    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
40        match self {
41            SymbolicExpr::Input(i) => write!(f, "Input_{}", i),
42            SymbolicExpr::Var(i) => write!(f, "Var_{}", i),
43            SymbolicExpr::Const(i, _, _) => write!(f, "Const_{}", i),
44            SymbolicExpr::Add(lhs, rhs) => write!(f, "({} + {})", lhs, rhs),
45            SymbolicExpr::Sub(lhs, rhs) => write!(f, "({} - {})", lhs, rhs),
46            SymbolicExpr::Mul(lhs, rhs) => write!(f, "{} * {}", lhs, rhs),
47            SymbolicExpr::Div(lhs, rhs) => write!(f, "({} / {})", lhs, rhs),
48            SymbolicExpr::IntAdd(lhs, s) => write!(f, "({} + {})", lhs, s),
49            SymbolicExpr::IntMul(lhs, s) => write!(f, "({} x {})", lhs, s),
50            SymbolicExpr::Select(flag_id, lhs, rhs) => {
51                write!(f, "(if {} then {} else {})", flag_id, lhs, rhs)
52            }
53        }
54    }
55}
56
57impl Add for SymbolicExpr {
58    type Output = SymbolicExpr;
59
60    fn add(self, rhs: Self) -> Self::Output {
61        SymbolicExpr::Add(Box::new(self), Box::new(rhs))
62    }
63}
64
65impl Add<&SymbolicExpr> for SymbolicExpr {
66    type Output = SymbolicExpr;
67
68    fn add(self, rhs: &SymbolicExpr) -> Self::Output {
69        SymbolicExpr::Add(Box::new(self), Box::new(rhs.clone()))
70    }
71}
72
73impl Add for &SymbolicExpr {
74    type Output = SymbolicExpr;
75
76    fn add(self, rhs: &SymbolicExpr) -> Self::Output {
77        SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs.clone()))
78    }
79}
80
81impl Add<SymbolicExpr> for &SymbolicExpr {
82    type Output = SymbolicExpr;
83
84    fn add(self, rhs: SymbolicExpr) -> Self::Output {
85        SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs))
86    }
87}
88
89impl Sub for SymbolicExpr {
90    type Output = SymbolicExpr;
91
92    fn sub(self, rhs: Self) -> Self::Output {
93        SymbolicExpr::Sub(Box::new(self), Box::new(rhs))
94    }
95}
96
97impl Sub<&SymbolicExpr> for SymbolicExpr {
98    type Output = SymbolicExpr;
99
100    fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
101        SymbolicExpr::Sub(Box::new(self), Box::new(rhs.clone()))
102    }
103}
104
105impl Sub for &SymbolicExpr {
106    type Output = SymbolicExpr;
107
108    fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
109        SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs.clone()))
110    }
111}
112
113impl Sub<SymbolicExpr> for &SymbolicExpr {
114    type Output = SymbolicExpr;
115
116    fn sub(self, rhs: SymbolicExpr) -> Self::Output {
117        SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs))
118    }
119}
120
121impl Mul for SymbolicExpr {
122    type Output = SymbolicExpr;
123
124    fn mul(self, rhs: Self) -> Self::Output {
125        SymbolicExpr::Mul(Box::new(self), Box::new(rhs))
126    }
127}
128
129impl Mul<&SymbolicExpr> for SymbolicExpr {
130    type Output = SymbolicExpr;
131
132    fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
133        SymbolicExpr::Mul(Box::new(self), Box::new(rhs.clone()))
134    }
135}
136
137impl Mul for &SymbolicExpr {
138    type Output = SymbolicExpr;
139
140    fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
141        SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs.clone()))
142    }
143}
144
145impl Mul<SymbolicExpr> for &SymbolicExpr {
146    type Output = SymbolicExpr;
147
148    fn mul(self, rhs: SymbolicExpr) -> Self::Output {
149        SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs))
150    }
151}
152
153// Note that division by zero will panic.
154impl Div for SymbolicExpr {
155    type Output = SymbolicExpr;
156
157    fn div(self, rhs: Self) -> Self::Output {
158        SymbolicExpr::Div(Box::new(self), Box::new(rhs))
159    }
160}
161
162// Note that division by zero will panic.
163impl Div<&SymbolicExpr> for SymbolicExpr {
164    type Output = SymbolicExpr;
165
166    fn div(self, rhs: &SymbolicExpr) -> Self::Output {
167        SymbolicExpr::Div(Box::new(self), Box::new(rhs.clone()))
168    }
169}
170
171// Note that division by zero will panic.
172impl Div for &SymbolicExpr {
173    type Output = SymbolicExpr;
174
175    fn div(self, rhs: &SymbolicExpr) -> Self::Output {
176        SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs.clone()))
177    }
178}
179
180// Note that division by zero will panic.
181impl Div<SymbolicExpr> for &SymbolicExpr {
182    type Output = SymbolicExpr;
183
184    fn div(self, rhs: SymbolicExpr) -> Self::Output {
185        SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs))
186    }
187}
188
189impl SymbolicExpr {
190    /// Returns maximum absolute positive and negative value of the expression.
191    /// That is, if `(r, l) = expr.max_abs(p)` then `l,r >= 0` and `-l <= expr <= r`.
192    /// Needed in `constraint_limbs` to estimate the number of limbs of q.
193    ///
194    /// It is assumed that any `Input` or `Var` is a non-negative big integer with value
195    /// in the range `[0, proper_max]`.
196    fn max_abs(&self, proper_max: &BigUint) -> (BigUint, BigUint) {
197        match self {
198            SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => (proper_max.clone(), BigUint::zero()),
199            SymbolicExpr::Const(_, val, _) => (val.clone(), BigUint::zero()),
200            SymbolicExpr::Add(lhs, rhs) => {
201                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
202                let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
203                (lhs_max_pos + rhs_max_pos, lhs_max_neg + rhs_max_neg)
204            }
205            SymbolicExpr::Sub(lhs, rhs) => {
206                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
207                let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
208                (lhs_max_pos + rhs_max_neg, lhs_max_neg + rhs_max_pos)
209            }
210            SymbolicExpr::Mul(lhs, rhs) => {
211                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
212                let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
213                (
214                    max(&lhs_max_pos * &rhs_max_pos, &lhs_max_neg * &rhs_max_neg),
215                    max(&lhs_max_pos * &rhs_max_neg, &lhs_max_neg * &rhs_max_pos),
216                )
217            }
218            SymbolicExpr::Div(_, _) => {
219                // Should not have division in expression when calling this.
220                unreachable!()
221            }
222            SymbolicExpr::IntAdd(lhs, s) => {
223                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
224                let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
225                // Optimization opportunity: since `s` is a constant, we can likely do better than this bound.
226                (lhs_max_pos + &scalar, lhs_max_neg + &scalar)
227            }
228            SymbolicExpr::IntMul(lhs, s) => {
229                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
230                let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
231                if *s < 0 {
232                    (lhs_max_neg * &scalar, lhs_max_pos * &scalar)
233                } else {
234                    (lhs_max_pos * &scalar, lhs_max_neg * &scalar)
235                }
236            }
237            SymbolicExpr::Select(_, lhs, rhs) => {
238                let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
239                let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
240                (max(lhs_max_pos, rhs_max_pos), max(lhs_max_neg, rhs_max_neg))
241            }
242        }
243    }
244
245    /// Returns the maximum possible size, in bits, of each limb in `self.expr`.
246    /// This is already tracked in `FieldVariable`. However when auto saving in `FieldVariable::div`,
247    /// we need to know it from the `SymbolicExpr` only.
248    /// self should be a constraint expr.
249    pub fn constraint_limb_max_abs(&self, limb_bits: usize, num_limbs: usize) -> usize {
250        let canonical_limb_max_abs = (1 << limb_bits) - 1;
251        match self {
252            SymbolicExpr::Input(_) | SymbolicExpr::Var(_) | SymbolicExpr::Const(_, _, _) => {
253                canonical_limb_max_abs
254            }
255            SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
256                lhs.constraint_limb_max_abs(limb_bits, num_limbs)
257                    + rhs.constraint_limb_max_abs(limb_bits, num_limbs)
258            }
259            SymbolicExpr::Mul(lhs, rhs) => {
260                let left_num_limbs = lhs.expr_limbs(num_limbs);
261                let right_num_limbs = rhs.expr_limbs(num_limbs);
262                lhs.constraint_limb_max_abs(limb_bits, num_limbs)
263                    * rhs.constraint_limb_max_abs(limb_bits, num_limbs)
264                    * min(left_num_limbs, right_num_limbs)
265            }
266            SymbolicExpr::IntAdd(lhs, i) => {
267                lhs.constraint_limb_max_abs(limb_bits, num_limbs) + i.unsigned_abs()
268            }
269            SymbolicExpr::IntMul(lhs, i) => {
270                lhs.constraint_limb_max_abs(limb_bits, num_limbs) * i.unsigned_abs()
271            }
272            SymbolicExpr::Select(_, lhs, rhs) => max(
273                lhs.constraint_limb_max_abs(limb_bits, num_limbs),
274                rhs.constraint_limb_max_abs(limb_bits, num_limbs),
275            ),
276            SymbolicExpr::Div(_, _) => {
277                unreachable!("should not have division when calling limb_max_abs")
278            }
279        }
280    }
281
282    /// Returns the maximum possible size, in bits, of each carry in `self.expr - q * p`.
283    /// self should be a constraint expr.
284    ///
285    /// The cached value `proper_max` should equal `2^{limb_bits * num_limbs} - 1`.
286    pub fn constraint_carry_bits_with_pq(
287        &self,
288        prime: &BigUint,
289        limb_bits: usize,
290        num_limbs: usize,
291        proper_max: &BigUint,
292    ) -> usize {
293        let without_pq = self.constraint_limb_max_abs(limb_bits, num_limbs);
294        let (q_limbs, _) = self.constraint_limbs(prime, limb_bits, num_limbs, proper_max);
295        let canonical_limb_max_abs = (1 << limb_bits) - 1;
296        let limb_max_abs =
297            without_pq + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, num_limbs);
298        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
299        let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
300        carry_bits
301    }
302
303    /// Returns the number of limbs needed to represent the expression.
304    /// The parameter `num_limbs` is the number of limbs of a canonical field element.
305    pub fn expr_limbs(&self, num_limbs: usize) -> usize {
306        match self {
307            SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => num_limbs,
308            SymbolicExpr::Const(_, _, limbs) => *limbs,
309            SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
310                max(lhs.expr_limbs(num_limbs), rhs.expr_limbs(num_limbs))
311            }
312            SymbolicExpr::Mul(lhs, rhs) => {
313                lhs.expr_limbs(num_limbs) + rhs.expr_limbs(num_limbs) - 1
314            }
315            SymbolicExpr::Div(_, _) => {
316                unimplemented!()
317            }
318            SymbolicExpr::IntAdd(lhs, _) => lhs.expr_limbs(num_limbs),
319            SymbolicExpr::IntMul(lhs, _) => lhs.expr_limbs(num_limbs),
320            SymbolicExpr::Select(_, lhs, rhs) => {
321                let left = lhs.expr_limbs(num_limbs);
322                let right = rhs.expr_limbs(num_limbs);
323                assert_eq!(left, right);
324                left
325            }
326        }
327    }
328
329    /// Let `q` be such that `self.expr = q * p`.
330    /// Returns (q_limbs, carry_limbs) where q_limbs is the number of limbs in q
331    /// and carry_limbs is the number of limbs in the carry of the constraint self.expr - q * p = 0.
332    /// self should be a constraint expression.
333    ///
334    /// The cached value `proper_max` should equal `2^{limb_bits * num_limbs} - 1`.
335    pub fn constraint_limbs(
336        &self,
337        prime: &BigUint,
338        limb_bits: usize,
339        num_limbs: usize,
340        proper_max: &BigUint,
341    ) -> (usize, usize) {
342        let (max_pos_abs, max_neg_abs) = self.max_abs(proper_max);
343        let max_abs = max(max_pos_abs, max_neg_abs);
344        let max_q_abs = (&max_abs + prime - BigUint::one()) / prime;
345        let q_bits = max_q_abs.bits() as usize;
346        let p_bits = prime.bits() as usize;
347        let q_limbs = q_bits.div_ceil(limb_bits);
348        // Attention! This must match with prime_overflow in `FieldExpr::generate_subrow`
349        let p_limbs = p_bits.div_ceil(limb_bits);
350        let qp_limbs = q_limbs + p_limbs - 1;
351
352        let expr_limbs = self.expr_limbs(num_limbs);
353        let carry_limbs = max(expr_limbs, qp_limbs);
354        (q_limbs, carry_limbs)
355    }
356
357    /// Used in trace gen to compute `q``.
358    /// self should be a constraint expression.
359    pub fn evaluate_bigint(
360        &self,
361        inputs: &[BigInt],
362        variables: &[BigInt],
363        flags: &[bool],
364    ) -> BigInt {
365        match self {
366            SymbolicExpr::IntAdd(lhs, s) => {
367                lhs.evaluate_bigint(inputs, variables, flags) + BigInt::from_isize(*s).unwrap()
368            }
369            SymbolicExpr::IntMul(lhs, s) => {
370                lhs.evaluate_bigint(inputs, variables, flags) * BigInt::from_isize(*s).unwrap()
371            }
372            SymbolicExpr::Input(i) => inputs[*i].clone(),
373            SymbolicExpr::Var(i) => variables[*i].clone(),
374            SymbolicExpr::Const(_, val, _) => {
375                if val.is_zero() {
376                    BigInt::zero()
377                } else {
378                    BigInt::from_biguint(Sign::Plus, val.clone())
379                }
380            }
381            SymbolicExpr::Add(lhs, rhs) => {
382                lhs.evaluate_bigint(inputs, variables, flags)
383                    + rhs.evaluate_bigint(inputs, variables, flags)
384            }
385            SymbolicExpr::Sub(lhs, rhs) => {
386                lhs.evaluate_bigint(inputs, variables, flags)
387                    - rhs.evaluate_bigint(inputs, variables, flags)
388            }
389            SymbolicExpr::Mul(lhs, rhs) => {
390                lhs.evaluate_bigint(inputs, variables, flags)
391                    * rhs.evaluate_bigint(inputs, variables, flags)
392            }
393            SymbolicExpr::Select(flag_id, lhs, rhs) => {
394                if flags[*flag_id] {
395                    lhs.evaluate_bigint(inputs, variables, flags)
396                } else {
397                    rhs.evaluate_bigint(inputs, variables, flags)
398                }
399            }
400            SymbolicExpr::Div(_, _) => unreachable!(), // Division is not allowed in constraints.
401        }
402    }
403
404    /// Used in trace gen to compute carries.
405    /// self should be a constraint expression.
406    pub fn evaluate_overflow_isize(
407        &self,
408        inputs: &[OverflowInt<isize>],
409        variables: &[OverflowInt<isize>],
410        constants: &[OverflowInt<isize>],
411        flags: &[bool],
412    ) -> OverflowInt<isize> {
413        match self {
414            SymbolicExpr::IntAdd(lhs, s) => {
415                let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
416                left.int_add(*s, identity)
417            }
418            SymbolicExpr::IntMul(lhs, s) => {
419                let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
420                left.int_mul(*s, identity)
421            }
422            SymbolicExpr::Input(i) => inputs[*i].clone(),
423            SymbolicExpr::Var(i) => variables[*i].clone(),
424            SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
425            SymbolicExpr::Add(lhs, rhs) => {
426                lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
427                    + rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
428            }
429            SymbolicExpr::Sub(lhs, rhs) => {
430                lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
431                    - rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
432            }
433            SymbolicExpr::Mul(lhs, rhs) => {
434                lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
435                    * rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
436            }
437            SymbolicExpr::Select(flag_id, lhs, rhs) => {
438                let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
439                let right = rhs.evaluate_overflow_isize(inputs, variables, constants, flags);
440                let num_limbs = max(left.num_limbs(), right.num_limbs());
441
442                let res = if flags[*flag_id] {
443                    left.limbs().to_vec()
444                } else {
445                    right.limbs().to_vec()
446                };
447                let res = res.into_iter().chain(repeat(0)).take(num_limbs).collect();
448
449                OverflowInt::from_computed_limbs(
450                    res,
451                    max(left.limb_max_abs(), right.limb_max_abs()),
452                    max(left.max_overflow_bits(), right.max_overflow_bits()),
453                )
454            }
455            SymbolicExpr::Div(_, _) => unreachable!(), // Division is not allowed in constraints.
456        }
457    }
458
459    fn isize_to_expr<AB: AirBuilder>(s: isize) -> AB::Expr {
460        if s >= 0 {
461            AB::Expr::from_canonical_usize(s as usize)
462        } else {
463            -AB::Expr::from_canonical_usize(s.unsigned_abs())
464        }
465    }
466
467    /// Used in AIR eval.
468    /// self should be a constraint expression.
469    pub fn evaluate_overflow_expr<AB: AirBuilder>(
470        &self,
471        inputs: &[OverflowInt<AB::Expr>],
472        variables: &[OverflowInt<AB::Expr>],
473        constants: &[OverflowInt<AB::Expr>],
474        flags: &[AB::Var],
475    ) -> OverflowInt<AB::Expr> {
476        match self {
477            SymbolicExpr::IntAdd(lhs, s) => {
478                let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
479                left.int_add(*s, Self::isize_to_expr::<AB>)
480            }
481            SymbolicExpr::IntMul(lhs, s) => {
482                let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
483                left.int_mul(*s, Self::isize_to_expr::<AB>)
484            }
485            SymbolicExpr::Input(i) => inputs[*i].clone(),
486            SymbolicExpr::Var(i) => variables[*i].clone(),
487            SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
488            SymbolicExpr::Add(lhs, rhs) => {
489                lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
490                    + rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
491            }
492            SymbolicExpr::Sub(lhs, rhs) => {
493                lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
494                    - rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
495            }
496            SymbolicExpr::Mul(lhs, rhs) => {
497                lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
498                    * rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
499            }
500            SymbolicExpr::Select(flag_id, lhs, rhs) => {
501                let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
502                let right = rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
503                let num_limbs = max(left.num_limbs(), right.num_limbs());
504                let flag = flags[*flag_id];
505                let mut res = vec![];
506                for i in 0..num_limbs {
507                    res.push(
508                        (if i < left.num_limbs() {
509                            left.limb(i).clone()
510                        } else {
511                            AB::Expr::ZERO
512                        }) * flag.into()
513                            + (if i < right.num_limbs() {
514                                right.limb(i).clone()
515                            } else {
516                                AB::Expr::ZERO
517                            }) * (AB::Expr::ONE - flag.into()),
518                    );
519                }
520                OverflowInt::from_computed_limbs(
521                    res,
522                    max(left.limb_max_abs(), right.limb_max_abs()),
523                    max(left.max_overflow_bits(), right.max_overflow_bits()),
524                )
525            }
526            SymbolicExpr::Div(_, _) => unreachable!(), // Division is not allowed in constraints.
527        }
528    }
529
530    /// Result will be within [0, prime).
531    /// self should be a compute expression.
532    /// Note that division by zero will panic.
533    pub fn compute(
534        &self,
535        inputs: &[BigUint],
536        variables: &[BigUint],
537        flags: &[bool],
538        prime: &BigUint,
539    ) -> BigUint {
540        let res = match self {
541            SymbolicExpr::Input(i) => inputs[*i].clone() % prime,
542            SymbolicExpr::Var(i) => variables[*i].clone(),
543            SymbolicExpr::Const(_, val, _) => val.clone(),
544            SymbolicExpr::Add(lhs, rhs) => {
545                (lhs.compute(inputs, variables, flags, prime)
546                    + rhs.compute(inputs, variables, flags, prime))
547                    % prime
548            }
549            SymbolicExpr::Sub(lhs, rhs) => {
550                (prime + lhs.compute(inputs, variables, flags, prime)
551                    - rhs.compute(inputs, variables, flags, prime))
552                    % prime
553            }
554            SymbolicExpr::Mul(lhs, rhs) => {
555                (lhs.compute(inputs, variables, flags, prime)
556                    * rhs.compute(inputs, variables, flags, prime))
557                    % prime
558            }
559            SymbolicExpr::Div(lhs, rhs) => {
560                let left = lhs.compute(inputs, variables, flags, prime);
561                let right = rhs.compute(inputs, variables, flags, prime);
562                let right_inv = right.modinv(prime).unwrap();
563                (left * right_inv) % prime
564            }
565            SymbolicExpr::IntAdd(lhs, s) => {
566                let left = lhs.compute(inputs, variables, flags, prime);
567                let right = if *s >= 0 {
568                    BigUint::from_usize(*s as usize).unwrap()
569                } else {
570                    prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
571                };
572                (left + right) % prime
573            }
574            SymbolicExpr::IntMul(lhs, s) => {
575                let left = lhs.compute(inputs, variables, flags, prime);
576                let right = if *s >= 0 {
577                    BigUint::from_usize(*s as usize).unwrap()
578                } else {
579                    prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
580                };
581                (left * right) % prime
582            }
583            SymbolicExpr::Select(flag_id, lhs, rhs) => {
584                if flags[*flag_id] {
585                    lhs.compute(inputs, variables, flags, prime)
586                } else {
587                    rhs.compute(inputs, variables, flags, prime)
588                }
589            }
590        };
591        assert!(
592            res < prime.clone(),
593            "symbolic expr: {} evaluation exceeds prime",
594            self
595        );
596        res
597    }
598}