1use std::{
2 cmp::{max, min},
3 convert::identity,
4 iter::repeat,
5 ops::{Add, Div, Mul, Sub},
6};
7
8use num_bigint::{BigInt, BigUint, Sign};
9use num_traits::{FromPrimitive, One, Zero};
10use openvm_circuit_primitives::bigint::{
11 check_carry_to_zero::get_carry_max_abs_and_bits, OverflowInt,
12};
13use openvm_stark_backend::{
14 p3_air::AirBuilder, p3_field::PrimeCharacteristicRing, p3_util::log2_ceil_usize,
15};
16
17#[derive(Clone, Debug, PartialEq)]
21pub enum SymbolicExpr {
22 Input(usize),
23 Var(usize),
24 Const(usize, BigUint, usize), Add(Box<SymbolicExpr>, Box<SymbolicExpr>),
26 Sub(Box<SymbolicExpr>, Box<SymbolicExpr>),
27 Mul(Box<SymbolicExpr>, Box<SymbolicExpr>),
28 Div(Box<SymbolicExpr>, Box<SymbolicExpr>),
31 IntAdd(Box<SymbolicExpr>, isize),
33 IntMul(Box<SymbolicExpr>, isize),
35 Select(usize, Box<SymbolicExpr>, Box<SymbolicExpr>),
38}
39
40impl std::fmt::Display for SymbolicExpr {
41 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42 match self {
43 SymbolicExpr::Input(i) => write!(f, "Input_{i}"),
44 SymbolicExpr::Var(i) => write!(f, "Var_{i}"),
45 SymbolicExpr::Const(i, _, _) => write!(f, "Const_{i}"),
46 SymbolicExpr::Add(lhs, rhs) => write!(f, "({lhs} + {rhs})"),
47 SymbolicExpr::Sub(lhs, rhs) => write!(f, "({lhs} - {rhs})"),
48 SymbolicExpr::Mul(lhs, rhs) => write!(f, "{lhs} * {rhs}"),
49 SymbolicExpr::Div(lhs, rhs) => write!(f, "({lhs} / {rhs})"),
50 SymbolicExpr::IntAdd(lhs, s) => write!(f, "({lhs} + {s})"),
51 SymbolicExpr::IntMul(lhs, s) => write!(f, "({lhs} x {s})"),
52 SymbolicExpr::Select(flag_id, lhs, rhs) => {
53 write!(f, "(if {flag_id} then {lhs} else {rhs})")
54 }
55 }
56 }
57}
58
59impl Add for SymbolicExpr {
60 type Output = SymbolicExpr;
61
62 fn add(self, rhs: Self) -> Self::Output {
63 SymbolicExpr::Add(Box::new(self), Box::new(rhs))
64 }
65}
66
67impl Add<&SymbolicExpr> for SymbolicExpr {
68 type Output = SymbolicExpr;
69
70 fn add(self, rhs: &SymbolicExpr) -> Self::Output {
71 SymbolicExpr::Add(Box::new(self), Box::new(rhs.clone()))
72 }
73}
74
75impl Add for &SymbolicExpr {
76 type Output = SymbolicExpr;
77
78 fn add(self, rhs: &SymbolicExpr) -> Self::Output {
79 SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs.clone()))
80 }
81}
82
83impl Add<SymbolicExpr> for &SymbolicExpr {
84 type Output = SymbolicExpr;
85
86 fn add(self, rhs: SymbolicExpr) -> Self::Output {
87 SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs))
88 }
89}
90
91impl Sub for SymbolicExpr {
92 type Output = SymbolicExpr;
93
94 fn sub(self, rhs: Self) -> Self::Output {
95 SymbolicExpr::Sub(Box::new(self), Box::new(rhs))
96 }
97}
98
99impl Sub<&SymbolicExpr> for SymbolicExpr {
100 type Output = SymbolicExpr;
101
102 fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
103 SymbolicExpr::Sub(Box::new(self), Box::new(rhs.clone()))
104 }
105}
106
107impl Sub for &SymbolicExpr {
108 type Output = SymbolicExpr;
109
110 fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
111 SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs.clone()))
112 }
113}
114
115impl Sub<SymbolicExpr> for &SymbolicExpr {
116 type Output = SymbolicExpr;
117
118 fn sub(self, rhs: SymbolicExpr) -> Self::Output {
119 SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs))
120 }
121}
122
123impl Mul for SymbolicExpr {
124 type Output = SymbolicExpr;
125
126 fn mul(self, rhs: Self) -> Self::Output {
127 SymbolicExpr::Mul(Box::new(self), Box::new(rhs))
128 }
129}
130
131impl Mul<&SymbolicExpr> for SymbolicExpr {
132 type Output = SymbolicExpr;
133
134 fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
135 SymbolicExpr::Mul(Box::new(self), Box::new(rhs.clone()))
136 }
137}
138
139impl Mul for &SymbolicExpr {
140 type Output = SymbolicExpr;
141
142 fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
143 SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs.clone()))
144 }
145}
146
147impl Mul<SymbolicExpr> for &SymbolicExpr {
148 type Output = SymbolicExpr;
149
150 fn mul(self, rhs: SymbolicExpr) -> Self::Output {
151 SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs))
152 }
153}
154
155impl Div for SymbolicExpr {
157 type Output = SymbolicExpr;
158
159 fn div(self, rhs: Self) -> Self::Output {
160 SymbolicExpr::Div(Box::new(self), Box::new(rhs))
161 }
162}
163
164impl Div<&SymbolicExpr> for SymbolicExpr {
166 type Output = SymbolicExpr;
167
168 fn div(self, rhs: &SymbolicExpr) -> Self::Output {
169 SymbolicExpr::Div(Box::new(self), Box::new(rhs.clone()))
170 }
171}
172
173impl Div for &SymbolicExpr {
175 type Output = SymbolicExpr;
176
177 fn div(self, rhs: &SymbolicExpr) -> Self::Output {
178 SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs.clone()))
179 }
180}
181
182impl Div<SymbolicExpr> for &SymbolicExpr {
184 type Output = SymbolicExpr;
185
186 fn div(self, rhs: SymbolicExpr) -> Self::Output {
187 SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs))
188 }
189}
190
191impl SymbolicExpr {
192 fn max_abs(&self, proper_max: &BigUint) -> (BigUint, BigUint) {
199 match self {
200 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => (proper_max.clone(), BigUint::zero()),
201 SymbolicExpr::Const(_, val, _) => (val.clone(), BigUint::zero()),
202 SymbolicExpr::Add(lhs, rhs) => {
203 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
204 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
205 (lhs_max_pos + rhs_max_pos, lhs_max_neg + rhs_max_neg)
206 }
207 SymbolicExpr::Sub(lhs, rhs) => {
208 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
209 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
210 (lhs_max_pos + rhs_max_neg, lhs_max_neg + rhs_max_pos)
211 }
212 SymbolicExpr::Mul(lhs, rhs) => {
213 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
214 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
215 (
216 max(&lhs_max_pos * &rhs_max_pos, &lhs_max_neg * &rhs_max_neg),
217 max(&lhs_max_pos * &rhs_max_neg, &lhs_max_neg * &rhs_max_pos),
218 )
219 }
220 SymbolicExpr::Div(_, _) => {
221 unreachable!()
223 }
224 SymbolicExpr::IntAdd(lhs, s) => {
225 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
226 let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
227 (lhs_max_pos + &scalar, lhs_max_neg + &scalar)
230 }
231 SymbolicExpr::IntMul(lhs, s) => {
232 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
233 let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
234 if *s < 0 {
235 (lhs_max_neg * &scalar, lhs_max_pos * &scalar)
236 } else {
237 (lhs_max_pos * &scalar, lhs_max_neg * &scalar)
238 }
239 }
240 SymbolicExpr::Select(_, lhs, rhs) => {
241 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
242 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
243 (max(lhs_max_pos, rhs_max_pos), max(lhs_max_neg, rhs_max_neg))
244 }
245 }
246 }
247
248 pub fn constraint_limb_max_abs(&self, limb_bits: usize, num_limbs: usize) -> usize {
253 let canonical_limb_max_abs = (1 << limb_bits) - 1;
254 match self {
255 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) | SymbolicExpr::Const(_, _, _) => {
256 canonical_limb_max_abs
257 }
258 SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
259 lhs.constraint_limb_max_abs(limb_bits, num_limbs)
260 + rhs.constraint_limb_max_abs(limb_bits, num_limbs)
261 }
262 SymbolicExpr::Mul(lhs, rhs) => {
263 let left_num_limbs = lhs.expr_limbs(num_limbs);
264 let right_num_limbs = rhs.expr_limbs(num_limbs);
265 lhs.constraint_limb_max_abs(limb_bits, num_limbs)
266 * rhs.constraint_limb_max_abs(limb_bits, num_limbs)
267 * min(left_num_limbs, right_num_limbs)
268 }
269 SymbolicExpr::IntAdd(lhs, i) => {
270 lhs.constraint_limb_max_abs(limb_bits, num_limbs) + i.unsigned_abs()
271 }
272 SymbolicExpr::IntMul(lhs, i) => {
273 lhs.constraint_limb_max_abs(limb_bits, num_limbs) * i.unsigned_abs()
274 }
275 SymbolicExpr::Select(_, lhs, rhs) => max(
276 lhs.constraint_limb_max_abs(limb_bits, num_limbs),
277 rhs.constraint_limb_max_abs(limb_bits, num_limbs),
278 ),
279 SymbolicExpr::Div(_, _) => {
280 unreachable!("should not have division when calling limb_max_abs")
281 }
282 }
283 }
284
285 pub fn constraint_carry_bits_with_pq(
290 &self,
291 prime: &BigUint,
292 limb_bits: usize,
293 num_limbs: usize,
294 proper_max: &BigUint,
295 ) -> usize {
296 let without_pq = self.constraint_limb_max_abs(limb_bits, num_limbs);
297 let (q_limbs, _) = self.constraint_limbs(prime, limb_bits, num_limbs, proper_max);
298 let canonical_limb_max_abs = (1 << limb_bits) - 1;
299 let limb_max_abs =
300 without_pq + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, num_limbs);
301 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
302 let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
303 carry_bits
304 }
305
306 pub fn expr_limbs(&self, num_limbs: usize) -> usize {
309 match self {
310 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => num_limbs,
311 SymbolicExpr::Const(_, _, limbs) => *limbs,
312 SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
313 max(lhs.expr_limbs(num_limbs), rhs.expr_limbs(num_limbs))
314 }
315 SymbolicExpr::Mul(lhs, rhs) => {
316 lhs.expr_limbs(num_limbs) + rhs.expr_limbs(num_limbs) - 1
317 }
318 SymbolicExpr::Div(_, _) => {
319 unimplemented!()
320 }
321 SymbolicExpr::IntAdd(lhs, _) => lhs.expr_limbs(num_limbs),
322 SymbolicExpr::IntMul(lhs, _) => lhs.expr_limbs(num_limbs),
323 SymbolicExpr::Select(_, lhs, rhs) => {
324 let left = lhs.expr_limbs(num_limbs);
325 let right = rhs.expr_limbs(num_limbs);
326 assert_eq!(left, right);
327 left
328 }
329 }
330 }
331
332 pub fn constraint_limbs(
339 &self,
340 prime: &BigUint,
341 limb_bits: usize,
342 num_limbs: usize,
343 proper_max: &BigUint,
344 ) -> (usize, usize) {
345 let (max_pos_abs, max_neg_abs) = self.max_abs(proper_max);
346 let max_abs = max(max_pos_abs, max_neg_abs);
347 let max_q_abs = (&max_abs + prime - BigUint::one()) / prime;
348 let q_bits = max_q_abs.bits() as usize;
349 let p_bits = prime.bits() as usize;
350 let q_limbs = q_bits.div_ceil(limb_bits);
351 let p_limbs = p_bits.div_ceil(limb_bits);
353 let qp_limbs = q_limbs + p_limbs - 1;
354
355 let expr_limbs = self.expr_limbs(num_limbs);
356 let carry_limbs = max(expr_limbs, qp_limbs);
357 (q_limbs, carry_limbs)
358 }
359
360 pub fn evaluate_bigint(
363 &self,
364 inputs: &[BigInt],
365 variables: &[BigInt],
366 flags: &[bool],
367 ) -> BigInt {
368 match self {
369 SymbolicExpr::IntAdd(lhs, s) => {
370 lhs.evaluate_bigint(inputs, variables, flags) + BigInt::from_isize(*s).unwrap()
371 }
372 SymbolicExpr::IntMul(lhs, s) => {
373 lhs.evaluate_bigint(inputs, variables, flags) * BigInt::from_isize(*s).unwrap()
374 }
375 SymbolicExpr::Input(i) => inputs[*i].clone(),
376 SymbolicExpr::Var(i) => variables[*i].clone(),
377 SymbolicExpr::Const(_, val, _) => {
378 if val.is_zero() {
379 BigInt::zero()
380 } else {
381 BigInt::from_biguint(Sign::Plus, val.clone())
382 }
383 }
384 SymbolicExpr::Add(lhs, rhs) => {
385 lhs.evaluate_bigint(inputs, variables, flags)
386 + rhs.evaluate_bigint(inputs, variables, flags)
387 }
388 SymbolicExpr::Sub(lhs, rhs) => {
389 lhs.evaluate_bigint(inputs, variables, flags)
390 - rhs.evaluate_bigint(inputs, variables, flags)
391 }
392 SymbolicExpr::Mul(lhs, rhs) => {
393 lhs.evaluate_bigint(inputs, variables, flags)
394 * rhs.evaluate_bigint(inputs, variables, flags)
395 }
396 SymbolicExpr::Select(flag_id, lhs, rhs) => {
397 if flags[*flag_id] {
398 lhs.evaluate_bigint(inputs, variables, flags)
399 } else {
400 rhs.evaluate_bigint(inputs, variables, flags)
401 }
402 }
403 SymbolicExpr::Div(_, _) => unreachable!(), }
405 }
406
407 pub fn evaluate_overflow_isize(
410 &self,
411 inputs: &[OverflowInt<isize>],
412 variables: &[OverflowInt<isize>],
413 constants: &[OverflowInt<isize>],
414 flags: &[bool],
415 ) -> OverflowInt<isize> {
416 match self {
417 SymbolicExpr::IntAdd(lhs, s) => {
418 let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
419 left.int_add(*s, identity)
420 }
421 SymbolicExpr::IntMul(lhs, s) => {
422 let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
423 left.int_mul(*s, identity)
424 }
425 SymbolicExpr::Input(i) => inputs[*i].clone(),
426 SymbolicExpr::Var(i) => variables[*i].clone(),
427 SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
428 SymbolicExpr::Add(lhs, rhs) => {
429 lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
430 + rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
431 }
432 SymbolicExpr::Sub(lhs, rhs) => {
433 lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
434 - rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
435 }
436 SymbolicExpr::Mul(lhs, rhs) => {
437 lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
438 * rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
439 }
440 SymbolicExpr::Select(flag_id, lhs, rhs) => {
441 let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
442 let right = rhs.evaluate_overflow_isize(inputs, variables, constants, flags);
443 let num_limbs = max(left.num_limbs(), right.num_limbs());
444
445 let res = if flags[*flag_id] {
446 left.limbs().to_vec()
447 } else {
448 right.limbs().to_vec()
449 };
450 let res = res.into_iter().chain(repeat(0)).take(num_limbs).collect();
451
452 OverflowInt::from_computed_limbs(
453 res,
454 max(left.limb_max_abs(), right.limb_max_abs()),
455 max(left.max_overflow_bits(), right.max_overflow_bits()),
456 )
457 }
458 SymbolicExpr::Div(_, _) => unreachable!(), }
460 }
461
462 fn isize_to_expr<AB: AirBuilder>(s: isize) -> AB::Expr {
463 if s >= 0 {
464 AB::Expr::from_usize(s as usize)
465 } else {
466 -AB::Expr::from_usize(s.unsigned_abs())
467 }
468 }
469
470 pub fn evaluate_overflow_expr<AB: AirBuilder>(
473 &self,
474 inputs: &[OverflowInt<AB::Expr>],
475 variables: &[OverflowInt<AB::Expr>],
476 constants: &[OverflowInt<AB::Expr>],
477 flags: &[AB::Var],
478 ) -> OverflowInt<AB::Expr> {
479 match self {
480 SymbolicExpr::IntAdd(lhs, s) => {
481 let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
482 left.int_add(*s, Self::isize_to_expr::<AB>)
483 }
484 SymbolicExpr::IntMul(lhs, s) => {
485 let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
486 left.int_mul(*s, Self::isize_to_expr::<AB>)
487 }
488 SymbolicExpr::Input(i) => inputs[*i].clone(),
489 SymbolicExpr::Var(i) => variables[*i].clone(),
490 SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
491 SymbolicExpr::Add(lhs, rhs) => {
492 lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
493 + rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
494 }
495 SymbolicExpr::Sub(lhs, rhs) => {
496 lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
497 - rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
498 }
499 SymbolicExpr::Mul(lhs, rhs) => {
500 lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
501 * rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
502 }
503 SymbolicExpr::Select(flag_id, lhs, rhs) => {
504 let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
505 let right = rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
506 let num_limbs = max(left.num_limbs(), right.num_limbs());
507 let flag = &flags[*flag_id];
508 let mut res = vec![];
509 for i in 0..num_limbs {
510 res.push(
511 (if i < left.num_limbs() {
512 left.limb(i).clone()
513 } else {
514 AB::Expr::ZERO
515 }) * flag.clone()
516 + (if i < right.num_limbs() {
517 right.limb(i).clone()
518 } else {
519 AB::Expr::ZERO
520 }) * (AB::Expr::ONE - flag.clone()),
521 );
522 }
523 OverflowInt::from_computed_limbs(
524 res,
525 max(left.limb_max_abs(), right.limb_max_abs()),
526 max(left.max_overflow_bits(), right.max_overflow_bits()),
527 )
528 }
529 SymbolicExpr::Div(_, _) => unreachable!(), }
531 }
532
533 pub fn compute(
537 &self,
538 inputs: &[BigUint],
539 variables: &[BigUint],
540 flags: &[bool],
541 prime: &BigUint,
542 ) -> BigUint {
543 let res = match self {
544 SymbolicExpr::Input(i) => inputs[*i].clone() % prime,
545 SymbolicExpr::Var(i) => variables[*i].clone(),
546 SymbolicExpr::Const(_, val, _) => val.clone(),
547 SymbolicExpr::Add(lhs, rhs) => {
548 (lhs.compute(inputs, variables, flags, prime)
549 + rhs.compute(inputs, variables, flags, prime))
550 % prime
551 }
552 SymbolicExpr::Sub(lhs, rhs) => {
553 (prime + lhs.compute(inputs, variables, flags, prime)
554 - rhs.compute(inputs, variables, flags, prime))
555 % prime
556 }
557 SymbolicExpr::Mul(lhs, rhs) => {
558 (lhs.compute(inputs, variables, flags, prime)
559 * rhs.compute(inputs, variables, flags, prime))
560 % prime
561 }
562 SymbolicExpr::Div(lhs, rhs) => {
563 let left = lhs.compute(inputs, variables, flags, prime);
564 let right = rhs.compute(inputs, variables, flags, prime);
565 let right_inv = right.modinv(prime).unwrap();
566 (left * right_inv) % prime
567 }
568 SymbolicExpr::IntAdd(lhs, s) => {
569 let left = lhs.compute(inputs, variables, flags, prime);
570 let right = if *s >= 0 {
571 BigUint::from_usize(*s as usize).unwrap()
572 } else {
573 prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
574 };
575 (left + right) % prime
576 }
577 SymbolicExpr::IntMul(lhs, s) => {
578 let left = lhs.compute(inputs, variables, flags, prime);
579 let right = if *s >= 0 {
580 BigUint::from_usize(*s as usize).unwrap()
581 } else {
582 prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
583 };
584 (left * right) % prime
585 }
586 SymbolicExpr::Select(flag_id, lhs, rhs) => {
587 if flags[*flag_id] {
588 lhs.compute(inputs, variables, flags, prime)
589 } else {
590 rhs.compute(inputs, variables, flags, prime)
591 }
592 }
593 };
594 assert!(
595 res < prime.clone(),
596 "symbolic expr: {self} evaluation exceeds prime"
597 );
598 res
599 }
600}