1use std::{cell::RefCell, cmp::min, iter, ops::Deref, rc::Rc};
2
3use itertools::{zip_eq, Itertools};
4use num_bigint::{BigInt, BigUint, Sign};
5use num_traits::{One, Zero};
6use openvm_circuit_primitives::{
7 bigint::{
8 check_carry_mod_to_zero::{CheckCarryModToZeroCols, CheckCarryModToZeroSubAir},
9 check_carry_to_zero::get_carry_max_abs_and_bits,
10 utils::*,
11 OverflowInt,
12 },
13 var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
14 SubAir, TraceSubRowGenerator,
15};
16use openvm_stark_backend::{
17 interaction::InteractionBuilder,
18 p3_air::{Air, AirBuilder, BaseAir},
19 p3_field::{Field, FieldAlgebra, PrimeField64},
20 p3_matrix::Matrix,
21 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
22};
23
24use super::{FieldVariable, SymbolicExpr};
25
26#[derive(Clone)]
27pub struct ExprBuilderConfig {
28 pub modulus: BigUint,
29 pub num_limbs: usize,
30 pub limb_bits: usize,
31}
32
33impl ExprBuilderConfig {
34 pub fn check_valid(&self) {
35 assert!(self.modulus.bits() <= (self.num_limbs * self.limb_bits) as u64);
36 }
37}
38
39#[derive(Clone)]
40pub struct ExprBuilder {
41 pub prime: BigUint,
43 pub prime_bigint: BigInt,
45 pub prime_limbs: Vec<usize>,
46
47 pub num_input: usize,
48 pub num_flags: usize,
49
50 pub num_variables: usize,
52
53 pub constants: Vec<(BigUint, Vec<usize>)>, pub limb_bits: usize,
57 pub num_limbs: usize,
59 proper_max: BigUint,
60 pub range_checker_bits: usize,
62 pub max_carry_bits: usize,
64
65 pub q_limbs: Vec<usize>,
67 pub carry_limbs: Vec<usize>,
69
70 pub constraints: Vec<SymbolicExpr>,
72
73 pub computes: Vec<SymbolicExpr>,
75
76 pub output_indices: Vec<usize>,
77
78 debug: bool,
80
81 finalized: bool,
84
85 needs_setup: bool,
91}
92
93const MODULUS_BITS: usize = 31;
95
96impl ExprBuilder {
97 pub fn new(config: ExprBuilderConfig, range_checker_bits: usize) -> Self {
98 let prime_bigint = BigInt::from_biguint(Sign::Plus, config.modulus.clone());
99 let proper_max = (BigUint::one() << (config.num_limbs * config.limb_bits)) - BigUint::one();
100 let max_carry_bits = MODULUS_BITS - config.limb_bits - 2;
102 assert!(config.limb_bits + 2 < MODULUS_BITS);
104 Self {
105 prime: config.modulus.clone(),
106 prime_bigint,
107 prime_limbs: big_uint_to_limbs(&config.modulus, config.limb_bits),
108 num_input: 0,
109 num_flags: 0,
110 limb_bits: config.limb_bits,
111 num_limbs: config.num_limbs,
112 proper_max,
113 range_checker_bits,
114 max_carry_bits: min(max_carry_bits, range_checker_bits),
115 num_variables: 0,
116 constants: vec![],
117 q_limbs: vec![],
118 carry_limbs: vec![],
119 constraints: vec![],
120 computes: vec![],
121 output_indices: vec![],
122 debug: false,
123 finalized: false,
124 needs_setup: false,
125 }
126 }
127
128 pub fn set_debug(&mut self) {
130 self.debug = true;
131 }
132
133 #[allow(unused)]
134 fn debug_print(&self, msg: &str) {
135 if self.debug {
136 println!("{}", msg);
137 }
138 }
139
140 pub fn is_finalized(&self) -> bool {
141 self.finalized
142 }
143
144 pub fn finalize(&mut self, needs_setup: bool) {
145 self.finalized = true;
146 self.needs_setup = needs_setup;
147
148 assert!(needs_setup || self.num_flags == 0);
150
151 if needs_setup && self.num_flags == 0 {
153 self.new_flag();
154 }
155 }
156
157 pub fn new_input(builder: Rc<RefCell<ExprBuilder>>) -> FieldVariable {
158 let mut borrowed = builder.borrow_mut();
159 let num_limbs = borrowed.num_limbs;
160 let limb_bits = borrowed.limb_bits;
161 borrowed.num_input += 1;
162 let (num_input, max_carry_bits) = (borrowed.num_input, borrowed.max_carry_bits);
163 drop(borrowed);
164 FieldVariable {
165 expr: SymbolicExpr::Input(num_input - 1),
166 builder: builder.clone(),
167 limb_max_abs: (1 << limb_bits) - 1,
168 max_overflow_bits: limb_bits,
169 expr_limbs: num_limbs,
170 max_carry_bits,
171 }
172 }
173
174 pub fn new_flag(&mut self) -> usize {
175 self.num_flags += 1;
176 self.num_flags - 1
177 }
178
179 pub fn needs_setup(&self) -> bool {
180 assert!(self.finalized); self.needs_setup
182 }
183
184 pub fn new_var(&mut self) -> (usize, SymbolicExpr) {
188 self.num_variables += 1;
189 self.constraints.push(SymbolicExpr::Input(0));
192 self.computes.push(SymbolicExpr::Input(0));
193 self.q_limbs.push(0);
194 self.carry_limbs.push(0);
195 (
196 self.num_variables - 1,
197 SymbolicExpr::Var(self.num_variables - 1),
198 )
199 }
200
201 pub fn new_const(builder: Rc<RefCell<ExprBuilder>>, value: BigUint) -> FieldVariable {
205 let mut borrowed = builder.borrow_mut();
206 let index = borrowed.constants.len();
207 let limb_bits = borrowed.limb_bits;
208 let num_limbs = borrowed.num_limbs;
209 let limbs = big_uint_to_num_limbs(&value, limb_bits, num_limbs);
210 let max_carry_bits = borrowed.max_carry_bits;
211 borrowed.constants.push((value.clone(), limbs));
212 drop(borrowed);
213
214 FieldVariable {
215 expr: SymbolicExpr::Const(index, value, num_limbs),
216 builder,
217 limb_max_abs: (1 << limb_bits) - 1,
218 max_overflow_bits: limb_bits,
219 expr_limbs: num_limbs,
220 max_carry_bits,
221 }
222 }
223
224 pub fn set_constraint(&mut self, index: usize, constraint: SymbolicExpr) {
225 let (q_limbs, carry_limbs) = constraint.constraint_limbs(
226 &self.prime,
227 self.limb_bits,
228 self.num_limbs,
229 &self.proper_max,
230 );
231 self.constraints[index] = constraint;
232 self.q_limbs[index] = q_limbs;
233 self.carry_limbs[index] = carry_limbs;
234 }
235
236 pub fn set_compute(&mut self, index: usize, compute: SymbolicExpr) {
237 self.computes[index] = compute;
238 }
239
240 pub fn proper_max(&self) -> &BigUint {
244 &self.proper_max
245 }
246}
247
248#[derive(Clone)]
249pub struct FieldExpr {
250 pub builder: ExprBuilder,
251
252 pub check_carry_mod_to_zero: CheckCarryModToZeroSubAir,
253
254 pub range_bus: VariableRangeCheckerBus,
255
256 pub setup_values: Vec<BigUint>,
258}
259
260impl FieldExpr {
261 pub fn new(
262 builder: ExprBuilder,
263 range_bus: VariableRangeCheckerBus,
264 needs_setup: bool,
265 ) -> Self {
266 let mut builder = builder;
267 builder.finalize(needs_setup);
268 let subair = CheckCarryModToZeroSubAir::new(
269 builder.prime.clone(),
270 builder.limb_bits,
271 range_bus.inner.index,
272 range_bus.range_max_bits,
273 );
274 FieldExpr {
275 builder,
276 check_carry_mod_to_zero: subair,
277 range_bus,
278 setup_values: vec![],
279 }
280 }
281
282 pub fn new_with_setup_values(
283 builder: ExprBuilder,
284 range_bus: VariableRangeCheckerBus,
285 needs_setup: bool,
286 setup_values: Vec<BigUint>,
287 ) -> Self {
288 let mut ret = Self::new(builder, range_bus, needs_setup);
289 ret.setup_values = setup_values;
290 ret
291 }
292
293 pub fn num_inputs(&self) -> usize {
294 self.builder.num_input
295 }
296
297 pub fn num_vars(&self) -> usize {
298 self.builder.num_variables
299 }
300
301 pub fn num_flags(&self) -> usize {
302 self.builder.num_flags
303 }
304
305 pub fn output_indices(&self) -> &[usize] {
306 &self.builder.output_indices
307 }
308}
309
310impl Deref for FieldExpr {
311 type Target = ExprBuilder;
312
313 fn deref(&self) -> &ExprBuilder {
314 &self.builder
315 }
316}
317
318impl<F: Field> BaseAirWithPublicValues<F> for FieldExpr {}
319impl<F: Field> PartitionedBaseAir<F> for FieldExpr {}
320impl<F: Field> BaseAir<F> for FieldExpr {
321 fn width(&self) -> usize {
322 assert!(self.builder.is_finalized());
323 self.num_limbs * (self.builder.num_input + self.builder.num_variables)
324 + self.builder.q_limbs.iter().sum::<usize>()
325 + self.builder.carry_limbs.iter().sum::<usize>()
326 + self.builder.num_flags
327 + 1 }
329}
330
331impl<AB: InteractionBuilder> Air<AB> for FieldExpr {
332 fn eval(&self, builder: &mut AB) {
333 let main = builder.main();
334 let local = main.row_slice(0);
335 SubAir::eval(self, builder, &local);
336 }
337}
338
339impl<AB: InteractionBuilder> SubAir<AB> for FieldExpr {
340 type AirContext<'a>
342 = &'a [AB::Var]
343 where
344 AB: 'a,
345 AB::Var: 'a,
346 AB::Expr: 'a;
347
348 fn eval<'a>(&'a self, builder: &'a mut AB, local: &'a [AB::Var])
349 where
350 AB::Var: 'a,
351 AB::Expr: 'a,
352 {
353 assert!(self.builder.is_finalized());
354 let FieldExprCols {
355 is_valid,
356 inputs,
357 vars,
358 q_limbs,
359 carry_limbs,
360 flags,
361 } = self.load_vars(local);
362
363 builder.assert_bool(is_valid);
364
365 if self.builder.needs_setup() {
366 let is_setup = flags.iter().fold(is_valid.into(), |acc, &x| acc - x);
367 builder.assert_bool(is_setup.clone());
368 let expected = iter::empty()
375 .chain({
376 let mut prime_limbs = self.builder.prime_limbs.clone();
377 prime_limbs.resize(self.builder.num_limbs, 0);
378 prime_limbs
379 })
380 .chain(self.setup_values.iter().flat_map(|x| {
381 big_uint_to_num_limbs(x, self.builder.limb_bits, self.builder.num_limbs)
382 .into_iter()
383 }))
384 .collect_vec();
385
386 let reads: Vec<AB::Expr> = inputs
387 .clone()
388 .into_iter()
389 .flatten()
390 .map(Into::into)
391 .take(expected.len())
392 .collect();
393
394 for (lhs, rhs) in zip_eq(&reads, expected) {
395 builder
396 .when(is_setup.clone())
397 .assert_eq(lhs.clone(), AB::F::from_canonical_usize(rhs));
398 }
399 }
400
401 let inputs = load_overflow::<AB>(inputs, self.limb_bits);
402 let vars = load_overflow::<AB>(vars, self.limb_bits);
403 let constants: Vec<_> = self
404 .constants
405 .iter()
406 .map(|(_, limbs)| {
407 let limbs_expr: Vec<_> = limbs
408 .iter()
409 .map(|limb| AB::Expr::from_canonical_usize(*limb))
410 .collect();
411 OverflowInt::from_canonical_unsigned_limbs(limbs_expr, self.limb_bits)
412 })
413 .collect();
414
415 for flag in flags.iter() {
416 builder.assert_bool(*flag);
417 }
418 for i in 0..self.constraints.len() {
419 let expr = self.constraints[i]
420 .evaluate_overflow_expr::<AB>(&inputs, &vars, &constants, &flags);
421 self.check_carry_mod_to_zero.eval(
422 builder,
423 (
424 expr,
425 CheckCarryModToZeroCols {
426 carries: carry_limbs[i].clone(),
427 quotient: q_limbs[i].clone(),
428 },
429 is_valid.into(),
430 ),
431 );
432 }
433
434 for var in vars.iter() {
435 for limb in var.limbs().iter() {
436 range_check(
437 builder,
438 self.range_bus.inner.index,
439 self.range_bus.range_max_bits,
440 self.limb_bits,
441 limb.clone(),
442 is_valid,
443 );
444 }
445 }
446 }
447}
448
449type Vecs<T> = Vec<Vec<T>>;
450
451pub struct FieldExprCols<T> {
452 pub is_valid: T,
453 pub inputs: Vecs<T>,
454 pub vars: Vecs<T>,
455 pub q_limbs: Vecs<T>,
456 pub carry_limbs: Vecs<T>,
457 pub flags: Vec<T>,
458}
459
460impl<F: PrimeField64> TraceSubRowGenerator<F> for FieldExpr {
461 type TraceContext<'a> = (&'a VariableRangeCheckerChip, Vec<BigUint>, Vec<bool>);
462 type ColsMut<'a> = &'a mut [F];
463
464 fn generate_subrow<'a>(
465 &'a self,
466 (range_checker, inputs, flags): (&'a VariableRangeCheckerChip, Vec<BigUint>, Vec<bool>),
467 sub_row: &'a mut [F],
468 ) {
469 assert!(self.builder.is_finalized());
470 assert_eq!(inputs.len(), self.num_input);
471 assert_eq!(self.num_variables, self.constraints.len());
472
473 assert_eq!(flags.len(), self.builder.num_flags);
474
475 let limb_bits = self.limb_bits;
476 let mut vars = vec![BigUint::zero(); self.num_variables];
477
478 let input_bigint = inputs
480 .iter()
481 .map(|x| BigInt::from_biguint(Sign::Plus, x.clone()))
482 .collect::<Vec<BigInt>>();
483 let mut vars_bigint = vec![BigInt::zero(); self.num_variables];
484
485 let input_overflow = inputs
487 .iter()
488 .map(|x| OverflowInt::<isize>::from_biguint(x, self.limb_bits, Some(self.num_limbs)))
489 .collect::<Vec<_>>();
490 let zero = OverflowInt::<isize>::from_canonical_unsigned_limbs(vec![0], limb_bits);
491 let mut vars_overflow = vec![zero; self.num_variables];
492 let prime_overflow = OverflowInt::<isize>::from_biguint(&self.prime, self.limb_bits, None);
495
496 let constants: Vec<_> = self
497 .constants
498 .iter()
499 .map(|(_, limbs)| {
500 let limbs_isize: Vec<_> = limbs.iter().map(|i| *i as isize).collect();
501 OverflowInt::from_canonical_unsigned_limbs(limbs_isize, self.limb_bits)
502 })
503 .collect();
504
505 let mut all_q = vec![];
506 let mut all_carry = vec![];
507 for i in 0..self.constraints.len() {
508 let r = self.computes[i].compute(&inputs, &vars, &flags, &self.prime);
509 vars[i] = r.clone();
510 vars_bigint[i] = BigInt::from_biguint(Sign::Plus, r);
511 vars_overflow[i] =
512 OverflowInt::<isize>::from_biguint(&vars[i], self.limb_bits, Some(self.num_limbs));
513 }
514 for i in 0..self.constraints.len() {
517 let expr_bigint =
519 self.constraints[i].evaluate_bigint(&input_bigint, &vars_bigint, &flags);
520 let q = &expr_bigint / &self.prime_bigint;
521 debug_assert_eq!(expr_bigint, &q * &self.prime_bigint);
523 let q_limbs = big_int_to_num_limbs(&q, limb_bits, self.q_limbs[i]);
524 assert_eq!(q_limbs.len(), self.q_limbs[i]); for &q in q_limbs.iter() {
526 range_checker.add_count((q + (1 << limb_bits)) as u32, limb_bits + 1);
527 }
528 let q_overflow = OverflowInt::from_canonical_signed_limbs(q_limbs.clone(), limb_bits);
529 let expr = self.constraints[i].evaluate_overflow_isize(
531 &input_overflow,
532 &vars_overflow,
533 &constants,
534 &flags,
535 );
536 let expr = expr - q_overflow * prime_overflow.clone();
537 let carries = expr.calculate_carries(limb_bits);
538 assert_eq!(carries.len(), self.carry_limbs[i]); let max_overflow_bits = expr.max_overflow_bits();
540 let (carry_min_abs, carry_bits) =
541 get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
542 for &carry in carries.iter() {
543 range_checker.add_count((carry + carry_min_abs as isize) as u32, carry_bits);
544 }
545 all_q.push(vec_isize_to_f::<F>(q_limbs));
546 all_carry.push(vec_isize_to_f::<F>(carries));
547 }
548 for var in vars_overflow.iter() {
549 for limb in var.limbs().iter() {
550 range_checker.add_count(*limb as u32, limb_bits);
551 }
552 }
553
554 let input_limbs = input_overflow
555 .iter()
556 .map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
557 .collect::<Vec<_>>();
558 let vars_limbs = vars_overflow
559 .iter()
560 .map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
561 .collect::<Vec<_>>();
562
563 sub_row.copy_from_slice(
564 &[
565 vec![F::ONE],
566 input_limbs.concat(),
567 vars_limbs.concat(),
568 all_q.concat(),
569 all_carry.concat(),
570 flags.iter().map(|x| F::from_bool(*x)).collect::<Vec<_>>(),
571 ]
572 .concat(),
573 );
574 }
575}
576
577impl FieldExpr {
578 pub fn canonical_num_limbs(&self) -> usize {
579 self.builder.num_limbs
580 }
581
582 pub fn canonical_limb_bits(&self) -> usize {
583 self.builder.limb_bits
584 }
585
586 pub fn execute(&self, inputs: Vec<BigUint>, flags: Vec<bool>) -> Vec<BigUint> {
587 assert!(self.builder.is_finalized());
588
589 #[cfg(debug_assertions)]
590 {
591 let is_setup = self.builder.needs_setup() && flags.iter().all(|&x| !x);
592 if is_setup {
593 assert_eq!(inputs[0], self.builder.prime);
594 assert!(inputs.len() > self.setup_values.len());
596 for (expected, actual) in self.setup_values.iter().zip(inputs.iter().skip(1)) {
597 assert_eq!(expected, actual);
598 }
599 }
600 }
601
602 let mut vars = vec![BigUint::zero(); self.num_variables];
603 for i in 0..self.constraints.len() {
604 let r = self.computes[i].compute(&inputs, &vars, &flags, &self.prime);
605 vars[i] = r.clone();
606 }
607 vars
608 }
609
610 pub fn execute_with_output(&self, inputs: Vec<BigUint>, flags: Vec<bool>) -> Vec<BigUint> {
611 let vars = self.execute(inputs, flags);
612 self.builder
613 .output_indices
614 .iter()
615 .map(|i| vars[*i].clone())
616 .collect()
617 }
618
619 pub fn load_vars<T: Clone>(&self, arr: &[T]) -> FieldExprCols<T> {
620 assert!(self.builder.is_finalized());
621 let is_valid = arr[0].clone();
622 let mut idx = 1;
623 let mut inputs = vec![];
624 for _ in 0..self.num_input {
625 inputs.push(arr[idx..idx + self.num_limbs].to_vec());
626 idx += self.num_limbs;
627 }
628 let mut vars = vec![];
629 for _ in 0..self.num_variables {
630 vars.push(arr[idx..idx + self.num_limbs].to_vec());
631 idx += self.num_limbs;
632 }
633 let mut q_limbs = vec![];
634 for q in self.q_limbs.iter() {
635 q_limbs.push(arr[idx..idx + q].to_vec());
636 idx += q;
637 }
638 let mut carry_limbs = vec![];
639 for c in self.carry_limbs.iter() {
640 carry_limbs.push(arr[idx..idx + c].to_vec());
641 idx += c;
642 }
643 let flags = arr[idx..idx + self.num_flags].to_vec();
644 FieldExprCols {
645 is_valid,
646 inputs,
647 vars,
648 q_limbs,
649 carry_limbs,
650 flags,
651 }
652 }
653}
654
655fn load_overflow<AB: AirBuilder>(
656 arr: Vecs<AB::Var>,
657 limb_bits: usize,
658) -> Vec<OverflowInt<AB::Expr>> {
659 let mut result = vec![];
660 for x in arr.into_iter() {
661 let limbs: Vec<AB::Expr> = x.iter().map(|x| (*x).into()).collect();
662 result.push(OverflowInt::<AB::Expr>::from_canonical_unsigned_limbs(
663 limbs, limb_bits,
664 ));
665 }
666 result
667}