1use std::{
2 cmp::{max, min},
3 convert::identity,
4 iter::repeat,
5 ops::{Add, Div, Mul, Sub},
6};
7
8use num_bigint::{BigInt, BigUint, Sign};
9use num_traits::{FromPrimitive, One, Zero};
10use openvm_circuit_primitives::bigint::{
11 check_carry_to_zero::get_carry_max_abs_and_bits, OverflowInt,
12};
13use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra, p3_util::log2_ceil_usize};
14
15#[derive(Clone, Debug, PartialEq)]
19pub enum SymbolicExpr {
20 Input(usize),
21 Var(usize),
22 Const(usize, BigUint, usize), Add(Box<SymbolicExpr>, Box<SymbolicExpr>),
24 Sub(Box<SymbolicExpr>, Box<SymbolicExpr>),
25 Mul(Box<SymbolicExpr>, Box<SymbolicExpr>),
26 Div(Box<SymbolicExpr>, Box<SymbolicExpr>),
29 IntAdd(Box<SymbolicExpr>, isize),
31 IntMul(Box<SymbolicExpr>, isize),
33 Select(usize, Box<SymbolicExpr>, Box<SymbolicExpr>),
36}
37
38impl std::fmt::Display for SymbolicExpr {
39 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
40 match self {
41 SymbolicExpr::Input(i) => write!(f, "Input_{}", i),
42 SymbolicExpr::Var(i) => write!(f, "Var_{}", i),
43 SymbolicExpr::Const(i, _, _) => write!(f, "Const_{}", i),
44 SymbolicExpr::Add(lhs, rhs) => write!(f, "({} + {})", lhs, rhs),
45 SymbolicExpr::Sub(lhs, rhs) => write!(f, "({} - {})", lhs, rhs),
46 SymbolicExpr::Mul(lhs, rhs) => write!(f, "{} * {}", lhs, rhs),
47 SymbolicExpr::Div(lhs, rhs) => write!(f, "({} / {})", lhs, rhs),
48 SymbolicExpr::IntAdd(lhs, s) => write!(f, "({} + {})", lhs, s),
49 SymbolicExpr::IntMul(lhs, s) => write!(f, "({} x {})", lhs, s),
50 SymbolicExpr::Select(flag_id, lhs, rhs) => {
51 write!(f, "(if {} then {} else {})", flag_id, lhs, rhs)
52 }
53 }
54 }
55}
56
57impl Add for SymbolicExpr {
58 type Output = SymbolicExpr;
59
60 fn add(self, rhs: Self) -> Self::Output {
61 SymbolicExpr::Add(Box::new(self), Box::new(rhs))
62 }
63}
64
65impl Add<&SymbolicExpr> for SymbolicExpr {
66 type Output = SymbolicExpr;
67
68 fn add(self, rhs: &SymbolicExpr) -> Self::Output {
69 SymbolicExpr::Add(Box::new(self), Box::new(rhs.clone()))
70 }
71}
72
73impl Add for &SymbolicExpr {
74 type Output = SymbolicExpr;
75
76 fn add(self, rhs: &SymbolicExpr) -> Self::Output {
77 SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs.clone()))
78 }
79}
80
81impl Add<SymbolicExpr> for &SymbolicExpr {
82 type Output = SymbolicExpr;
83
84 fn add(self, rhs: SymbolicExpr) -> Self::Output {
85 SymbolicExpr::Add(Box::new(self.clone()), Box::new(rhs))
86 }
87}
88
89impl Sub for SymbolicExpr {
90 type Output = SymbolicExpr;
91
92 fn sub(self, rhs: Self) -> Self::Output {
93 SymbolicExpr::Sub(Box::new(self), Box::new(rhs))
94 }
95}
96
97impl Sub<&SymbolicExpr> for SymbolicExpr {
98 type Output = SymbolicExpr;
99
100 fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
101 SymbolicExpr::Sub(Box::new(self), Box::new(rhs.clone()))
102 }
103}
104
105impl Sub for &SymbolicExpr {
106 type Output = SymbolicExpr;
107
108 fn sub(self, rhs: &SymbolicExpr) -> Self::Output {
109 SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs.clone()))
110 }
111}
112
113impl Sub<SymbolicExpr> for &SymbolicExpr {
114 type Output = SymbolicExpr;
115
116 fn sub(self, rhs: SymbolicExpr) -> Self::Output {
117 SymbolicExpr::Sub(Box::new(self.clone()), Box::new(rhs))
118 }
119}
120
121impl Mul for SymbolicExpr {
122 type Output = SymbolicExpr;
123
124 fn mul(self, rhs: Self) -> Self::Output {
125 SymbolicExpr::Mul(Box::new(self), Box::new(rhs))
126 }
127}
128
129impl Mul<&SymbolicExpr> for SymbolicExpr {
130 type Output = SymbolicExpr;
131
132 fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
133 SymbolicExpr::Mul(Box::new(self), Box::new(rhs.clone()))
134 }
135}
136
137impl Mul for &SymbolicExpr {
138 type Output = SymbolicExpr;
139
140 fn mul(self, rhs: &SymbolicExpr) -> Self::Output {
141 SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs.clone()))
142 }
143}
144
145impl Mul<SymbolicExpr> for &SymbolicExpr {
146 type Output = SymbolicExpr;
147
148 fn mul(self, rhs: SymbolicExpr) -> Self::Output {
149 SymbolicExpr::Mul(Box::new(self.clone()), Box::new(rhs))
150 }
151}
152
153impl Div for SymbolicExpr {
155 type Output = SymbolicExpr;
156
157 fn div(self, rhs: Self) -> Self::Output {
158 SymbolicExpr::Div(Box::new(self), Box::new(rhs))
159 }
160}
161
162impl Div<&SymbolicExpr> for SymbolicExpr {
164 type Output = SymbolicExpr;
165
166 fn div(self, rhs: &SymbolicExpr) -> Self::Output {
167 SymbolicExpr::Div(Box::new(self), Box::new(rhs.clone()))
168 }
169}
170
171impl Div for &SymbolicExpr {
173 type Output = SymbolicExpr;
174
175 fn div(self, rhs: &SymbolicExpr) -> Self::Output {
176 SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs.clone()))
177 }
178}
179
180impl Div<SymbolicExpr> for &SymbolicExpr {
182 type Output = SymbolicExpr;
183
184 fn div(self, rhs: SymbolicExpr) -> Self::Output {
185 SymbolicExpr::Div(Box::new(self.clone()), Box::new(rhs))
186 }
187}
188
189impl SymbolicExpr {
190 fn max_abs(&self, proper_max: &BigUint) -> (BigUint, BigUint) {
197 match self {
198 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => (proper_max.clone(), BigUint::zero()),
199 SymbolicExpr::Const(_, val, _) => (val.clone(), BigUint::zero()),
200 SymbolicExpr::Add(lhs, rhs) => {
201 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
202 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
203 (lhs_max_pos + rhs_max_pos, lhs_max_neg + rhs_max_neg)
204 }
205 SymbolicExpr::Sub(lhs, rhs) => {
206 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
207 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
208 (lhs_max_pos + rhs_max_neg, lhs_max_neg + rhs_max_pos)
209 }
210 SymbolicExpr::Mul(lhs, rhs) => {
211 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
212 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
213 (
214 max(&lhs_max_pos * &rhs_max_pos, &lhs_max_neg * &rhs_max_neg),
215 max(&lhs_max_pos * &rhs_max_neg, &lhs_max_neg * &rhs_max_pos),
216 )
217 }
218 SymbolicExpr::Div(_, _) => {
219 unreachable!()
221 }
222 SymbolicExpr::IntAdd(lhs, s) => {
223 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
224 let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
225 (lhs_max_pos + &scalar, lhs_max_neg + &scalar)
227 }
228 SymbolicExpr::IntMul(lhs, s) => {
229 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
230 let scalar = BigUint::from_usize(s.unsigned_abs()).unwrap();
231 if *s < 0 {
232 (lhs_max_neg * &scalar, lhs_max_pos * &scalar)
233 } else {
234 (lhs_max_pos * &scalar, lhs_max_neg * &scalar)
235 }
236 }
237 SymbolicExpr::Select(_, lhs, rhs) => {
238 let (lhs_max_pos, lhs_max_neg) = lhs.max_abs(proper_max);
239 let (rhs_max_pos, rhs_max_neg) = rhs.max_abs(proper_max);
240 (max(lhs_max_pos, rhs_max_pos), max(lhs_max_neg, rhs_max_neg))
241 }
242 }
243 }
244
245 pub fn constraint_limb_max_abs(&self, limb_bits: usize, num_limbs: usize) -> usize {
250 let canonical_limb_max_abs = (1 << limb_bits) - 1;
251 match self {
252 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) | SymbolicExpr::Const(_, _, _) => {
253 canonical_limb_max_abs
254 }
255 SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
256 lhs.constraint_limb_max_abs(limb_bits, num_limbs)
257 + rhs.constraint_limb_max_abs(limb_bits, num_limbs)
258 }
259 SymbolicExpr::Mul(lhs, rhs) => {
260 let left_num_limbs = lhs.expr_limbs(num_limbs);
261 let right_num_limbs = rhs.expr_limbs(num_limbs);
262 lhs.constraint_limb_max_abs(limb_bits, num_limbs)
263 * rhs.constraint_limb_max_abs(limb_bits, num_limbs)
264 * min(left_num_limbs, right_num_limbs)
265 }
266 SymbolicExpr::IntAdd(lhs, i) => {
267 lhs.constraint_limb_max_abs(limb_bits, num_limbs) + i.unsigned_abs()
268 }
269 SymbolicExpr::IntMul(lhs, i) => {
270 lhs.constraint_limb_max_abs(limb_bits, num_limbs) * i.unsigned_abs()
271 }
272 SymbolicExpr::Select(_, lhs, rhs) => max(
273 lhs.constraint_limb_max_abs(limb_bits, num_limbs),
274 rhs.constraint_limb_max_abs(limb_bits, num_limbs),
275 ),
276 SymbolicExpr::Div(_, _) => {
277 unreachable!("should not have division when calling limb_max_abs")
278 }
279 }
280 }
281
282 pub fn constraint_carry_bits_with_pq(
287 &self,
288 prime: &BigUint,
289 limb_bits: usize,
290 num_limbs: usize,
291 proper_max: &BigUint,
292 ) -> usize {
293 let without_pq = self.constraint_limb_max_abs(limb_bits, num_limbs);
294 let (q_limbs, _) = self.constraint_limbs(prime, limb_bits, num_limbs, proper_max);
295 let canonical_limb_max_abs = (1 << limb_bits) - 1;
296 let limb_max_abs =
297 without_pq + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, num_limbs);
298 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
299 let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
300 carry_bits
301 }
302
303 pub fn expr_limbs(&self, num_limbs: usize) -> usize {
306 match self {
307 SymbolicExpr::Input(_) | SymbolicExpr::Var(_) => num_limbs,
308 SymbolicExpr::Const(_, _, limbs) => *limbs,
309 SymbolicExpr::Add(lhs, rhs) | SymbolicExpr::Sub(lhs, rhs) => {
310 max(lhs.expr_limbs(num_limbs), rhs.expr_limbs(num_limbs))
311 }
312 SymbolicExpr::Mul(lhs, rhs) => {
313 lhs.expr_limbs(num_limbs) + rhs.expr_limbs(num_limbs) - 1
314 }
315 SymbolicExpr::Div(_, _) => {
316 unimplemented!()
317 }
318 SymbolicExpr::IntAdd(lhs, _) => lhs.expr_limbs(num_limbs),
319 SymbolicExpr::IntMul(lhs, _) => lhs.expr_limbs(num_limbs),
320 SymbolicExpr::Select(_, lhs, rhs) => {
321 let left = lhs.expr_limbs(num_limbs);
322 let right = rhs.expr_limbs(num_limbs);
323 assert_eq!(left, right);
324 left
325 }
326 }
327 }
328
329 pub fn constraint_limbs(
336 &self,
337 prime: &BigUint,
338 limb_bits: usize,
339 num_limbs: usize,
340 proper_max: &BigUint,
341 ) -> (usize, usize) {
342 let (max_pos_abs, max_neg_abs) = self.max_abs(proper_max);
343 let max_abs = max(max_pos_abs, max_neg_abs);
344 let max_q_abs = (&max_abs + prime - BigUint::one()) / prime;
345 let q_bits = max_q_abs.bits() as usize;
346 let p_bits = prime.bits() as usize;
347 let q_limbs = q_bits.div_ceil(limb_bits);
348 let p_limbs = p_bits.div_ceil(limb_bits);
350 let qp_limbs = q_limbs + p_limbs - 1;
351
352 let expr_limbs = self.expr_limbs(num_limbs);
353 let carry_limbs = max(expr_limbs, qp_limbs);
354 (q_limbs, carry_limbs)
355 }
356
357 pub fn evaluate_bigint(
360 &self,
361 inputs: &[BigInt],
362 variables: &[BigInt],
363 flags: &[bool],
364 ) -> BigInt {
365 match self {
366 SymbolicExpr::IntAdd(lhs, s) => {
367 lhs.evaluate_bigint(inputs, variables, flags) + BigInt::from_isize(*s).unwrap()
368 }
369 SymbolicExpr::IntMul(lhs, s) => {
370 lhs.evaluate_bigint(inputs, variables, flags) * BigInt::from_isize(*s).unwrap()
371 }
372 SymbolicExpr::Input(i) => inputs[*i].clone(),
373 SymbolicExpr::Var(i) => variables[*i].clone(),
374 SymbolicExpr::Const(_, val, _) => {
375 if val.is_zero() {
376 BigInt::zero()
377 } else {
378 BigInt::from_biguint(Sign::Plus, val.clone())
379 }
380 }
381 SymbolicExpr::Add(lhs, rhs) => {
382 lhs.evaluate_bigint(inputs, variables, flags)
383 + rhs.evaluate_bigint(inputs, variables, flags)
384 }
385 SymbolicExpr::Sub(lhs, rhs) => {
386 lhs.evaluate_bigint(inputs, variables, flags)
387 - rhs.evaluate_bigint(inputs, variables, flags)
388 }
389 SymbolicExpr::Mul(lhs, rhs) => {
390 lhs.evaluate_bigint(inputs, variables, flags)
391 * rhs.evaluate_bigint(inputs, variables, flags)
392 }
393 SymbolicExpr::Select(flag_id, lhs, rhs) => {
394 if flags[*flag_id] {
395 lhs.evaluate_bigint(inputs, variables, flags)
396 } else {
397 rhs.evaluate_bigint(inputs, variables, flags)
398 }
399 }
400 SymbolicExpr::Div(_, _) => unreachable!(), }
402 }
403
404 pub fn evaluate_overflow_isize(
407 &self,
408 inputs: &[OverflowInt<isize>],
409 variables: &[OverflowInt<isize>],
410 constants: &[OverflowInt<isize>],
411 flags: &[bool],
412 ) -> OverflowInt<isize> {
413 match self {
414 SymbolicExpr::IntAdd(lhs, s) => {
415 let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
416 left.int_add(*s, identity)
417 }
418 SymbolicExpr::IntMul(lhs, s) => {
419 let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
420 left.int_mul(*s, identity)
421 }
422 SymbolicExpr::Input(i) => inputs[*i].clone(),
423 SymbolicExpr::Var(i) => variables[*i].clone(),
424 SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
425 SymbolicExpr::Add(lhs, rhs) => {
426 lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
427 + rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
428 }
429 SymbolicExpr::Sub(lhs, rhs) => {
430 lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
431 - rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
432 }
433 SymbolicExpr::Mul(lhs, rhs) => {
434 lhs.evaluate_overflow_isize(inputs, variables, constants, flags)
435 * rhs.evaluate_overflow_isize(inputs, variables, constants, flags)
436 }
437 SymbolicExpr::Select(flag_id, lhs, rhs) => {
438 let left = lhs.evaluate_overflow_isize(inputs, variables, constants, flags);
439 let right = rhs.evaluate_overflow_isize(inputs, variables, constants, flags);
440 let num_limbs = max(left.num_limbs(), right.num_limbs());
441
442 let res = if flags[*flag_id] {
443 left.limbs().to_vec()
444 } else {
445 right.limbs().to_vec()
446 };
447 let res = res.into_iter().chain(repeat(0)).take(num_limbs).collect();
448
449 OverflowInt::from_computed_limbs(
450 res,
451 max(left.limb_max_abs(), right.limb_max_abs()),
452 max(left.max_overflow_bits(), right.max_overflow_bits()),
453 )
454 }
455 SymbolicExpr::Div(_, _) => unreachable!(), }
457 }
458
459 fn isize_to_expr<AB: AirBuilder>(s: isize) -> AB::Expr {
460 if s >= 0 {
461 AB::Expr::from_canonical_usize(s as usize)
462 } else {
463 -AB::Expr::from_canonical_usize(s.unsigned_abs())
464 }
465 }
466
467 pub fn evaluate_overflow_expr<AB: AirBuilder>(
470 &self,
471 inputs: &[OverflowInt<AB::Expr>],
472 variables: &[OverflowInt<AB::Expr>],
473 constants: &[OverflowInt<AB::Expr>],
474 flags: &[AB::Var],
475 ) -> OverflowInt<AB::Expr> {
476 match self {
477 SymbolicExpr::IntAdd(lhs, s) => {
478 let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
479 left.int_add(*s, Self::isize_to_expr::<AB>)
480 }
481 SymbolicExpr::IntMul(lhs, s) => {
482 let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
483 left.int_mul(*s, Self::isize_to_expr::<AB>)
484 }
485 SymbolicExpr::Input(i) => inputs[*i].clone(),
486 SymbolicExpr::Var(i) => variables[*i].clone(),
487 SymbolicExpr::Const(i, _, _) => constants[*i].clone(),
488 SymbolicExpr::Add(lhs, rhs) => {
489 lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
490 + rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
491 }
492 SymbolicExpr::Sub(lhs, rhs) => {
493 lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
494 - rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
495 }
496 SymbolicExpr::Mul(lhs, rhs) => {
497 lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
498 * rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags)
499 }
500 SymbolicExpr::Select(flag_id, lhs, rhs) => {
501 let left = lhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
502 let right = rhs.evaluate_overflow_expr::<AB>(inputs, variables, constants, flags);
503 let num_limbs = max(left.num_limbs(), right.num_limbs());
504 let flag = flags[*flag_id];
505 let mut res = vec![];
506 for i in 0..num_limbs {
507 res.push(
508 (if i < left.num_limbs() {
509 left.limb(i).clone()
510 } else {
511 AB::Expr::ZERO
512 }) * flag.into()
513 + (if i < right.num_limbs() {
514 right.limb(i).clone()
515 } else {
516 AB::Expr::ZERO
517 }) * (AB::Expr::ONE - flag.into()),
518 );
519 }
520 OverflowInt::from_computed_limbs(
521 res,
522 max(left.limb_max_abs(), right.limb_max_abs()),
523 max(left.max_overflow_bits(), right.max_overflow_bits()),
524 )
525 }
526 SymbolicExpr::Div(_, _) => unreachable!(), }
528 }
529
530 pub fn compute(
534 &self,
535 inputs: &[BigUint],
536 variables: &[BigUint],
537 flags: &[bool],
538 prime: &BigUint,
539 ) -> BigUint {
540 let res = match self {
541 SymbolicExpr::Input(i) => inputs[*i].clone() % prime,
542 SymbolicExpr::Var(i) => variables[*i].clone(),
543 SymbolicExpr::Const(_, val, _) => val.clone(),
544 SymbolicExpr::Add(lhs, rhs) => {
545 (lhs.compute(inputs, variables, flags, prime)
546 + rhs.compute(inputs, variables, flags, prime))
547 % prime
548 }
549 SymbolicExpr::Sub(lhs, rhs) => {
550 (prime + lhs.compute(inputs, variables, flags, prime)
551 - rhs.compute(inputs, variables, flags, prime))
552 % prime
553 }
554 SymbolicExpr::Mul(lhs, rhs) => {
555 (lhs.compute(inputs, variables, flags, prime)
556 * rhs.compute(inputs, variables, flags, prime))
557 % prime
558 }
559 SymbolicExpr::Div(lhs, rhs) => {
560 let left = lhs.compute(inputs, variables, flags, prime);
561 let right = rhs.compute(inputs, variables, flags, prime);
562 let right_inv = right.modinv(prime).unwrap();
563 (left * right_inv) % prime
564 }
565 SymbolicExpr::IntAdd(lhs, s) => {
566 let left = lhs.compute(inputs, variables, flags, prime);
567 let right = if *s >= 0 {
568 BigUint::from_usize(*s as usize).unwrap()
569 } else {
570 prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
571 };
572 (left + right) % prime
573 }
574 SymbolicExpr::IntMul(lhs, s) => {
575 let left = lhs.compute(inputs, variables, flags, prime);
576 let right = if *s >= 0 {
577 BigUint::from_usize(*s as usize).unwrap()
578 } else {
579 prime - BigUint::from_usize(s.unsigned_abs()).unwrap()
580 };
581 (left * right) % prime
582 }
583 SymbolicExpr::Select(flag_id, lhs, rhs) => {
584 if flags[*flag_id] {
585 lhs.compute(inputs, variables, flags, prime)
586 } else {
587 rhs.compute(inputs, variables, flags, prime)
588 }
589 }
590 };
591 assert!(
592 res < prime.clone(),
593 "symbolic expr: {} evaluation exceeds prime",
594 self
595 );
596 res
597 }
598}