1use std::{
2 cmp::{max, min},
3 convert::identity,
4 iter::repeat,
5 ops::{Add, Div, Mul, Sub},
6};
7
8use num_bigint::{BigInt, BigUint, Sign};
9use num_traits::{FromPrimitive, One, Zero};
10use openvm_circuit_primitives::bigint::{
11 check_carry_to_zero::get_carry_max_abs_and_bits, OverflowInt,
12};
13use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra, p3_util::log2_ceil_usize};
14
15#[derive(Clone, Debug, PartialEq)]
19pub enum SymbolicExpr {
20 Input(usize),
21 Var(usize),
22 Const(usize, BigUint, usize), Add(Box<SymbolicExpr>, Box<SymbolicExpr>),
24 Sub(Box<SymbolicExpr>, Box<SymbolicExpr>),
25 Mul(Box<SymbolicExpr>, Box<SymbolicExpr>),
26 Div(Box<SymbolicExpr>, Box<SymbolicExpr>),
29 IntAdd(Box<SymbolicExpr>, isize),
31 IntMul(Box<SymbolicExpr>, isize),
33 Select(usize, Box<SymbolicExpr>, Box<SymbolicExpr>),
36}
37
38impl std::fmt::Display for SymbolicExpr {
39 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
40 match self {
41 SymbolicExpr::Input(i) => write!(f, "Input_{}", i),
42 SymbolicExpr::Var(i) => write!(f, "Var_{}", i),
43 SymbolicExpr::Const(i, _, _) => write!(f, "Const_{}", i),
44 SymbolicExpr::Add(lhs, rhs) => write!(f, "({} + {})", lhs, rhs),
45 SymbolicExpr::Sub(lhs, rhs) => write!(f, "({} - {})", lhs, rhs),
46 SymbolicExpr::Mul(lhs, rhs) => write!(f, "{} * {}", lhs, rhs),
47 SymbolicExpr::Div(lhs, rhs) => write!(f, "({} / {})", lhs, rhs),
48 SymbolicExpr::IntAdd(lhs, s) => write!(f, "({} + {})", lhs, s),
49 SymbolicExpr::IntMul(lhs, s) => write!(f, "({} x {})", lhs, s),
50 SymbolicExpr::Select(flag_id, lhs, rhs) => {
51 write!(f, "(if {} then {} else {})", flag_id, lhs, rhs)
52 }
53 }
54 }
55}
56
57impl Add for SymbolicExpr {
58 type Output = SymbolicExpr;
59
60 fn add(self, rhs: Self) -> Self::Output {
61 SymbolicExpr::Add(Box::new(self), Box::new(rhs))
62 }
63}
64
65impl Add<&SymbolicExpr> for SymbolicExpr {
66 type Output = SymbolicExpr;
67
68 fn add(self, rhs: &SymbolicExpr) -> Self::Output {
69 SymbolicExpr::Add(Box::new(self), Box::new(rhs.clone()))
70 }
71}
72
73impl Add for &SymbolicExpr {
74 type Output = SymbolicExpr;
75
76 fn add(self, rhs: &SymbolicExpr) -> Self::Output {
77 SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs.clone()))
78 }
79}
80
81impl Add<SymbolicExpr> for &SymbolicExpr {
82 type Output = SymbolicExpr;
83
84 fn add(self, rhs: SymbolicExpr) -> Self::Output {
85 SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs))
86 }
87}
88
89impl Sub for SymbolicExpr {
90 type Output = SymbolicExpr;
91
92 fn sub(self, rhs: Self) -> Self::Output {
93 SymbolicExpr::Sub(Box::new(self), Box::new(rhs))
94 }
95}
96
97impl Sub<&SymbolicExpr> for SymbolicExpr {
98 type Output = SymbolicExpr;
99
100 fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
101 SymbolicExpr::Sub(Box::new(self), Box::new(rhs.clone()))
102 }
103}
104
105impl Sub for &SymbolicExpr {
106 type Output = SymbolicExpr;
107
108 fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
109 SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs.clone()))
110 }
111}
112
113impl Sub<SymbolicExpr> for &SymbolicExpr {
114 type Output = SymbolicExpr;
115
116 fn sub(self, rhs: SymbolicExpr) -> Self::Output {
117 SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs))
118 }
119}
120
121impl Mul for SymbolicExpr {
122 type Output = SymbolicExpr;
123
124 fn mul(self, rhs: Self) -> Self::Output {
125 SymbolicExpr::Mul(Box::new(self), Box::new(rhs))
126 }
127}
128
129impl Mul<&SymbolicExpr> for SymbolicExpr {
130 type Output = SymbolicExpr;
131
132 fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
133 SymbolicExpr::Mul(Box::new(self), Box::new(rhs.clone()))
134 }
135}
136
137impl Mul for &SymbolicExpr {
138 type Output = SymbolicExpr;
139
140 fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
141 SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs.clone()))
142 }
143}
144
145impl Mul<SymbolicExpr> for &SymbolicExpr {
146 type Output = SymbolicExpr;
147
148 fn mul(self, rhs: SymbolicExpr) -> Self::Output {
149 SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs))
150 }
151}
152
153impl Div for SymbolicExpr {
155 type Output = SymbolicExpr;
156
157 fn div(self, rhs: Self) -> Self::Output {
158 SymbolicExpr::Div(Box::new(self), Box::new(rhs))
159 }
160}
161
162impl Div<&SymbolicExpr> for SymbolicExpr {
164 type Output = SymbolicExpr;
165
166 fn div(self, rhs: &SymbolicExpr) -> Self::Output {
167 SymbolicExpr::Div(Box::new(self), Box::new(rhs.clone()))
168 }
169}
170
171impl Div for &SymbolicExpr {
173 type Output = SymbolicExpr;
174
175 fn div(self, rhs: &SymbolicExpr) -> Self::Output {
176 SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs.clone()))
177 }
178}
179
180impl Div<SymbolicExpr> for &SymbolicExpr {
182 type Output = SymbolicExpr;
183
184 fn div(self, rhs: SymbolicExpr) -> Self::Output {
185 SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs))
186 }
187}
188
189impl SymbolicExpr {
190 fn max_abs(&self, proper_max: &BigUint) -> (BigUint, BigUint) {
197 match self {
198 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => (proper_max.clone(), BigUint::zero()),
199 SymbolicExpr::Const(_, val, _) => (val.clone(), BigUint::zero()),
200 SymbolicExpr::Add(lhs, rhs) => {
201 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
202 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
203 (lhs_max_pos + rhs_max_pos, lhs_max_neg + rhs_max_neg)
204 }
205 SymbolicExpr::Sub(lhs, rhs) => {
206 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
207 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
208 (lhs_max_pos + rhs_max_neg, lhs_max_neg + rhs_max_pos)
209 }
210 SymbolicExpr::Mul(lhs, rhs) => {
211 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
212 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
213 (
214 max(&lhs_max_pos * &rhs_max_pos, &lhs_max_neg * &rhs_max_neg),
215 max(&lhs_max_pos * &rhs_max_neg, &lhs_max_neg * &rhs_max_pos),
216 )
217 }
218 SymbolicExpr::Div(_, _) => {
219 unreachable!()
221 }
222 SymbolicExpr::IntAdd(lhs, s) => {
223 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
224 let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
225 (lhs_max_pos + &scalar, lhs_max_neg + &scalar)
228 }
229 SymbolicExpr::IntMul(lhs, s) => {
230 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
231 let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
232 if *s < 0 {
233 (lhs_max_neg * &scalar, lhs_max_pos * &scalar)
234 } else {
235 (lhs_max_pos * &scalar, lhs_max_neg * &scalar)
236 }
237 }
238 SymbolicExpr::Select(_, lhs, rhs) => {
239 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
240 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
241 (max(lhs_max_pos, rhs_max_pos), max(lhs_max_neg, rhs_max_neg))
242 }
243 }
244 }
245
246 pub fn constraint_limb_max_abs(&self, limb_bits: usize, num_limbs: usize) -> usize {
251 let canonical_limb_max_abs = (1 << limb_bits) - 1;
252 match self {
253 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) | SymbolicExpr::Const(_, _, _) => {
254 canonical_limb_max_abs
255 }
256 SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
257 lhs.constraint_limb_max_abs(limb_bits, num_limbs)
258 + rhs.constraint_limb_max_abs(limb_bits, num_limbs)
259 }
260 SymbolicExpr::Mul(lhs, rhs) => {
261 let left_num_limbs = lhs.expr_limbs(num_limbs);
262 let right_num_limbs = rhs.expr_limbs(num_limbs);
263 lhs.constraint_limb_max_abs(limb_bits, num_limbs)
264 * rhs.constraint_limb_max_abs(limb_bits, num_limbs)
265 * min(left_num_limbs, right_num_limbs)
266 }
267 SymbolicExpr::IntAdd(lhs, i) => {
268 lhs.constraint_limb_max_abs(limb_bits, num_limbs) + i.unsigned_abs()
269 }
270 SymbolicExpr::IntMul(lhs, i) => {
271 lhs.constraint_limb_max_abs(limb_bits, num_limbs) * i.unsigned_abs()
272 }
273 SymbolicExpr::Select(_, lhs, rhs) => max(
274 lhs.constraint_limb_max_abs(limb_bits, num_limbs),
275 rhs.constraint_limb_max_abs(limb_bits, num_limbs),
276 ),
277 SymbolicExpr::Div(_, _) => {
278 unreachable!("should not have division when calling limb_max_abs")
279 }
280 }
281 }
282
283 pub fn constraint_carry_bits_with_pq(
288 &self,
289 prime: &BigUint,
290 limb_bits: usize,
291 num_limbs: usize,
292 proper_max: &BigUint,
293 ) -> usize {
294 let without_pq = self.constraint_limb_max_abs(limb_bits, num_limbs);
295 let (q_limbs, _) = self.constraint_limbs(prime, limb_bits, num_limbs, proper_max);
296 let canonical_limb_max_abs = (1 << limb_bits) - 1;
297 let limb_max_abs =
298 without_pq + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, num_limbs);
299 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
300 let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
301 carry_bits
302 }
303
304 pub fn expr_limbs(&self, num_limbs: usize) -> usize {
307 match self {
308 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => num_limbs,
309 SymbolicExpr::Const(_, _, limbs) => *limbs,
310 SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
311 max(lhs.expr_limbs(num_limbs), rhs.expr_limbs(num_limbs))
312 }
313 SymbolicExpr::Mul(lhs, rhs) => {
314 lhs.expr_limbs(num_limbs) + rhs.expr_limbs(num_limbs) - 1
315 }
316 SymbolicExpr::Div(_, _) => {
317 unimplemented!()
318 }
319 SymbolicExpr::IntAdd(lhs, _) => lhs.expr_limbs(num_limbs),
320 SymbolicExpr::IntMul(lhs, _) => lhs.expr_limbs(num_limbs),
321 SymbolicExpr::Select(_, lhs, rhs) => {
322 let left = lhs.expr_limbs(num_limbs);
323 let right = rhs.expr_limbs(num_limbs);
324 assert_eq!(left, right);
325 left
326 }
327 }
328 }
329
330 pub fn constraint_limbs(
337 &self,
338 prime: &BigUint,
339 limb_bits: usize,
340 num_limbs: usize,
341 proper_max: &BigUint,
342 ) -> (usize, usize) {
343 let (max_pos_abs, max_neg_abs) = self.max_abs(proper_max);
344 let max_abs = max(max_pos_abs, max_neg_abs);
345 let max_q_abs = (&max_abs + prime - BigUint::one()) / prime;
346 let q_bits = max_q_abs.bits() as usize;
347 let p_bits = prime.bits() as usize;
348 let q_limbs = q_bits.div_ceil(limb_bits);
349 let p_limbs = p_bits.div_ceil(limb_bits);
351 let qp_limbs = q_limbs + p_limbs - 1;
352
353 let expr_limbs = self.expr_limbs(num_limbs);
354 let carry_limbs = max(expr_limbs, qp_limbs);
355 (q_limbs, carry_limbs)
356 }
357
358 pub fn evaluate_bigint(
361 &self,
362 inputs: &[BigInt],
363 variables: &[BigInt],
364 flags: &[bool],
365 ) -> BigInt {
366 match self {
367 SymbolicExpr::IntAdd(lhs, s) => {
368 lhs.evaluate_bigint(inputs, variables, flags) + BigInt::from_isize(*s).unwrap()
369 }
370 SymbolicExpr::IntMul(lhs, s) => {
371 lhs.evaluate_bigint(inputs, variables, flags) * BigInt::from_isize(*s).unwrap()
372 }
373 SymbolicExpr::Input(i) => inputs[*i].clone(),
374 SymbolicExpr::Var(i) => variables[*i].clone(),
375 SymbolicExpr::Const(_, val, _) => {
376 if val.is_zero() {
377 BigInt::zero()
378 } else {
379 BigInt::from_biguint(Sign::Plus, val.clone())
380 }
381 }
382 SymbolicExpr::Add(lhs, rhs) => {
383 lhs.evaluate_bigint(inputs, variables, flags)
384 + rhs.evaluate_bigint(inputs, variables, flags)
385 }
386 SymbolicExpr::Sub(lhs, rhs) => {
387 lhs.evaluate_bigint(inputs, variables, flags)
388 - rhs.evaluate_bigint(inputs, variables, flags)
389 }
390 SymbolicExpr::Mul(lhs, rhs) => {
391 lhs.evaluate_bigint(inputs, variables, flags)
392 * rhs.evaluate_bigint(inputs, variables, flags)
393 }
394 SymbolicExpr::Select(flag_id, lhs, rhs) => {
395 if flags[*flag_id] {
396 lhs.evaluate_bigint(inputs, variables, flags)
397 } else {
398 rhs.evaluate_bigint(inputs, variables, flags)
399 }
400 }
401 SymbolicExpr::Div(_, _) => unreachable!(), }
403 }
404
405 pub fn evaluate_overflow_isize(
408 &self,
409 inputs: &[OverflowInt<isize>],
410 variables: &[OverflowInt<isize>],
411 constants: &[OverflowInt<isize>],
412 flags: &[bool],
413 ) -> OverflowInt<isize> {
414 match self {
415 SymbolicExpr::IntAdd(lhs, s) => {
416 let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
417 left.int_add(*s, identity)
418 }
419 SymbolicExpr::IntMul(lhs, s) => {
420 let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
421 left.int_mul(*s, identity)
422 }
423 SymbolicExpr::Input(i) => inputs[*i].clone(),
424 SymbolicExpr::Var(i) => variables[*i].clone(),
425 SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
426 SymbolicExpr::Add(lhs, rhs) => {
427 lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
428 + rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
429 }
430 SymbolicExpr::Sub(lhs, rhs) => {
431 lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
432 - rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
433 }
434 SymbolicExpr::Mul(lhs, rhs) => {
435 lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
436 * rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
437 }
438 SymbolicExpr::Select(flag_id, lhs, rhs) => {
439 let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
440 let right = rhs.evaluate_overflow_isize(inputs, variables, constants, flags);
441 let num_limbs = max(left.num_limbs(), right.num_limbs());
442
443 let res = if flags[*flag_id] {
444 left.limbs().to_vec()
445 } else {
446 right.limbs().to_vec()
447 };
448 let res = res.into_iter().chain(repeat(0)).take(num_limbs).collect();
449
450 OverflowInt::from_computed_limbs(
451 res,
452 max(left.limb_max_abs(), right.limb_max_abs()),
453 max(left.max_overflow_bits(), right.max_overflow_bits()),
454 )
455 }
456 SymbolicExpr::Div(_, _) => unreachable!(), }
458 }
459
460 fn isize_to_expr<AB: AirBuilder>(s: isize) -> AB::Expr {
461 if s >= 0 {
462 AB::Expr::from_canonical_usize(s as usize)
463 } else {
464 -AB::Expr::from_canonical_usize(s.unsigned_abs())
465 }
466 }
467
468 pub fn evaluate_overflow_expr<AB: AirBuilder>(
471 &self,
472 inputs: &[OverflowInt<AB::Expr>],
473 variables: &[OverflowInt<AB::Expr>],
474 constants: &[OverflowInt<AB::Expr>],
475 flags: &[AB::Var],
476 ) -> OverflowInt<AB::Expr> {
477 match self {
478 SymbolicExpr::IntAdd(lhs, s) => {
479 let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
480 left.int_add(*s, Self::isize_to_expr::<AB>)
481 }
482 SymbolicExpr::IntMul(lhs, s) => {
483 let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
484 left.int_mul(*s, Self::isize_to_expr::<AB>)
485 }
486 SymbolicExpr::Input(i) => inputs[*i].clone(),
487 SymbolicExpr::Var(i) => variables[*i].clone(),
488 SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
489 SymbolicExpr::Add(lhs, rhs) => {
490 lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
491 + rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
492 }
493 SymbolicExpr::Sub(lhs, rhs) => {
494 lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
495 - rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
496 }
497 SymbolicExpr::Mul(lhs, rhs) => {
498 lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
499 * rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
500 }
501 SymbolicExpr::Select(flag_id, lhs, rhs) => {
502 let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
503 let right = rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
504 let num_limbs = max(left.num_limbs(), right.num_limbs());
505 let flag = flags[*flag_id];
506 let mut res = vec![];
507 for i in 0..num_limbs {
508 res.push(
509 (if i < left.num_limbs() {
510 left.limb(i).clone()
511 } else {
512 AB::Expr::ZERO
513 }) * flag.into()
514 + (if i < right.num_limbs() {
515 right.limb(i).clone()
516 } else {
517 AB::Expr::ZERO
518 }) * (AB::Expr::ONE - flag.into()),
519 );
520 }
521 OverflowInt::from_computed_limbs(
522 res,
523 max(left.limb_max_abs(), right.limb_max_abs()),
524 max(left.max_overflow_bits(), right.max_overflow_bits()),
525 )
526 }
527 SymbolicExpr::Div(_, _) => unreachable!(), }
529 }
530
531 pub fn compute(
535 &self,
536 inputs: &[BigUint],
537 variables: &[BigUint],
538 flags: &[bool],
539 prime: &BigUint,
540 ) -> BigUint {
541 let res = match self {
542 SymbolicExpr::Input(i) => inputs[*i].clone() % prime,
543 SymbolicExpr::Var(i) => variables[*i].clone(),
544 SymbolicExpr::Const(_, val, _) => val.clone(),
545 SymbolicExpr::Add(lhs, rhs) => {
546 (lhs.compute(inputs, variables, flags, prime)
547 + rhs.compute(inputs, variables, flags, prime))
548 % prime
549 }
550 SymbolicExpr::Sub(lhs, rhs) => {
551 (prime + lhs.compute(inputs, variables, flags, prime)
552 - rhs.compute(inputs, variables, flags, prime))
553 % prime
554 }
555 SymbolicExpr::Mul(lhs, rhs) => {
556 (lhs.compute(inputs, variables, flags, prime)
557 * rhs.compute(inputs, variables, flags, prime))
558 % prime
559 }
560 SymbolicExpr::Div(lhs, rhs) => {
561 let left = lhs.compute(inputs, variables, flags, prime);
562 let right = rhs.compute(inputs, variables, flags, prime);
563 let right_inv = right.modinv(prime).unwrap();
564 (left * right_inv) % prime
565 }
566 SymbolicExpr::IntAdd(lhs, s) => {
567 let left = lhs.compute(inputs, variables, flags, prime);
568 let right = if *s >= 0 {
569 BigUint::from_usize(*s as usize).unwrap()
570 } else {
571 prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
572 };
573 (left + right) % prime
574 }
575 SymbolicExpr::IntMul(lhs, s) => {
576 let left = lhs.compute(inputs, variables, flags, prime);
577 let right = if *s >= 0 {
578 BigUint::from_usize(*s as usize).unwrap()
579 } else {
580 prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
581 };
582 (left * right) % prime
583 }
584 SymbolicExpr::Select(flag_id, lhs, rhs) => {
585 if flags[*flag_id] {
586 lhs.compute(inputs, variables, flags, prime)
587 } else {
588 rhs.compute(inputs, variables, flags, prime)
589 }
590 }
591 };
592 assert!(
593 res < prime.clone(),
594 "symbolic expr: {} evaluation exceeds prime",
595 self
596 );
597 res
598 }
599}