openvm_mod_circuit_builder/
builder.rs

1use std::{cell::RefCell, cmp::min, iter, ops::Deref, rc::Rc};
2
3use itertools::{zip_eq, Itertools};
4use num_bigint::{BigInt, BigUint, Sign};
5use num_traits::{One, Zero};
6use openvm_circuit_primitives::{
7    bigint::{
8        check_carry_mod_to_zero::{CheckCarryModToZeroCols, CheckCarryModToZeroSubAir},
9        check_carry_to_zero::get_carry_max_abs_and_bits,
10        utils::*,
11        OverflowInt,
12    },
13    var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
14    SubAir, TraceSubRowGenerator,
15};
16use openvm_stark_backend::{
17    interaction::InteractionBuilder,
18    p3_air::{Air, AirBuilder, BaseAir},
19    p3_field::{Field, FieldAlgebra, PrimeField64},
20    p3_matrix::Matrix,
21    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
22};
23
24use super::{FieldVariable, SymbolicExpr};
25
26#[derive(Clone)]
27pub struct ExprBuilderConfig {
28    pub modulus: BigUint,
29    pub num_limbs: usize,
30    pub limb_bits: usize,
31}
32
33impl ExprBuilderConfig {
34    pub fn check_valid(&self) {
35        assert!(self.modulus.bits() <= (self.num_limbs * self.limb_bits) as u64);
36    }
37}
38
39#[derive(Clone)]
40pub struct ExprBuilder {
41    // The prime field.
42    pub prime: BigUint,
43    // Same value, but we need BigInt for computing the quotient.
44    pub prime_bigint: BigInt,
45    pub prime_limbs: Vec<usize>,
46
47    pub num_input: usize,
48    pub num_flags: usize,
49
50    // This should be equal to number of constraints, but declare it to be explicit.
51    pub num_variables: usize,
52
53    pub constants: Vec<(BigUint, Vec<usize>)>, // value and limbs
54
55    /// The number of bits in a canonical representation of a limb.
56    pub limb_bits: usize,
57    /// Number of limbs in canonical representation of the bigint field element.
58    pub num_limbs: usize,
59    proper_max: BigUint,
60    // The max bits that we can range check.
61    pub range_checker_bits: usize,
62    // The max bits that carries are allowed to have.
63    pub max_carry_bits: usize,
64
65    // The number of limbs of the quotient for each constraint.
66    pub q_limbs: Vec<usize>,
67    // The number of limbs of the carries for each constraint.
68    pub carry_limbs: Vec<usize>,
69
70    // The constraints that should be evaluated to zero mod p (doesn't include - p * q part).
71    pub constraints: Vec<SymbolicExpr>,
72
73    // The equations to compute the newly introduced variables. For trace gen only.
74    pub computes: Vec<SymbolicExpr>,
75
76    pub output_indices: Vec<usize>,
77
78    /// flag for debug mode
79    debug: bool,
80
81    /// Whether the builder has been finalized. Only after finalize, we can do generate_subrow and eval etc.
82    finalized: bool,
83
84    // Setup opcode is a special op that verifies the modulus is correct.
85    // There are some chips that don't need it because we hardcode the modulus. E.g. the pairing ones.
86    // For those chips need setup, setup is derived: setup = is_valid - sum(all_flags)
87    // Therefore when the chip only supports one opcode, user won't explicitly create a flag for it
88    // and we will create a default flag for it on finalizing.
89    needs_setup: bool,
90}
91
92// Number of bits in BabyBear modulus
93const MODULUS_BITS: usize = 31;
94
95impl ExprBuilder {
96    pub fn new(config: ExprBuilderConfig, range_checker_bits: usize) -> Self {
97        let prime_bigint = BigInt::from_biguint(Sign::Plus, config.modulus.clone());
98        let proper_max = (BigUint::one() << (config.num_limbs * config.limb_bits)) - BigUint::one();
99        // Max carry bits to ensure constraints don't overflow
100        let max_carry_bits = MODULUS_BITS - config.limb_bits - 2;
101        // sanity
102        assert!(config.limb_bits + 2 < MODULUS_BITS);
103        Self {
104            prime: config.modulus.clone(),
105            prime_bigint,
106            prime_limbs: big_uint_to_limbs(&config.modulus, config.limb_bits),
107            num_input: 0,
108            num_flags: 0,
109            limb_bits: config.limb_bits,
110            num_limbs: config.num_limbs,
111            proper_max,
112            range_checker_bits,
113            max_carry_bits: min(max_carry_bits, range_checker_bits),
114            num_variables: 0,
115            constants: vec![],
116            q_limbs: vec![],
117            carry_limbs: vec![],
118            constraints: vec![],
119            computes: vec![],
120            output_indices: vec![],
121            debug: false,
122            finalized: false,
123            needs_setup: false,
124        }
125    }
126
127    // This can be used to debug, when we only want to print something in a specific chip.
128    pub fn set_debug(&mut self) {
129        self.debug = true;
130    }
131
132    #[allow(unused)]
133    fn debug_print(&self, msg: &str) {
134        if self.debug {
135            println!("{}", msg);
136        }
137    }
138
139    pub fn is_finalized(&self) -> bool {
140        self.finalized
141    }
142
143    pub fn finalize(&mut self, needs_setup: bool) {
144        self.finalized = true;
145        self.needs_setup = needs_setup;
146
147        // We don't support multi-op chip that doesn't need setup right now.
148        assert!(needs_setup || self.num_flags == 0);
149
150        // setup the default flag if needed
151        if needs_setup && self.num_flags == 0 {
152            self.new_flag();
153        }
154    }
155
156    pub fn new_input(builder: Rc<RefCell<ExprBuilder>>) -> FieldVariable {
157        let mut borrowed = builder.borrow_mut();
158        let num_limbs = borrowed.num_limbs;
159        let limb_bits = borrowed.limb_bits;
160        borrowed.num_input += 1;
161        let (num_input, max_carry_bits) = (borrowed.num_input, borrowed.max_carry_bits);
162        drop(borrowed);
163        FieldVariable {
164            expr: SymbolicExpr::Input(num_input - 1),
165            builder: builder.clone(),
166            limb_max_abs: (1 << limb_bits) - 1,
167            max_overflow_bits: limb_bits,
168            expr_limbs: num_limbs,
169            max_carry_bits,
170        }
171    }
172
173    pub fn new_flag(&mut self) -> usize {
174        self.num_flags += 1;
175        self.num_flags - 1
176    }
177
178    pub fn needs_setup(&self) -> bool {
179        assert!(self.finalized); // Should only be used after finalize.
180        self.needs_setup
181    }
182
183    // Below functions are used when adding variables and constraints manually, need to be careful.
184    // Number of variables, constraints and computes should be consistent,
185    // so there should be same number of calls to the new_var, add_constraint and add_compute.
186    pub fn new_var(&mut self) -> (usize, SymbolicExpr) {
187        self.num_variables += 1;
188        // Allocate space for the new variable, to make sure they are corresponding to the same variable index.
189        self.constraints.push(SymbolicExpr::Input(0));
190        self.computes.push(SymbolicExpr::Input(0));
191        self.q_limbs.push(0);
192        self.carry_limbs.push(0);
193        (
194            self.num_variables - 1,
195            SymbolicExpr::Var(self.num_variables - 1),
196        )
197    }
198
199    /// Creates a new constant (compile-time known) FieldVariable from `value` where
200    /// the big integer `value` is decomposed into `num_limbs` limbs of `limb_bits` bits,
201    /// with `num_limbs, limb_bits` specified by the builder config.
202    pub fn new_const(builder: Rc<RefCell<ExprBuilder>>, value: BigUint) -> FieldVariable {
203        let mut borrowed = builder.borrow_mut();
204        let index = borrowed.constants.len();
205        let limb_bits = borrowed.limb_bits;
206        let num_limbs = borrowed.num_limbs;
207        let limbs = big_uint_to_num_limbs(&value, limb_bits, num_limbs);
208        let max_carry_bits = borrowed.max_carry_bits;
209        borrowed.constants.push((value.clone(), limbs));
210        drop(borrowed);
211
212        FieldVariable {
213            expr: SymbolicExpr::Const(index, value, num_limbs),
214            builder,
215            limb_max_abs: (1 << limb_bits) - 1,
216            max_overflow_bits: limb_bits,
217            expr_limbs: num_limbs,
218            max_carry_bits,
219        }
220    }
221
222    pub fn set_constraint(&mut self, index: usize, constraint: SymbolicExpr) {
223        let (q_limbs, carry_limbs) = constraint.constraint_limbs(
224            &self.prime,
225            self.limb_bits,
226            self.num_limbs,
227            &self.proper_max,
228        );
229        self.constraints[index] = constraint;
230        self.q_limbs[index] = q_limbs;
231        self.carry_limbs[index] = carry_limbs;
232    }
233
234    pub fn set_compute(&mut self, index: usize, compute: SymbolicExpr) {
235        self.computes[index] = compute;
236    }
237
238    /// Returns `proper_max = 2^{num_limbs * limb_bits} - 1` as a precomputed value.
239    /// Any proper representation of a positive big integer using `num_limbs` limbs with
240    /// `limb_bits` bits each will be `<= proper_max`.
241    pub fn proper_max(&self) -> &BigUint {
242        &self.proper_max
243    }
244}
245
246#[derive(Clone)]
247pub struct FieldExpr {
248    pub builder: ExprBuilder,
249
250    pub check_carry_mod_to_zero: CheckCarryModToZeroSubAir,
251
252    pub range_bus: VariableRangeCheckerBus,
253
254    // any values other than the prime modulus that need to be checked at setup
255    pub setup_values: Vec<BigUint>,
256}
257
258impl FieldExpr {
259    pub fn new(
260        builder: ExprBuilder,
261        range_bus: VariableRangeCheckerBus,
262        needs_setup: bool,
263    ) -> Self {
264        let mut builder = builder;
265        builder.finalize(needs_setup);
266        let subair = CheckCarryModToZeroSubAir::new(
267            builder.prime.clone(),
268            builder.limb_bits,
269            range_bus.inner.index,
270            range_bus.range_max_bits,
271        );
272        FieldExpr {
273            builder,
274            check_carry_mod_to_zero: subair,
275            range_bus,
276            setup_values: vec![],
277        }
278    }
279
280    pub fn new_with_setup_values(
281        builder: ExprBuilder,
282        range_bus: VariableRangeCheckerBus,
283        needs_setup: bool,
284        setup_values: Vec<BigUint>,
285    ) -> Self {
286        let mut ret = Self::new(builder, range_bus, needs_setup);
287        ret.setup_values = setup_values;
288        ret
289    }
290}
291
292impl Deref for FieldExpr {
293    type Target = ExprBuilder;
294
295    fn deref(&self) -> &ExprBuilder {
296        &self.builder
297    }
298}
299
300impl<F: Field> BaseAirWithPublicValues<F> for FieldExpr {}
301impl<F: Field> PartitionedBaseAir<F> for FieldExpr {}
302impl<F: Field> BaseAir<F> for FieldExpr {
303    fn width(&self) -> usize {
304        assert!(self.builder.is_finalized());
305        self.num_limbs * (self.builder.num_input + self.builder.num_variables)
306            + self.builder.q_limbs.iter().sum::<usize>()
307            + self.builder.carry_limbs.iter().sum::<usize>()
308            + self.builder.num_flags
309            + 1 // is_valid
310    }
311}
312
313impl<AB: InteractionBuilder> Air<AB> for FieldExpr {
314    fn eval(&self, builder: &mut AB) {
315        let main = builder.main();
316        let local = main.row_slice(0);
317        SubAir::eval(self, builder, &local);
318    }
319}
320
321impl<AB: InteractionBuilder> SubAir<AB> for FieldExpr {
322    /// The sub-row slice owned by the expression builder.
323    type AirContext<'a>
324        = &'a [AB::Var]
325    where
326        AB: 'a,
327        AB::Var: 'a,
328        AB::Expr: 'a;
329
330    fn eval<'a>(&'a self, builder: &'a mut AB, local: &'a [AB::Var])
331    where
332        AB::Var: 'a,
333        AB::Expr: 'a,
334    {
335        assert!(self.builder.is_finalized());
336        let FieldExprCols {
337            is_valid,
338            inputs,
339            vars,
340            q_limbs,
341            carry_limbs,
342            flags,
343        } = self.load_vars(local);
344
345        builder.assert_bool(is_valid);
346
347        if self.builder.needs_setup() {
348            let is_setup = flags.iter().fold(is_valid.into(), |acc, &x| acc - x);
349            builder.assert_bool(is_setup.clone());
350            // TODO[jpw]: currently we enforce at the program code level that:
351            // - a valid program must call the correct setup opcodes to be correct
352            // - it would be better if we can constraint this in the circuit,
353            //   however this has the challenge that when the same chip is used
354            //   across continuation segments,
355            //   only the first segment will have setup called
356
357            let expected = iter::empty()
358                .chain({
359                    let mut prime_limbs = self.builder.prime_limbs.clone();
360                    prime_limbs.resize(self.builder.num_limbs, 0);
361                    prime_limbs
362                })
363                .chain(self.setup_values.iter().flat_map(|x| {
364                    big_uint_to_num_limbs(x, self.builder.limb_bits, self.builder.num_limbs)
365                        .into_iter()
366                }))
367                .collect_vec();
368
369            let reads: Vec<AB::Expr> = inputs
370                .clone()
371                .into_iter()
372                .flatten()
373                .map(Into::into)
374                .take(expected.len())
375                .collect();
376
377            for (lhs, rhs) in zip_eq(&reads, expected) {
378                builder
379                    .when(is_setup.clone())
380                    .assert_eq(lhs.clone(), AB::F::from_canonical_usize(rhs));
381            }
382        }
383
384        let inputs = load_overflow::<AB>(inputs, self.limb_bits);
385        let vars = load_overflow::<AB>(vars, self.limb_bits);
386        let constants: Vec<_> = self
387            .constants
388            .iter()
389            .map(|(_, limbs)| {
390                let limbs_expr: Vec<_> = limbs
391                    .iter()
392                    .map(|limb| AB::Expr::from_canonical_usize(*limb))
393                    .collect();
394                OverflowInt::from_canonical_unsigned_limbs(limbs_expr, self.limb_bits)
395            })
396            .collect();
397
398        for flag in flags.iter() {
399            builder.assert_bool(*flag);
400        }
401        for i in 0..self.constraints.len() {
402            let expr = self.constraints[i]
403                .evaluate_overflow_expr::<AB>(&inputs, &vars, &constants, &flags);
404            self.check_carry_mod_to_zero.eval(
405                builder,
406                (
407                    expr,
408                    CheckCarryModToZeroCols {
409                        carries: carry_limbs[i].clone(),
410                        quotient: q_limbs[i].clone(),
411                    },
412                    is_valid.into(),
413                ),
414            );
415        }
416
417        for var in vars.iter() {
418            for limb in var.limbs().iter() {
419                range_check(
420                    builder,
421                    self.range_bus.inner.index,
422                    self.range_bus.range_max_bits,
423                    self.limb_bits,
424                    limb.clone(),
425                    is_valid,
426                );
427            }
428        }
429    }
430}
431
432type Vecs<T> = Vec<Vec<T>>;
433
434pub struct FieldExprCols<T> {
435    pub is_valid: T,
436    pub inputs: Vecs<T>,
437    pub vars: Vecs<T>,
438    pub q_limbs: Vecs<T>,
439    pub carry_limbs: Vecs<T>,
440    pub flags: Vec<T>,
441}
442
443impl<F: PrimeField64> TraceSubRowGenerator<F> for FieldExpr {
444    type TraceContext<'a> = (&'a VariableRangeCheckerChip, Vec<BigUint>, Vec<bool>);
445    type ColsMut<'a> = &'a mut [F];
446
447    fn generate_subrow<'a>(
448        &'a self,
449        (range_checker, inputs, flags): (&'a VariableRangeCheckerChip, Vec<BigUint>, Vec<bool>),
450        sub_row: &'a mut [F],
451    ) {
452        assert!(self.builder.is_finalized());
453        assert_eq!(inputs.len(), self.num_input);
454        assert_eq!(self.num_variables, self.constraints.len());
455
456        assert_eq!(flags.len(), self.builder.num_flags);
457
458        let limb_bits = self.limb_bits;
459        let mut vars = vec![BigUint::zero(); self.num_variables];
460
461        // BigInt type is required for computing the quotient.
462        let input_bigint = inputs
463            .iter()
464            .map(|x| BigInt::from_biguint(Sign::Plus, x.clone()))
465            .collect::<Vec<BigInt>>();
466        let mut vars_bigint = vec![BigInt::zero(); self.num_variables];
467
468        // OverflowInt type is required for computing the carries.
469        let input_overflow = inputs
470            .iter()
471            .map(|x| OverflowInt::<isize>::from_biguint(x, self.limb_bits, Some(self.num_limbs)))
472            .collect::<Vec<_>>();
473        let zero = OverflowInt::<isize>::from_canonical_unsigned_limbs(vec![0], limb_bits);
474        let mut vars_overflow = vec![zero; self.num_variables];
475        // Note: in cases where the prime fits in less limbs than `num_limbs`, we use the smaller number of limbs.
476        let prime_overflow = OverflowInt::<isize>::from_biguint(&self.prime, self.limb_bits, None);
477
478        let constants: Vec<_> = self
479            .constants
480            .iter()
481            .map(|(_, limbs)| {
482                let limbs_isize: Vec<_> = limbs.iter().map(|i| *i as isize).collect();
483                OverflowInt::from_canonical_unsigned_limbs(limbs_isize, self.limb_bits)
484            })
485            .collect();
486
487        let mut all_q = vec![];
488        let mut all_carry = vec![];
489        for i in 0..self.constraints.len() {
490            let r = self.computes[i].compute(&inputs, &vars, &flags, &self.prime);
491            vars[i] = r.clone();
492            vars_bigint[i] = BigInt::from_biguint(Sign::Plus, r);
493            vars_overflow[i] =
494                OverflowInt::<isize>::from_biguint(&vars[i], self.limb_bits, Some(self.num_limbs));
495        }
496        // We need to have all variables computed first because, e.g. constraints[2] might need variables[3].
497        for i in 0..self.constraints.len() {
498            // expr = q * p
499            let expr_bigint =
500                self.constraints[i].evaluate_bigint(&input_bigint, &vars_bigint, &flags);
501            let q = &expr_bigint / &self.prime_bigint;
502            // If this is not true then the evaluated constraint is not divisible by p.
503            debug_assert_eq!(expr_bigint, &q * &self.prime_bigint);
504            let q_limbs = big_int_to_num_limbs(&q, limb_bits, self.q_limbs[i]);
505            assert_eq!(q_limbs.len(), self.q_limbs[i]); // If this fails, the q_limbs estimate is wrong.
506            for &q in q_limbs.iter() {
507                range_checker.add_count((q + (1 << limb_bits)) as u32, limb_bits + 1);
508            }
509            let q_overflow = OverflowInt::from_canonical_signed_limbs(q_limbs.clone(), limb_bits);
510            // compute carries of (expr - q * p)
511            let expr = self.constraints[i].evaluate_overflow_isize(
512                &input_overflow,
513                &vars_overflow,
514                &constants,
515                &flags,
516            );
517            let expr = expr - q_overflow * prime_overflow.clone();
518            let carries = expr.calculate_carries(limb_bits);
519            assert_eq!(carries.len(), self.carry_limbs[i]); // If this fails, the carry limbs estimate is wrong.
520            let max_overflow_bits = expr.max_overflow_bits();
521            let (carry_min_abs, carry_bits) =
522                get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
523            for &carry in carries.iter() {
524                range_checker.add_count((carry + carry_min_abs as isize) as u32, carry_bits);
525            }
526            all_q.push(vec_isize_to_f::<F>(q_limbs));
527            all_carry.push(vec_isize_to_f::<F>(carries));
528        }
529        for var in vars_overflow.iter() {
530            for limb in var.limbs().iter() {
531                range_checker.add_count(*limb as u32, limb_bits);
532            }
533        }
534
535        let input_limbs = input_overflow
536            .iter()
537            .map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
538            .collect::<Vec<_>>();
539        let vars_limbs = vars_overflow
540            .iter()
541            .map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
542            .collect::<Vec<_>>();
543
544        sub_row.copy_from_slice(
545            &[
546                vec![F::ONE],
547                input_limbs.concat(),
548                vars_limbs.concat(),
549                all_q.concat(),
550                all_carry.concat(),
551                flags.iter().map(|x| F::from_bool(*x)).collect::<Vec<_>>(),
552            ]
553            .concat(),
554        );
555    }
556}
557
558impl FieldExpr {
559    pub fn canonical_num_limbs(&self) -> usize {
560        self.builder.num_limbs
561    }
562
563    pub fn canonical_limb_bits(&self) -> usize {
564        self.builder.limb_bits
565    }
566
567    pub fn execute(&self, inputs: Vec<BigUint>, flags: Vec<bool>) -> Vec<BigUint> {
568        assert!(self.builder.is_finalized());
569
570        #[cfg(debug_assertions)]
571        {
572            let is_setup = self.builder.needs_setup() && flags.iter().all(|&x| !x);
573            if is_setup {
574                assert_eq!(inputs[0], self.builder.prime);
575                // Check that inputs.iter().skip(1) has all the setup values as a prefix
576                assert!(inputs.len() > self.setup_values.len());
577                for (expected, actual) in self.setup_values.iter().zip(inputs.iter().skip(1)) {
578                    assert_eq!(expected, actual);
579                }
580            }
581        }
582
583        let mut vars = vec![BigUint::zero(); self.num_variables];
584        for i in 0..self.constraints.len() {
585            let r = self.computes[i].compute(&inputs, &vars, &flags, &self.prime);
586            vars[i] = r.clone();
587        }
588        vars
589    }
590
591    pub fn execute_with_output(&self, inputs: Vec<BigUint>, flags: Vec<bool>) -> Vec<BigUint> {
592        let vars = self.execute(inputs, flags);
593        self.builder
594            .output_indices
595            .iter()
596            .map(|i| vars[*i].clone())
597            .collect()
598    }
599
600    pub fn load_vars<T: Clone>(&self, arr: &[T]) -> FieldExprCols<T> {
601        assert!(self.builder.is_finalized());
602        let is_valid = arr[0].clone();
603        let mut idx = 1;
604        let mut inputs = vec![];
605        for _ in 0..self.num_input {
606            inputs.push(arr[idx..idx + self.num_limbs].to_vec());
607            idx += self.num_limbs;
608        }
609        let mut vars = vec![];
610        for _ in 0..self.num_variables {
611            vars.push(arr[idx..idx + self.num_limbs].to_vec());
612            idx += self.num_limbs;
613        }
614        let mut q_limbs = vec![];
615        for q in self.q_limbs.iter() {
616            q_limbs.push(arr[idx..idx + q].to_vec());
617            idx += q;
618        }
619        let mut carry_limbs = vec![];
620        for c in self.carry_limbs.iter() {
621            carry_limbs.push(arr[idx..idx + c].to_vec());
622            idx += c;
623        }
624        let flags = arr[idx..idx + self.num_flags].to_vec();
625        FieldExprCols {
626            is_valid,
627            inputs,
628            vars,
629            q_limbs,
630            carry_limbs,
631            flags,
632        }
633    }
634}
635
636fn load_overflow<AB: AirBuilder>(
637    arr: Vecs<AB::Var>,
638    limb_bits: usize,
639) -> Vec<OverflowInt<AB::Expr>> {
640    let mut result = vec![];
641    for x in arr.into_iter() {
642        let limbs: Vec<AB::Expr> = x.iter().map(|x| (*x).into()).collect();
643        result.push(OverflowInt::<AB::Expr>::from_canonical_unsigned_limbs(
644            limbs, limb_bits,
645        ));
646    }
647    result
648}