1use std::{cell::RefCell, cmp::min, iter, ops::Deref, rc::Rc};
2
3use itertools::{zip_eq, Itertools};
4use num_bigint::{BigInt, BigUint, Sign};
5use num_traits::{One, Zero};
6use openvm_circuit_primitives::{
7 bigint::{
8 check_carry_mod_to_zero::{CheckCarryModToZeroCols, CheckCarryModToZeroSubAir},
9 check_carry_to_zero::get_carry_max_abs_and_bits,
10 utils::*,
11 OverflowInt,
12 },
13 var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
14 SubAir, TraceSubRowGenerator,
15};
16use openvm_stark_backend::{
17 interaction::InteractionBuilder,
18 p3_air::{Air, AirBuilder, BaseAir},
19 p3_field::{Field, FieldAlgebra, PrimeField64},
20 p3_matrix::Matrix,
21 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
22};
23
24use super::{FieldVariable, SymbolicExpr};
25
26#[derive(Clone)]
27pub struct ExprBuilderConfig {
28 pub modulus: BigUint,
29 pub num_limbs: usize,
30 pub limb_bits: usize,
31}
32
33impl ExprBuilderConfig {
34 pub fn check_valid(&self) {
35 assert!(self.modulus.bits() <= (self.num_limbs * self.limb_bits) as u64);
36 }
37}
38
39#[derive(Clone)]
40pub struct ExprBuilder {
41 pub prime: BigUint,
43 pub prime_bigint: BigInt,
45 pub prime_limbs: Vec<usize>,
46
47 pub num_input: usize,
48 pub num_flags: usize,
49
50 pub num_variables: usize,
52
53 pub constants: Vec<(BigUint, Vec<usize>)>, pub limb_bits: usize,
57 pub num_limbs: usize,
59 proper_max: BigUint,
60 pub range_checker_bits: usize,
62 pub max_carry_bits: usize,
64
65 pub q_limbs: Vec<usize>,
67 pub carry_limbs: Vec<usize>,
69
70 pub constraints: Vec<SymbolicExpr>,
72
73 pub computes: Vec<SymbolicExpr>,
75
76 pub output_indices: Vec<usize>,
77
78 debug: bool,
80
81 finalized: bool,
83
84 needs_setup: bool,
90}
91
92const MODULUS_BITS: usize = 31;
94
95impl ExprBuilder {
96 pub fn new(config: ExprBuilderConfig, range_checker_bits: usize) -> Self {
97 let prime_bigint = BigInt::from_biguint(Sign::Plus, config.modulus.clone());
98 let proper_max = (BigUint::one() << (config.num_limbs * config.limb_bits)) - BigUint::one();
99 let max_carry_bits = MODULUS_BITS - config.limb_bits - 2;
101 assert!(config.limb_bits + 2 < MODULUS_BITS);
103 Self {
104 prime: config.modulus.clone(),
105 prime_bigint,
106 prime_limbs: big_uint_to_limbs(&config.modulus, config.limb_bits),
107 num_input: 0,
108 num_flags: 0,
109 limb_bits: config.limb_bits,
110 num_limbs: config.num_limbs,
111 proper_max,
112 range_checker_bits,
113 max_carry_bits: min(max_carry_bits, range_checker_bits),
114 num_variables: 0,
115 constants: vec![],
116 q_limbs: vec![],
117 carry_limbs: vec![],
118 constraints: vec![],
119 computes: vec![],
120 output_indices: vec![],
121 debug: false,
122 finalized: false,
123 needs_setup: false,
124 }
125 }
126
127 pub fn set_debug(&mut self) {
129 self.debug = true;
130 }
131
132 #[allow(unused)]
133 fn debug_print(&self, msg: &str) {
134 if self.debug {
135 println!("{}", msg);
136 }
137 }
138
139 pub fn is_finalized(&self) -> bool {
140 self.finalized
141 }
142
143 pub fn finalize(&mut self, needs_setup: bool) {
144 self.finalized = true;
145 self.needs_setup = needs_setup;
146
147 assert!(needs_setup || self.num_flags == 0);
149
150 if needs_setup && self.num_flags == 0 {
152 self.new_flag();
153 }
154 }
155
156 pub fn new_input(builder: Rc<RefCell<ExprBuilder>>) -> FieldVariable {
157 let mut borrowed = builder.borrow_mut();
158 let num_limbs = borrowed.num_limbs;
159 let limb_bits = borrowed.limb_bits;
160 borrowed.num_input += 1;
161 let (num_input, max_carry_bits) = (borrowed.num_input, borrowed.max_carry_bits);
162 drop(borrowed);
163 FieldVariable {
164 expr: SymbolicExpr::Input(num_input - 1),
165 builder: builder.clone(),
166 limb_max_abs: (1 << limb_bits) - 1,
167 max_overflow_bits: limb_bits,
168 expr_limbs: num_limbs,
169 max_carry_bits,
170 }
171 }
172
173 pub fn new_flag(&mut self) -> usize {
174 self.num_flags += 1;
175 self.num_flags - 1
176 }
177
178 pub fn needs_setup(&self) -> bool {
179 assert!(self.finalized); self.needs_setup
181 }
182
183 pub fn new_var(&mut self) -> (usize, SymbolicExpr) {
187 self.num_variables += 1;
188 self.constraints.push(SymbolicExpr::Input(0));
190 self.computes.push(SymbolicExpr::Input(0));
191 self.q_limbs.push(0);
192 self.carry_limbs.push(0);
193 (
194 self.num_variables - 1,
195 SymbolicExpr::Var(self.num_variables - 1),
196 )
197 }
198
199 pub fn new_const(builder: Rc<RefCell<ExprBuilder>>, value: BigUint) -> FieldVariable {
203 let mut borrowed = builder.borrow_mut();
204 let index = borrowed.constants.len();
205 let limb_bits = borrowed.limb_bits;
206 let num_limbs = borrowed.num_limbs;
207 let limbs = big_uint_to_num_limbs(&value, limb_bits, num_limbs);
208 let max_carry_bits = borrowed.max_carry_bits;
209 borrowed.constants.push((value.clone(), limbs));
210 drop(borrowed);
211
212 FieldVariable {
213 expr: SymbolicExpr::Const(index, value, num_limbs),
214 builder,
215 limb_max_abs: (1 << limb_bits) - 1,
216 max_overflow_bits: limb_bits,
217 expr_limbs: num_limbs,
218 max_carry_bits,
219 }
220 }
221
222 pub fn set_constraint(&mut self, index: usize, constraint: SymbolicExpr) {
223 let (q_limbs, carry_limbs) = constraint.constraint_limbs(
224 &self.prime,
225 self.limb_bits,
226 self.num_limbs,
227 &self.proper_max,
228 );
229 self.constraints[index] = constraint;
230 self.q_limbs[index] = q_limbs;
231 self.carry_limbs[index] = carry_limbs;
232 }
233
234 pub fn set_compute(&mut self, index: usize, compute: SymbolicExpr) {
235 self.computes[index] = compute;
236 }
237
238 pub fn proper_max(&self) -> &BigUint {
242 &self.proper_max
243 }
244}
245
246#[derive(Clone)]
247pub struct FieldExpr {
248 pub builder: ExprBuilder,
249
250 pub check_carry_mod_to_zero: CheckCarryModToZeroSubAir,
251
252 pub range_bus: VariableRangeCheckerBus,
253
254 pub setup_values: Vec<BigUint>,
256}
257
258impl FieldExpr {
259 pub fn new(
260 builder: ExprBuilder,
261 range_bus: VariableRangeCheckerBus,
262 needs_setup: bool,
263 ) -> Self {
264 let mut builder = builder;
265 builder.finalize(needs_setup);
266 let subair = CheckCarryModToZeroSubAir::new(
267 builder.prime.clone(),
268 builder.limb_bits,
269 range_bus.inner.index,
270 range_bus.range_max_bits,
271 );
272 FieldExpr {
273 builder,
274 check_carry_mod_to_zero: subair,
275 range_bus,
276 setup_values: vec![],
277 }
278 }
279
280 pub fn new_with_setup_values(
281 builder: ExprBuilder,
282 range_bus: VariableRangeCheckerBus,
283 needs_setup: bool,
284 setup_values: Vec<BigUint>,
285 ) -> Self {
286 let mut ret = Self::new(builder, range_bus, needs_setup);
287 ret.setup_values = setup_values;
288 ret
289 }
290}
291
292impl Deref for FieldExpr {
293 type Target = ExprBuilder;
294
295 fn deref(&self) -> &ExprBuilder {
296 &self.builder
297 }
298}
299
300impl<F: Field> BaseAirWithPublicValues<F> for FieldExpr {}
301impl<F: Field> PartitionedBaseAir<F> for FieldExpr {}
302impl<F: Field> BaseAir<F> for FieldExpr {
303 fn width(&self) -> usize {
304 assert!(self.builder.is_finalized());
305 self.num_limbs * (self.builder.num_input + self.builder.num_variables)
306 + self.builder.q_limbs.iter().sum::<usize>()
307 + self.builder.carry_limbs.iter().sum::<usize>()
308 + self.builder.num_flags
309 + 1 }
311}
312
313impl<AB: InteractionBuilder> Air<AB> for FieldExpr {
314 fn eval(&self, builder: &mut AB) {
315 let main = builder.main();
316 let local = main.row_slice(0);
317 SubAir::eval(self, builder, &local);
318 }
319}
320
321impl<AB: InteractionBuilder> SubAir<AB> for FieldExpr {
322 type AirContext<'a>
324 = &'a [AB::Var]
325 where
326 AB: 'a,
327 AB::Var: 'a,
328 AB::Expr: 'a;
329
330 fn eval<'a>(&'a self, builder: &'a mut AB, local: &'a [AB::Var])
331 where
332 AB::Var: 'a,
333 AB::Expr: 'a,
334 {
335 assert!(self.builder.is_finalized());
336 let FieldExprCols {
337 is_valid,
338 inputs,
339 vars,
340 q_limbs,
341 carry_limbs,
342 flags,
343 } = self.load_vars(local);
344
345 builder.assert_bool(is_valid);
346
347 if self.builder.needs_setup() {
348 let is_setup = flags.iter().fold(is_valid.into(), |acc, &x| acc - x);
349 builder.assert_bool(is_setup.clone());
350 let expected = iter::empty()
358 .chain({
359 let mut prime_limbs = self.builder.prime_limbs.clone();
360 prime_limbs.resize(self.builder.num_limbs, 0);
361 prime_limbs
362 })
363 .chain(self.setup_values.iter().flat_map(|x| {
364 big_uint_to_num_limbs(x, self.builder.limb_bits, self.builder.num_limbs)
365 .into_iter()
366 }))
367 .collect_vec();
368
369 let reads: Vec<AB::Expr> = inputs
370 .clone()
371 .into_iter()
372 .flatten()
373 .map(Into::into)
374 .take(expected.len())
375 .collect();
376
377 for (lhs, rhs) in zip_eq(&reads, expected) {
378 builder
379 .when(is_setup.clone())
380 .assert_eq(lhs.clone(), AB::F::from_canonical_usize(rhs));
381 }
382 }
383
384 let inputs = load_overflow::<AB>(inputs, self.limb_bits);
385 let vars = load_overflow::<AB>(vars, self.limb_bits);
386 let constants: Vec<_> = self
387 .constants
388 .iter()
389 .map(|(_, limbs)| {
390 let limbs_expr: Vec<_> = limbs
391 .iter()
392 .map(|limb| AB::Expr::from_canonical_usize(*limb))
393 .collect();
394 OverflowInt::from_canonical_unsigned_limbs(limbs_expr, self.limb_bits)
395 })
396 .collect();
397
398 for flag in flags.iter() {
399 builder.assert_bool(*flag);
400 }
401 for i in 0..self.constraints.len() {
402 let expr = self.constraints[i]
403 .evaluate_overflow_expr::<AB>(&inputs, &vars, &constants, &flags);
404 self.check_carry_mod_to_zero.eval(
405 builder,
406 (
407 expr,
408 CheckCarryModToZeroCols {
409 carries: carry_limbs[i].clone(),
410 quotient: q_limbs[i].clone(),
411 },
412 is_valid.into(),
413 ),
414 );
415 }
416
417 for var in vars.iter() {
418 for limb in var.limbs().iter() {
419 range_check(
420 builder,
421 self.range_bus.inner.index,
422 self.range_bus.range_max_bits,
423 self.limb_bits,
424 limb.clone(),
425 is_valid,
426 );
427 }
428 }
429 }
430}
431
432type Vecs<T> = Vec<Vec<T>>;
433
434pub struct FieldExprCols<T> {
435 pub is_valid: T,
436 pub inputs: Vecs<T>,
437 pub vars: Vecs<T>,
438 pub q_limbs: Vecs<T>,
439 pub carry_limbs: Vecs<T>,
440 pub flags: Vec<T>,
441}
442
443impl<F: PrimeField64> TraceSubRowGenerator<F> for FieldExpr {
444 type TraceContext<'a> = (&'a VariableRangeCheckerChip, Vec<BigUint>, Vec<bool>);
445 type ColsMut<'a> = &'a mut [F];
446
447 fn generate_subrow<'a>(
448 &'a self,
449 (range_checker, inputs, flags): (&'a VariableRangeCheckerChip, Vec<BigUint>, Vec<bool>),
450 sub_row: &'a mut [F],
451 ) {
452 assert!(self.builder.is_finalized());
453 assert_eq!(inputs.len(), self.num_input);
454 assert_eq!(self.num_variables, self.constraints.len());
455
456 assert_eq!(flags.len(), self.builder.num_flags);
457
458 let limb_bits = self.limb_bits;
459 let mut vars = vec![BigUint::zero(); self.num_variables];
460
461 let input_bigint = inputs
463 .iter()
464 .map(|x| BigInt::from_biguint(Sign::Plus, x.clone()))
465 .collect::<Vec<BigInt>>();
466 let mut vars_bigint = vec![BigInt::zero(); self.num_variables];
467
468 let input_overflow = inputs
470 .iter()
471 .map(|x| OverflowInt::<isize>::from_biguint(x, self.limb_bits, Some(self.num_limbs)))
472 .collect::<Vec<_>>();
473 let zero = OverflowInt::<isize>::from_canonical_unsigned_limbs(vec![0], limb_bits);
474 let mut vars_overflow = vec![zero; self.num_variables];
475 let prime_overflow = OverflowInt::<isize>::from_biguint(&self.prime, self.limb_bits, None);
477
478 let constants: Vec<_> = self
479 .constants
480 .iter()
481 .map(|(_, limbs)| {
482 let limbs_isize: Vec<_> = limbs.iter().map(|i| *i as isize).collect();
483 OverflowInt::from_canonical_unsigned_limbs(limbs_isize, self.limb_bits)
484 })
485 .collect();
486
487 let mut all_q = vec![];
488 let mut all_carry = vec![];
489 for i in 0..self.constraints.len() {
490 let r = self.computes[i].compute(&inputs, &vars, &flags, &self.prime);
491 vars[i] = r.clone();
492 vars_bigint[i] = BigInt::from_biguint(Sign::Plus, r);
493 vars_overflow[i] =
494 OverflowInt::<isize>::from_biguint(&vars[i], self.limb_bits, Some(self.num_limbs));
495 }
496 for i in 0..self.constraints.len() {
498 let expr_bigint =
500 self.constraints[i].evaluate_bigint(&input_bigint, &vars_bigint, &flags);
501 let q = &expr_bigint / &self.prime_bigint;
502 debug_assert_eq!(expr_bigint, &q * &self.prime_bigint);
504 let q_limbs = big_int_to_num_limbs(&q, limb_bits, self.q_limbs[i]);
505 assert_eq!(q_limbs.len(), self.q_limbs[i]); for &q in q_limbs.iter() {
507 range_checker.add_count((q + (1 << limb_bits)) as u32, limb_bits + 1);
508 }
509 let q_overflow = OverflowInt::from_canonical_signed_limbs(q_limbs.clone(), limb_bits);
510 let expr = self.constraints[i].evaluate_overflow_isize(
512 &input_overflow,
513 &vars_overflow,
514 &constants,
515 &flags,
516 );
517 let expr = expr - q_overflow * prime_overflow.clone();
518 let carries = expr.calculate_carries(limb_bits);
519 assert_eq!(carries.len(), self.carry_limbs[i]); let max_overflow_bits = expr.max_overflow_bits();
521 let (carry_min_abs, carry_bits) =
522 get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
523 for &carry in carries.iter() {
524 range_checker.add_count((carry + carry_min_abs as isize) as u32, carry_bits);
525 }
526 all_q.push(vec_isize_to_f::<F>(q_limbs));
527 all_carry.push(vec_isize_to_f::<F>(carries));
528 }
529 for var in vars_overflow.iter() {
530 for limb in var.limbs().iter() {
531 range_checker.add_count(*limb as u32, limb_bits);
532 }
533 }
534
535 let input_limbs = input_overflow
536 .iter()
537 .map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
538 .collect::<Vec<_>>();
539 let vars_limbs = vars_overflow
540 .iter()
541 .map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
542 .collect::<Vec<_>>();
543
544 sub_row.copy_from_slice(
545 &[
546 vec![F::ONE],
547 input_limbs.concat(),
548 vars_limbs.concat(),
549 all_q.concat(),
550 all_carry.concat(),
551 flags.iter().map(|x| F::from_bool(*x)).collect::<Vec<_>>(),
552 ]
553 .concat(),
554 );
555 }
556}
557
558impl FieldExpr {
559 pub fn canonical_num_limbs(&self) -> usize {
560 self.builder.num_limbs
561 }
562
563 pub fn canonical_limb_bits(&self) -> usize {
564 self.builder.limb_bits
565 }
566
567 pub fn execute(&self, inputs: Vec<BigUint>, flags: Vec<bool>) -> Vec<BigUint> {
568 assert!(self.builder.is_finalized());
569
570 #[cfg(debug_assertions)]
571 {
572 let is_setup = self.builder.needs_setup() && flags.iter().all(|&x| !x);
573 if is_setup {
574 assert_eq!(inputs[0], self.builder.prime);
575 assert!(inputs.len() > self.setup_values.len());
577 for (expected, actual) in self.setup_values.iter().zip(inputs.iter().skip(1)) {
578 assert_eq!(expected, actual);
579 }
580 }
581 }
582
583 let mut vars = vec![BigUint::zero(); self.num_variables];
584 for i in 0..self.constraints.len() {
585 let r = self.computes[i].compute(&inputs, &vars, &flags, &self.prime);
586 vars[i] = r.clone();
587 }
588 vars
589 }
590
591 pub fn execute_with_output(&self, inputs: Vec<BigUint>, flags: Vec<bool>) -> Vec<BigUint> {
592 let vars = self.execute(inputs, flags);
593 self.builder
594 .output_indices
595 .iter()
596 .map(|i| vars[*i].clone())
597 .collect()
598 }
599
600 pub fn load_vars<T: Clone>(&self, arr: &[T]) -> FieldExprCols<T> {
601 assert!(self.builder.is_finalized());
602 let is_valid = arr[0].clone();
603 let mut idx = 1;
604 let mut inputs = vec![];
605 for _ in 0..self.num_input {
606 inputs.push(arr[idx..idx + self.num_limbs].to_vec());
607 idx += self.num_limbs;
608 }
609 let mut vars = vec![];
610 for _ in 0..self.num_variables {
611 vars.push(arr[idx..idx + self.num_limbs].to_vec());
612 idx += self.num_limbs;
613 }
614 let mut q_limbs = vec![];
615 for q in self.q_limbs.iter() {
616 q_limbs.push(arr[idx..idx + q].to_vec());
617 idx += q;
618 }
619 let mut carry_limbs = vec![];
620 for c in self.carry_limbs.iter() {
621 carry_limbs.push(arr[idx..idx + c].to_vec());
622 idx += c;
623 }
624 let flags = arr[idx..idx + self.num_flags].to_vec();
625 FieldExprCols {
626 is_valid,
627 inputs,
628 vars,
629 q_limbs,
630 carry_limbs,
631 flags,
632 }
633 }
634}
635
636fn load_overflow<AB: AirBuilder>(
637 arr: Vecs<AB::Var>,
638 limb_bits: usize,
639) -> Vec<OverflowInt<AB::Expr>> {
640 let mut result = vec![];
641 for x in arr.into_iter() {
642 let limbs: Vec<AB::Expr> = x.iter().map(|x| (*x).into()).collect();
643 result.push(OverflowInt::<AB::Expr>::from_canonical_unsigned_limbs(
644 limbs, limb_bits,
645 ));
646 }
647 result
648}