openvm_mod_circuit_builder/
builder.rs

1use std::{cell::RefCell, cmp::min, iter, ops::Deref, rc::Rc};
2
3use itertools::{zip_eq, Itertools};
4use num_bigint::{BigInt, BigUint, Sign};
5use num_traits::{One, Zero};
6use openvm_circuit_primitives::{
7    bigint::{
8        check_carry_mod_to_zero::{CheckCarryModToZeroCols, CheckCarryModToZeroSubAir},
9        check_carry_to_zero::get_carry_max_abs_and_bits,
10        utils::*,
11        OverflowInt,
12    },
13    var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
14    SubAir, TraceSubRowGenerator,
15};
16use openvm_stark_backend::{
17    interaction::InteractionBuilder,
18    p3_air::{Air, AirBuilder, BaseAir},
19    p3_field::{Field, FieldAlgebra, PrimeField64},
20    p3_matrix::Matrix,
21    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
22};
23
24use super::{FieldVariable, SymbolicExpr};
25
26#[derive(Clone)]
27pub struct ExprBuilderConfig {
28    pub modulus: BigUint,
29    pub num_limbs: usize,
30    pub limb_bits: usize,
31}
32
33impl ExprBuilderConfig {
34    pub fn check_valid(&self) {
35        assert!(self.modulus.bits() <= (self.num_limbs * self.limb_bits) as u64);
36    }
37}
38
39#[derive(Clone)]
40pub struct ExprBuilder {
41    // The prime field.
42    pub prime: BigUint,
43    // Same value, but we need BigInt for computing the quotient.
44    pub prime_bigint: BigInt,
45    pub prime_limbs: Vec<usize>,
46
47    pub num_input: usize,
48    pub num_flags: usize,
49
50    // This should be equal to number of constraints, but declare it to be explicit.
51    pub num_variables: usize,
52
53    pub constants: Vec<(BigUint, Vec<usize>)>, // value and limbs
54
55    /// The number of bits in a canonical representation of a limb.
56    pub limb_bits: usize,
57    /// Number of limbs in canonical representation of the bigint field element.
58    pub num_limbs: usize,
59    proper_max: BigUint,
60    // The max bits that we can range check.
61    pub range_checker_bits: usize,
62    // The max bits that carries are allowed to have.
63    pub max_carry_bits: usize,
64
65    // The number of limbs of the quotient for each constraint.
66    pub q_limbs: Vec<usize>,
67    // The number of limbs of the carries for each constraint.
68    pub carry_limbs: Vec<usize>,
69
70    // The constraints that should be evaluated to zero mod p (doesn't include - p * q part).
71    pub constraints: Vec<SymbolicExpr>,
72
73    // The equations to compute the newly introduced variables. For trace gen only.
74    pub computes: Vec<SymbolicExpr>,
75
76    pub output_indices: Vec<usize>,
77
78    /// flag for debug mode
79    debug: bool,
80
81    /// Whether the builder has been finalized. Only after finalize, we can do generate_subrow and
82    /// eval etc.
83    finalized: bool,
84
85    // Setup opcode is a special op that verifies the modulus is correct.
86    // There are some chips that don't need it because we hardcode the modulus. E.g. the pairing
87    // ones. For those chips need setup, setup is derived: setup = is_valid - sum(all_flags)
88    // Therefore when the chip only supports one opcode, user won't explicitly create a flag for it
89    // and we will create a default flag for it on finalizing.
90    needs_setup: bool,
91}
92
93// Number of bits in BabyBear modulus
94const MODULUS_BITS: usize = 31;
95
96impl ExprBuilder {
97    pub fn new(config: ExprBuilderConfig, range_checker_bits: usize) -> Self {
98        let prime_bigint = BigInt::from_biguint(Sign::Plus, config.modulus.clone());
99        let proper_max = (BigUint::one() << (config.num_limbs * config.limb_bits)) - BigUint::one();
100        // Max carry bits to ensure constraints don't overflow
101        let max_carry_bits = MODULUS_BITS - config.limb_bits - 2;
102        // sanity
103        assert!(config.limb_bits + 2 < MODULUS_BITS);
104        Self {
105            prime: config.modulus.clone(),
106            prime_bigint,
107            prime_limbs: big_uint_to_limbs(&config.modulus, config.limb_bits),
108            num_input: 0,
109            num_flags: 0,
110            limb_bits: config.limb_bits,
111            num_limbs: config.num_limbs,
112            proper_max,
113            range_checker_bits,
114            max_carry_bits: min(max_carry_bits, range_checker_bits),
115            num_variables: 0,
116            constants: vec![],
117            q_limbs: vec![],
118            carry_limbs: vec![],
119            constraints: vec![],
120            computes: vec![],
121            output_indices: vec![],
122            debug: false,
123            finalized: false,
124            needs_setup: false,
125        }
126    }
127
128    // This can be used to debug, when we only want to print something in a specific chip.
129    pub fn set_debug(&mut self) {
130        self.debug = true;
131    }
132
133    #[allow(unused)]
134    fn debug_print(&self, msg: &str) {
135        if self.debug {
136            println!("{}", msg);
137        }
138    }
139
140    pub fn is_finalized(&self) -> bool {
141        self.finalized
142    }
143
144    pub fn finalize(&mut self, needs_setup: bool) {
145        self.finalized = true;
146        self.needs_setup = needs_setup;
147
148        // We don't support multi-op chip that doesn't need setup right now.
149        assert!(needs_setup || self.num_flags == 0);
150
151        // setup the default flag if needed
152        if needs_setup && self.num_flags == 0 {
153            self.new_flag();
154        }
155    }
156
157    pub fn new_input(builder: Rc<RefCell<ExprBuilder>>) -> FieldVariable {
158        let mut borrowed = builder.borrow_mut();
159        let num_limbs = borrowed.num_limbs;
160        let limb_bits = borrowed.limb_bits;
161        borrowed.num_input += 1;
162        let (num_input, max_carry_bits) = (borrowed.num_input, borrowed.max_carry_bits);
163        drop(borrowed);
164        FieldVariable {
165            expr: SymbolicExpr::Input(num_input - 1),
166            builder: builder.clone(),
167            limb_max_abs: (1 << limb_bits) - 1,
168            max_overflow_bits: limb_bits,
169            expr_limbs: num_limbs,
170            max_carry_bits,
171        }
172    }
173
174    pub fn new_flag(&mut self) -> usize {
175        self.num_flags += 1;
176        self.num_flags - 1
177    }
178
179    pub fn needs_setup(&self) -> bool {
180        assert!(self.finalized); // Should only be used after finalize.
181        self.needs_setup
182    }
183
184    // Below functions are used when adding variables and constraints manually, need to be careful.
185    // Number of variables, constraints and computes should be consistent,
186    // so there should be same number of calls to the new_var, add_constraint and add_compute.
187    pub fn new_var(&mut self) -> (usize, SymbolicExpr) {
188        self.num_variables += 1;
189        // Allocate space for the new variable, to make sure they are corresponding to the same
190        // variable index.
191        self.constraints.push(SymbolicExpr::Input(0));
192        self.computes.push(SymbolicExpr::Input(0));
193        self.q_limbs.push(0);
194        self.carry_limbs.push(0);
195        (
196            self.num_variables - 1,
197            SymbolicExpr::Var(self.num_variables - 1),
198        )
199    }
200
201    /// Creates a new constant (compile-time known) FieldVariable from `value` where
202    /// the big integer `value` is decomposed into `num_limbs` limbs of `limb_bits` bits,
203    /// with `num_limbs, limb_bits` specified by the builder config.
204    pub fn new_const(builder: Rc<RefCell<ExprBuilder>>, value: BigUint) -> FieldVariable {
205        let mut borrowed = builder.borrow_mut();
206        let index = borrowed.constants.len();
207        let limb_bits = borrowed.limb_bits;
208        let num_limbs = borrowed.num_limbs;
209        let limbs = big_uint_to_num_limbs(&value, limb_bits, num_limbs);
210        let max_carry_bits = borrowed.max_carry_bits;
211        borrowed.constants.push((value.clone(), limbs));
212        drop(borrowed);
213
214        FieldVariable {
215            expr: SymbolicExpr::Const(index, value, num_limbs),
216            builder,
217            limb_max_abs: (1 << limb_bits) - 1,
218            max_overflow_bits: limb_bits,
219            expr_limbs: num_limbs,
220            max_carry_bits,
221        }
222    }
223
224    pub fn set_constraint(&mut self, index: usize, constraint: SymbolicExpr) {
225        let (q_limbs, carry_limbs) = constraint.constraint_limbs(
226            &self.prime,
227            self.limb_bits,
228            self.num_limbs,
229            &self.proper_max,
230        );
231        self.constraints[index] = constraint;
232        self.q_limbs[index] = q_limbs;
233        self.carry_limbs[index] = carry_limbs;
234    }
235
236    pub fn set_compute(&mut self, index: usize, compute: SymbolicExpr) {
237        self.computes[index] = compute;
238    }
239
240    /// Returns `proper_max = 2^{num_limbs * limb_bits} - 1` as a precomputed value.
241    /// Any proper representation of a positive big integer using `num_limbs` limbs with
242    /// `limb_bits` bits each will be `<= proper_max`.
243    pub fn proper_max(&self) -> &BigUint {
244        &self.proper_max
245    }
246}
247
248#[derive(Clone)]
249pub struct FieldExpr {
250    pub builder: ExprBuilder,
251
252    pub check_carry_mod_to_zero: CheckCarryModToZeroSubAir,
253
254    pub range_bus: VariableRangeCheckerBus,
255
256    // any values other than the prime modulus that need to be checked at setup
257    pub setup_values: Vec<BigUint>,
258}
259
260impl FieldExpr {
261    pub fn new(
262        builder: ExprBuilder,
263        range_bus: VariableRangeCheckerBus,
264        needs_setup: bool,
265    ) -> Self {
266        let mut builder = builder;
267        builder.finalize(needs_setup);
268        let subair = CheckCarryModToZeroSubAir::new(
269            builder.prime.clone(),
270            builder.limb_bits,
271            range_bus.inner.index,
272            range_bus.range_max_bits,
273        );
274        FieldExpr {
275            builder,
276            check_carry_mod_to_zero: subair,
277            range_bus,
278            setup_values: vec![],
279        }
280    }
281
282    pub fn new_with_setup_values(
283        builder: ExprBuilder,
284        range_bus: VariableRangeCheckerBus,
285        needs_setup: bool,
286        setup_values: Vec<BigUint>,
287    ) -> Self {
288        let mut ret = Self::new(builder, range_bus, needs_setup);
289        ret.setup_values = setup_values;
290        ret
291    }
292
293    pub fn num_inputs(&self) -> usize {
294        self.builder.num_input
295    }
296
297    pub fn num_vars(&self) -> usize {
298        self.builder.num_variables
299    }
300
301    pub fn num_flags(&self) -> usize {
302        self.builder.num_flags
303    }
304
305    pub fn output_indices(&self) -> &[usize] {
306        &self.builder.output_indices
307    }
308}
309
310impl Deref for FieldExpr {
311    type Target = ExprBuilder;
312
313    fn deref(&self) -> &ExprBuilder {
314        &self.builder
315    }
316}
317
318impl<F: Field> BaseAirWithPublicValues<F> for FieldExpr {}
319impl<F: Field> PartitionedBaseAir<F> for FieldExpr {}
320impl<F: Field> BaseAir<F> for FieldExpr {
321    fn width(&self) -> usize {
322        assert!(self.builder.is_finalized());
323        self.num_limbs * (self.builder.num_input + self.builder.num_variables)
324            + self.builder.q_limbs.iter().sum::<usize>()
325            + self.builder.carry_limbs.iter().sum::<usize>()
326            + self.builder.num_flags
327            + 1 // is_valid
328    }
329}
330
331impl<AB: InteractionBuilder> Air<AB> for FieldExpr {
332    fn eval(&self, builder: &mut AB) {
333        let main = builder.main();
334        let local = main.row_slice(0);
335        SubAir::eval(self, builder, &local);
336    }
337}
338
339impl<AB: InteractionBuilder> SubAir<AB> for FieldExpr {
340    /// The sub-row slice owned by the expression builder.
341    type AirContext<'a>
342        = &'a [AB::Var]
343    where
344        AB: 'a,
345        AB::Var: 'a,
346        AB::Expr: 'a;
347
348    fn eval<'a>(&'a self, builder: &'a mut AB, local: &'a [AB::Var])
349    where
350        AB::Var: 'a,
351        AB::Expr: 'a,
352    {
353        assert!(self.builder.is_finalized());
354        let FieldExprCols {
355            is_valid,
356            inputs,
357            vars,
358            q_limbs,
359            carry_limbs,
360            flags,
361        } = self.load_vars(local);
362
363        builder.assert_bool(is_valid);
364
365        if self.builder.needs_setup() {
366            let is_setup = flags.iter().fold(is_valid.into(), |acc, &x| acc - x);
367            builder.assert_bool(is_setup.clone());
368            // TODO[jpw]: currently we enforce at the program code level that:
369            // - a valid program must call the correct setup opcodes to be correct
370            // - it would be better if we can constraint this in the circuit, however this has the
371            //   challenge that when the same chip is used across continuation segments, only the
372            //   first segment will have setup called
373
374            let expected = iter::empty()
375                .chain({
376                    let mut prime_limbs = self.builder.prime_limbs.clone();
377                    prime_limbs.resize(self.builder.num_limbs, 0);
378                    prime_limbs
379                })
380                .chain(self.setup_values.iter().flat_map(|x| {
381                    big_uint_to_num_limbs(x, self.builder.limb_bits, self.builder.num_limbs)
382                        .into_iter()
383                }))
384                .collect_vec();
385
386            let reads: Vec<AB::Expr> = inputs
387                .clone()
388                .into_iter()
389                .flatten()
390                .map(Into::into)
391                .take(expected.len())
392                .collect();
393
394            for (lhs, rhs) in zip_eq(&reads, expected) {
395                builder
396                    .when(is_setup.clone())
397                    .assert_eq(lhs.clone(), AB::F::from_canonical_usize(rhs));
398            }
399        }
400
401        let inputs = load_overflow::<AB>(inputs, self.limb_bits);
402        let vars = load_overflow::<AB>(vars, self.limb_bits);
403        let constants: Vec<_> = self
404            .constants
405            .iter()
406            .map(|(_, limbs)| {
407                let limbs_expr: Vec<_> = limbs
408                    .iter()
409                    .map(|limb| AB::Expr::from_canonical_usize(*limb))
410                    .collect();
411                OverflowInt::from_canonical_unsigned_limbs(limbs_expr, self.limb_bits)
412            })
413            .collect();
414
415        for flag in flags.iter() {
416            builder.assert_bool(*flag);
417        }
418        for i in 0..self.constraints.len() {
419            let expr = self.constraints[i]
420                .evaluate_overflow_expr::<AB>(&inputs, &vars, &constants, &flags);
421            self.check_carry_mod_to_zero.eval(
422                builder,
423                (
424                    expr,
425                    CheckCarryModToZeroCols {
426                        carries: carry_limbs[i].clone(),
427                        quotient: q_limbs[i].clone(),
428                    },
429                    is_valid.into(),
430                ),
431            );
432        }
433
434        for var in vars.iter() {
435            for limb in var.limbs().iter() {
436                range_check(
437                    builder,
438                    self.range_bus.inner.index,
439                    self.range_bus.range_max_bits,
440                    self.limb_bits,
441                    limb.clone(),
442                    is_valid,
443                );
444            }
445        }
446    }
447}
448
449type Vecs<T> = Vec<Vec<T>>;
450
451pub struct FieldExprCols<T> {
452    pub is_valid: T,
453    pub inputs: Vecs<T>,
454    pub vars: Vecs<T>,
455    pub q_limbs: Vecs<T>,
456    pub carry_limbs: Vecs<T>,
457    pub flags: Vec<T>,
458}
459
460impl<F: PrimeField64> TraceSubRowGenerator<F> for FieldExpr {
461    type TraceContext<'a> = (&'a VariableRangeCheckerChip, Vec<BigUint>, Vec<bool>);
462    type ColsMut<'a> = &'a mut [F];
463
464    fn generate_subrow<'a>(
465        &'a self,
466        (range_checker, inputs, flags): (&'a VariableRangeCheckerChip, Vec<BigUint>, Vec<bool>),
467        sub_row: &'a mut [F],
468    ) {
469        assert!(self.builder.is_finalized());
470        assert_eq!(inputs.len(), self.num_input);
471        assert_eq!(self.num_variables, self.constraints.len());
472
473        assert_eq!(flags.len(), self.builder.num_flags);
474
475        let limb_bits = self.limb_bits;
476        let mut vars = vec![BigUint::zero(); self.num_variables];
477
478        // BigInt type is required for computing the quotient.
479        let input_bigint = inputs
480            .iter()
481            .map(|x| BigInt::from_biguint(Sign::Plus, x.clone()))
482            .collect::<Vec<BigInt>>();
483        let mut vars_bigint = vec![BigInt::zero(); self.num_variables];
484
485        // OverflowInt type is required for computing the carries.
486        let input_overflow = inputs
487            .iter()
488            .map(|x| OverflowInt::<isize>::from_biguint(x, self.limb_bits, Some(self.num_limbs)))
489            .collect::<Vec<_>>();
490        let zero = OverflowInt::<isize>::from_canonical_unsigned_limbs(vec![0], limb_bits);
491        let mut vars_overflow = vec![zero; self.num_variables];
492        // Note: in cases where the prime fits in less limbs than `num_limbs`, we use the smaller
493        // number of limbs.
494        let prime_overflow = OverflowInt::<isize>::from_biguint(&self.prime, self.limb_bits, None);
495
496        let constants: Vec<_> = self
497            .constants
498            .iter()
499            .map(|(_, limbs)| {
500                let limbs_isize: Vec<_> = limbs.iter().map(|i| *i as isize).collect();
501                OverflowInt::from_canonical_unsigned_limbs(limbs_isize, self.limb_bits)
502            })
503            .collect();
504
505        let mut all_q = vec![];
506        let mut all_carry = vec![];
507        for i in 0..self.constraints.len() {
508            let r = self.computes[i].compute(&inputs, &vars, &flags, &self.prime);
509            vars[i] = r.clone();
510            vars_bigint[i] = BigInt::from_biguint(Sign::Plus, r);
511            vars_overflow[i] =
512                OverflowInt::<isize>::from_biguint(&vars[i], self.limb_bits, Some(self.num_limbs));
513        }
514        // We need to have all variables computed first because, e.g. constraints[2] might need
515        // variables[3].
516        for i in 0..self.constraints.len() {
517            // expr = q * p
518            let expr_bigint =
519                self.constraints[i].evaluate_bigint(&input_bigint, &vars_bigint, &flags);
520            let q = &expr_bigint / &self.prime_bigint;
521            // If this is not true then the evaluated constraint is not divisible by p.
522            debug_assert_eq!(expr_bigint, &q * &self.prime_bigint);
523            let q_limbs = big_int_to_num_limbs(&q, limb_bits, self.q_limbs[i]);
524            assert_eq!(q_limbs.len(), self.q_limbs[i]); // If this fails, the q_limbs estimate is wrong.
525            for &q in q_limbs.iter() {
526                range_checker.add_count((q + (1 << limb_bits)) as u32, limb_bits + 1);
527            }
528            let q_overflow = OverflowInt::from_canonical_signed_limbs(q_limbs.clone(), limb_bits);
529            // compute carries of (expr - q * p)
530            let expr = self.constraints[i].evaluate_overflow_isize(
531                &input_overflow,
532                &vars_overflow,
533                &constants,
534                &flags,
535            );
536            let expr = expr - q_overflow * prime_overflow.clone();
537            let carries = expr.calculate_carries(limb_bits);
538            assert_eq!(carries.len(), self.carry_limbs[i]); // If this fails, the carry limbs estimate is wrong.
539            let max_overflow_bits = expr.max_overflow_bits();
540            let (carry_min_abs, carry_bits) =
541                get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
542            for &carry in carries.iter() {
543                range_checker.add_count((carry + carry_min_abs as isize) as u32, carry_bits);
544            }
545            all_q.push(vec_isize_to_f::<F>(q_limbs));
546            all_carry.push(vec_isize_to_f::<F>(carries));
547        }
548        for var in vars_overflow.iter() {
549            for limb in var.limbs().iter() {
550                range_checker.add_count(*limb as u32, limb_bits);
551            }
552        }
553
554        let input_limbs = input_overflow
555            .iter()
556            .map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
557            .collect::<Vec<_>>();
558        let vars_limbs = vars_overflow
559            .iter()
560            .map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
561            .collect::<Vec<_>>();
562
563        sub_row.copy_from_slice(
564            &[
565                vec![F::ONE],
566                input_limbs.concat(),
567                vars_limbs.concat(),
568                all_q.concat(),
569                all_carry.concat(),
570                flags.iter().map(|x| F::from_bool(*x)).collect::<Vec<_>>(),
571            ]
572            .concat(),
573        );
574    }
575}
576
577impl FieldExpr {
578    pub fn canonical_num_limbs(&self) -> usize {
579        self.builder.num_limbs
580    }
581
582    pub fn canonical_limb_bits(&self) -> usize {
583        self.builder.limb_bits
584    }
585
586    pub fn execute(&self, inputs: Vec<BigUint>, flags: Vec<bool>) -> Vec<BigUint> {
587        assert!(self.builder.is_finalized());
588
589        #[cfg(debug_assertions)]
590        {
591            let is_setup = self.builder.needs_setup() && flags.iter().all(|&x| !x);
592            if is_setup {
593                assert_eq!(inputs[0], self.builder.prime);
594                // Check that inputs.iter().skip(1) has all the setup values as a prefix
595                assert!(inputs.len() > self.setup_values.len());
596                for (expected, actual) in self.setup_values.iter().zip(inputs.iter().skip(1)) {
597                    assert_eq!(expected, actual);
598                }
599            }
600        }
601
602        let mut vars = vec![BigUint::zero(); self.num_variables];
603        for i in 0..self.constraints.len() {
604            let r = self.computes[i].compute(&inputs, &vars, &flags, &self.prime);
605            vars[i] = r.clone();
606        }
607        vars
608    }
609
610    pub fn execute_with_output(&self, inputs: Vec<BigUint>, flags: Vec<bool>) -> Vec<BigUint> {
611        let vars = self.execute(inputs, flags);
612        self.builder
613            .output_indices
614            .iter()
615            .map(|i| vars[*i].clone())
616            .collect()
617    }
618
619    pub fn load_vars<T: Clone>(&self, arr: &[T]) -> FieldExprCols<T> {
620        assert!(self.builder.is_finalized());
621        let is_valid = arr[0].clone();
622        let mut idx = 1;
623        let mut inputs = vec![];
624        for _ in 0..self.num_input {
625            inputs.push(arr[idx..idx + self.num_limbs].to_vec());
626            idx += self.num_limbs;
627        }
628        let mut vars = vec![];
629        for _ in 0..self.num_variables {
630            vars.push(arr[idx..idx + self.num_limbs].to_vec());
631            idx += self.num_limbs;
632        }
633        let mut q_limbs = vec![];
634        for q in self.q_limbs.iter() {
635            q_limbs.push(arr[idx..idx + q].to_vec());
636            idx += q;
637        }
638        let mut carry_limbs = vec![];
639        for c in self.carry_limbs.iter() {
640            carry_limbs.push(arr[idx..idx + c].to_vec());
641            idx += c;
642        }
643        let flags = arr[idx..idx + self.num_flags].to_vec();
644        FieldExprCols {
645            is_valid,
646            inputs,
647            vars,
648            q_limbs,
649            carry_limbs,
650            flags,
651        }
652    }
653}
654
655fn load_overflow<AB: AirBuilder>(
656    arr: Vecs<AB::Var>,
657    limb_bits: usize,
658) -> Vec<OverflowInt<AB::Expr>> {
659    let mut result = vec![];
660    for x in arr.into_iter() {
661        let limbs: Vec<AB::Expr> = x.iter().map(|x| (*x).into()).collect();
662        result.push(OverflowInt::<AB::Expr>::from_canonical_unsigned_limbs(
663            limbs, limb_bits,
664        ));
665    }
666    result
667}