openvm_mod_circuit_builder/
field_variable.rs

1use std::{
2    cell::RefCell,
3    cmp::{max, min},
4    ops::{Add, Div, Mul, Sub},
5    rc::Rc,
6};
7
8use openvm_circuit_primitives::bigint::check_carry_to_zero::get_carry_max_abs_and_bits;
9use openvm_stark_backend::p3_util::log2_ceil_usize;
10
11use super::{ExprBuilder, SymbolicExpr};
12
13#[derive(Clone)]
14pub struct FieldVariable {
15    // 1. This will be "reset" to Var(n), when calling save on it.
16    // 2. This is an expression to "compute" (instead of to "constrain")
17    // But it will NOT have division, as it will be auto save and reset.
18    // For example, if we want to compute d = a * b + c, the expr here will be a * b + c
19    // So this is not a constraint that should be equal to zero (a * b + c - d is the constraint).
20    pub expr: SymbolicExpr,
21
22    pub builder: Rc<RefCell<ExprBuilder>>,
23
24    // Limb related information when evaluated as an OverflowInt (vector of limbs).
25    // Max abs of each limb.
26    pub limb_max_abs: usize,
27    // All limbs should be within [-2^max_overflow_bits, 2^max_overflow_bits)
28    // This is log2_ceil(limb_max_abs)
29    pub max_overflow_bits: usize,
30    // Number of limbs to represent the expression.
31    pub expr_limbs: usize,
32
33    // This is the same for all FieldVariable, but we might use different values at runtime,
34    // so store it here for easy configuration.
35    pub max_carry_bits: usize,
36}
37
38impl FieldVariable {
39    // Returns the index of the new variable.
40    // There should be no division in the expression.
41    /// This function is idempotent, i.e., if you already saved, then saving again does nothing.
42    pub fn save(&mut self) -> usize {
43        if let SymbolicExpr::Var(var_id) = self.expr {
44            // If self.expr is already a Var, no need to save
45            return var_id;
46        }
47        let mut builder = self.builder.borrow_mut();
48
49        // Introduce a new variable to replace self.expr.
50        let (new_var_idx, new_var) = builder.new_var();
51        // self.expr - new_var = 0
52        let new_constraint =
53            SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(new_var.clone()));
54        // limbs information.
55        builder.set_constraint(new_var_idx, new_constraint);
56        builder.set_compute(new_var_idx, self.expr.clone());
57
58        self.expr = new_var;
59        self.limb_max_abs = (1 << builder.limb_bits) - 1;
60        self.max_overflow_bits = builder.limb_bits;
61        self.expr_limbs = builder.num_limbs;
62
63        builder.num_variables - 1
64    }
65
66    pub fn save_output(&mut self) {
67        let index = self.save();
68        let mut builder = self.builder.borrow_mut();
69        builder.output_indices.push(index);
70    }
71
72    pub fn canonical_limb_bits(&self) -> usize {
73        self.builder.borrow().limb_bits
74    }
75
76    fn get_q_limbs(expr: SymbolicExpr, builder: &ExprBuilder) -> usize {
77        let constraint_expr = SymbolicExpr::Sub(
78            Box::new(expr),
79            Box::new(SymbolicExpr::Var(builder.num_variables)),
80        );
81        let (q_limbs, _) = constraint_expr.constraint_limbs(
82            &builder.prime,
83            builder.limb_bits,
84            builder.num_limbs,
85            builder.proper_max(),
86        );
87        q_limbs
88    }
89
90    fn save_if_overflow(
91        a: &mut FieldVariable, // will save this variable if overflow
92        expr: SymbolicExpr,    /* the "compute" expression of the result variable. Note that we
93                                * need to check if constraint overflows */
94        limb_max_abs: usize, // The max abs of limbs of compute expression.
95    ) {
96        if let SymbolicExpr::Var(_) = a.expr {
97            return;
98        }
99        let builder = a.builder.borrow();
100        let canonical_limb_bits = builder.limb_bits;
101        let q_limbs = FieldVariable::get_q_limbs(expr, &builder);
102        let canonical_limb_max_abs = (1 << canonical_limb_bits) - 1;
103
104        // The constraint equation is expr - new_var - qp, and we need to check if it overflows.
105        let limb_max_abs = limb_max_abs
106            + canonical_limb_max_abs  // new var
107            + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, builder.num_limbs); // qp
108        drop(builder);
109
110        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
111        let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, canonical_limb_bits);
112        if carry_bits > a.max_carry_bits {
113            a.save();
114        }
115    }
116
117    // TODO[Lun-Kai]: rethink about how should auto-save work.
118    // This implementation requires self and other to be mutable, and might actually mutate them.
119    // This might surprise the caller or introduce hard bug if the caller clone the FieldVariable
120    // and then call this.
121    pub fn add(&mut self, other: &mut FieldVariable) -> FieldVariable {
122        assert!(Rc::ptr_eq(&self.builder, &other.builder));
123        let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
124        FieldVariable::save_if_overflow(
125            self,
126            SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
127            limb_max_fn(self, other),
128        );
129        // Do again to check if the other also needs to be saved.
130        FieldVariable::save_if_overflow(
131            other,
132            SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
133            limb_max_fn(self, other),
134        );
135
136        let limb_max_abs = limb_max_fn(self, other);
137        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
138        FieldVariable {
139            expr: SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
140            builder: self.builder.clone(),
141            limb_max_abs,
142            max_overflow_bits,
143            expr_limbs: max(self.expr_limbs, other.expr_limbs),
144            max_carry_bits: self.max_carry_bits,
145        }
146    }
147
148    pub fn sub(&mut self, other: &mut FieldVariable) -> FieldVariable {
149        assert!(Rc::ptr_eq(&self.builder, &other.builder));
150        let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
151        FieldVariable::save_if_overflow(
152            self,
153            SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
154            limb_max_fn(self, other),
155        );
156        // Do again to check if the other also needs to be saved.
157        FieldVariable::save_if_overflow(
158            other,
159            SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
160            limb_max_fn(self, other),
161        );
162
163        let limb_max_abs = limb_max_fn(self, other);
164        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
165        FieldVariable {
166            expr: SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
167            builder: self.builder.clone(),
168            limb_max_abs,
169            max_overflow_bits,
170            expr_limbs: max(self.expr_limbs, other.expr_limbs),
171            max_carry_bits: self.max_carry_bits,
172        }
173    }
174
175    pub fn mul(&mut self, other: &mut FieldVariable) -> FieldVariable {
176        assert!(Rc::ptr_eq(&self.builder, &other.builder));
177        let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| {
178            a.limb_max_abs * b.limb_max_abs * min(a.expr_limbs, b.expr_limbs)
179        };
180        FieldVariable::save_if_overflow(
181            self,
182            SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
183            limb_max_fn(self, other),
184        );
185        // Do again to check if the other also needs to be saved.
186        FieldVariable::save_if_overflow(
187            other,
188            SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
189            limb_max_fn(self, other),
190        );
191
192        let limb_max_abs = limb_max_fn(self, other);
193        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
194        FieldVariable {
195            expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
196            builder: self.builder.clone(),
197            limb_max_abs,
198            max_overflow_bits,
199            expr_limbs: self.expr_limbs + other.expr_limbs - 1,
200            max_carry_bits: self.max_carry_bits,
201        }
202    }
203
204    pub fn square(&mut self) -> FieldVariable {
205        let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
206        FieldVariable::save_if_overflow(
207            self,
208            SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
209            limb_max_abs,
210        );
211
212        let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
213        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
214        FieldVariable {
215            expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
216            builder: self.builder.clone(),
217            limb_max_abs,
218            max_overflow_bits,
219            expr_limbs: self.expr_limbs * 2 - 1,
220            max_carry_bits: self.max_carry_bits,
221        }
222    }
223
224    pub fn int_add(&mut self, scalar: isize) -> FieldVariable {
225        let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
226        FieldVariable::save_if_overflow(
227            self,
228            SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
229            limb_max_abs,
230        );
231
232        let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
233        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
234        FieldVariable {
235            expr: SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
236            builder: self.builder.clone(),
237            limb_max_abs,
238            max_overflow_bits,
239            expr_limbs: self.expr_limbs,
240            max_carry_bits: self.max_carry_bits,
241        }
242    }
243
244    pub fn int_mul(&mut self, scalar: isize) -> FieldVariable {
245        let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
246        FieldVariable::save_if_overflow(
247            self,
248            SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
249            limb_max_abs,
250        );
251
252        let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
253        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
254        FieldVariable {
255            expr: SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
256            builder: self.builder.clone(),
257            limb_max_abs,
258            max_overflow_bits,
259            expr_limbs: self.expr_limbs,
260            max_carry_bits: self.max_carry_bits,
261        }
262    }
263
264    // expr cannot have division, so auto-save a new variable.
265    // Note that division by zero will panic.
266    pub fn div(&mut self, other: &mut FieldVariable) -> FieldVariable {
267        assert!(Rc::ptr_eq(&self.builder, &other.builder));
268        let builder = self.builder.borrow();
269        let prime = builder.prime.clone();
270        let limb_bits = builder.limb_bits;
271        let num_limbs = builder.num_limbs;
272        let proper_max = builder.proper_max().clone();
273        drop(builder);
274
275        // This is a dummy variable, will be replaced later so the index within it doesn't matter.
276        // We use this to check if we need to save self/other first.
277        let fake_var = SymbolicExpr::Var(0);
278
279        // Constraint: other.expr * new_var - self.expr = 0 (mod p)
280        let new_constraint = SymbolicExpr::Sub(
281            Box::new(SymbolicExpr::Mul(
282                Box::new(other.expr.clone()),
283                Box::new(fake_var.clone()),
284            )),
285            Box::new(self.expr.clone()),
286        );
287        let carry_bits =
288            new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
289        if carry_bits > self.max_carry_bits {
290            self.save();
291        }
292        // Do it again to check if other needs to be saved.
293        let new_constraint = SymbolicExpr::Sub(
294            Box::new(SymbolicExpr::Mul(
295                Box::new(other.expr.clone()),
296                Box::new(fake_var.clone()),
297            )),
298            Box::new(self.expr.clone()),
299        );
300        let carry_bits =
301            new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
302        if carry_bits > self.max_carry_bits {
303            other.save();
304        }
305
306        let mut builder = self.builder.borrow_mut();
307        let (new_var_idx, new_var) = builder.new_var();
308        let new_constraint = SymbolicExpr::Sub(
309            Box::new(SymbolicExpr::Mul(
310                Box::new(other.expr.clone()),
311                Box::new(new_var.clone()),
312            )),
313            Box::new(self.expr.clone()),
314        );
315        builder.set_constraint(new_var_idx, new_constraint);
316        // Only compute can have division.
317        let compute = SymbolicExpr::Div(Box::new(self.expr.clone()), Box::new(other.expr.clone()));
318        builder.set_compute(new_var_idx, compute);
319        drop(builder);
320
321        FieldVariable::from_var(self.builder.clone(), new_var)
322    }
323
324    pub fn from_var(builder: Rc<RefCell<ExprBuilder>>, var: SymbolicExpr) -> FieldVariable {
325        let borrowed_builder = builder.borrow();
326        let max_carry_bits = borrowed_builder.max_carry_bits;
327        assert!(
328            matches!(var, SymbolicExpr::Var(_)),
329            "Expected var to be of type SymbolicExpr::Var"
330        );
331        let num_limbs = borrowed_builder.num_limbs;
332        let canonical_limb_bits = borrowed_builder.limb_bits;
333        drop(borrowed_builder);
334        FieldVariable {
335            expr: var,
336            builder,
337            limb_max_abs: (1 << canonical_limb_bits) - 1,
338            max_overflow_bits: canonical_limb_bits,
339            expr_limbs: num_limbs,
340            max_carry_bits,
341        }
342    }
343
344    pub fn select(flag_id: usize, a: &FieldVariable, b: &FieldVariable) -> FieldVariable {
345        assert!(Rc::ptr_eq(&a.builder, &b.builder));
346        let limb_max_abs = max(a.limb_max_abs, b.limb_max_abs);
347        let max_overflow_bits = max(a.max_overflow_bits, b.max_overflow_bits);
348        let expr_limbs = max(a.expr_limbs, b.expr_limbs);
349        FieldVariable {
350            expr: SymbolicExpr::Select(flag_id, Box::new(a.expr.clone()), Box::new(b.expr.clone())),
351            builder: a.builder.clone(),
352            limb_max_abs,
353            max_overflow_bits,
354            expr_limbs,
355            max_carry_bits: a.max_carry_bits,
356        }
357    }
358}
359
360impl Add<&mut FieldVariable> for &mut FieldVariable {
361    type Output = FieldVariable;
362
363    fn add(self, rhs: &mut FieldVariable) -> Self::Output {
364        self.add(rhs)
365    }
366}
367
368impl Add<FieldVariable> for FieldVariable {
369    type Output = FieldVariable;
370
371    fn add(mut self, mut rhs: FieldVariable) -> Self::Output {
372        let x = &mut self;
373        x.add(&mut rhs)
374    }
375}
376
377impl Sub<FieldVariable> for FieldVariable {
378    type Output = FieldVariable;
379
380    fn sub(mut self, mut rhs: FieldVariable) -> Self::Output {
381        let x = &mut self;
382        x.sub(&mut rhs)
383    }
384}
385
386impl Sub<&mut FieldVariable> for &mut FieldVariable {
387    type Output = FieldVariable;
388
389    fn sub(self, rhs: &mut FieldVariable) -> Self::Output {
390        self.sub(rhs)
391    }
392}
393
394impl Mul<FieldVariable> for FieldVariable {
395    type Output = FieldVariable;
396
397    fn mul(mut self, mut rhs: FieldVariable) -> Self::Output {
398        let x = &mut self;
399        x.mul(&mut rhs)
400    }
401}
402
403impl Mul<&mut FieldVariable> for &mut FieldVariable {
404    type Output = FieldVariable;
405
406    fn mul(self, rhs: &mut FieldVariable) -> Self::Output {
407        FieldVariable::mul(self, rhs)
408    }
409}
410
411// Note that division by zero will panic.
412impl Div<FieldVariable> for FieldVariable {
413    type Output = FieldVariable;
414
415    fn div(mut self, mut rhs: FieldVariable) -> Self::Output {
416        let x = &mut self;
417        x.div(&mut rhs)
418    }
419}
420
421impl Div<&mut FieldVariable> for &mut FieldVariable {
422    type Output = FieldVariable;
423
424    fn div(self, rhs: &mut FieldVariable) -> Self::Output {
425        FieldVariable::div(self, rhs)
426    }
427}