openvm_mod_circuit_builder/
field_variable.rs

1use std::{
2    cell::RefCell,
3    cmp::{max, min},
4    ops::{Add, Div, Mul, Sub},
5    rc::Rc,
6};
7
8use openvm_circuit_primitives::bigint::check_carry_to_zero::get_carry_max_abs_and_bits;
9use openvm_stark_backend::p3_util::log2_ceil_usize;
10
11use super::{ExprBuilder, SymbolicExpr};
12
13#[derive(Clone)]
14pub struct FieldVariable {
15    // 1. This will be "reset" to Var(n), when calling save on it.
16    // 2. This is an expression to "compute" (instead of to "constrain")
17    // But it will NOT have division, as it will be auto save and reset.
18    // For example, if we want to compute d = a * b + c, the expr here will be a * b + c
19    // So this is not a constraint that should be equal to zero (a * b + c - d is the constraint).
20    pub expr: SymbolicExpr,
21
22    pub builder: Rc<RefCell<ExprBuilder>>,
23
24    // Limb related information when evaluated as an OverflowInt (vector of limbs).
25    // Max abs of each limb.
26    pub limb_max_abs: usize,
27    // All limbs should be within [-2^max_overflow_bits, 2^max_overflow_bits)
28    // This is log2_ceil(limb_max_abs)
29    pub max_overflow_bits: usize,
30    // Number of limbs to represent the expression.
31    pub expr_limbs: usize,
32
33    // This is the same for all FieldVariable, but we might use different values at runtime,
34    // so store it here for easy configuration.
35    pub max_carry_bits: usize,
36}
37
38impl FieldVariable {
39    // Returns the index of the new variable.
40    // There should be no division in the expression.
41    /// This function is idempotent, i.e., if you already saved, then saving again does nothing.
42    pub fn save(&mut self) -> usize {
43        if let SymbolicExpr::Var(var_id) = self.expr {
44            // If self.expr is already a Var, no need to save
45            return var_id;
46        }
47        let mut builder = self.builder.borrow_mut();
48
49        // Introduce a new variable to replace self.expr.
50        let (new_var_idx, new_var) = builder.new_var();
51        // self.expr - new_var = 0
52        let new_constraint =
53            SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(new_var.clone()));
54        // limbs information.
55        builder.set_constraint(new_var_idx, new_constraint);
56        builder.set_compute(new_var_idx, self.expr.clone());
57
58        self.expr = new_var;
59        self.limb_max_abs = (1 << builder.limb_bits) - 1;
60        self.max_overflow_bits = builder.limb_bits;
61        self.expr_limbs = builder.num_limbs;
62
63        builder.num_variables - 1
64    }
65
66    pub fn save_output(&mut self) {
67        let index = self.save();
68        let mut builder = self.builder.borrow_mut();
69        builder.output_indices.push(index);
70    }
71
72    pub fn canonical_limb_bits(&self) -> usize {
73        self.builder.borrow().limb_bits
74    }
75
76    fn get_q_limbs(expr: SymbolicExpr, builder: &ExprBuilder) -> usize {
77        let constraint_expr = SymbolicExpr::Sub(
78            Box::new(expr),
79            Box::new(SymbolicExpr::Var(builder.num_variables)),
80        );
81        let (q_limbs, _) = constraint_expr.constraint_limbs(
82            &builder.prime,
83            builder.limb_bits,
84            builder.num_limbs,
85            builder.proper_max(),
86        );
87        q_limbs
88    }
89
90    fn save_if_overflow(
91        a: &mut FieldVariable, // will save this variable if overflow
92        expr: SymbolicExpr, // the "compute" expression of the result variable. Note that we need to check if constraint overflows
93        limb_max_abs: usize, // The max abs of limbs of compute expression.
94    ) {
95        if let SymbolicExpr::Var(_) = a.expr {
96            return;
97        }
98        let builder = a.builder.borrow();
99        let canonical_limb_bits = builder.limb_bits;
100        let q_limbs = FieldVariable::get_q_limbs(expr, &builder);
101        let canonical_limb_max_abs = (1 << canonical_limb_bits) - 1;
102
103        // The constraint equation is expr - new_var - qp, and we need to check if it overflows.
104        let limb_max_abs = limb_max_abs
105            + canonical_limb_max_abs  // new var
106            + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, builder.num_limbs); // qp
107        drop(builder);
108
109        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
110        let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, canonical_limb_bits);
111        if carry_bits > a.max_carry_bits {
112            a.save();
113        }
114    }
115
116    // TODO[Lun-Kai]: rethink about how should auto-save work.
117    // This implementation requires self and other to be mutable, and might actually mutate them.
118    // This might surprise the caller or introduce hard bug if the caller clone the FieldVariable and then call this.
119    pub fn add(&mut self, other: &mut FieldVariable) -> FieldVariable {
120        assert!(Rc::ptr_eq(&self.builder, &other.builder));
121        let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
122        FieldVariable::save_if_overflow(
123            self,
124            SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
125            limb_max_fn(self, other),
126        );
127        // Do again to check if the other also needs to be saved.
128        FieldVariable::save_if_overflow(
129            other,
130            SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
131            limb_max_fn(self, other),
132        );
133
134        let limb_max_abs = limb_max_fn(self, other);
135        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
136        FieldVariable {
137            expr: SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
138            builder: self.builder.clone(),
139            limb_max_abs,
140            max_overflow_bits,
141            expr_limbs: max(self.expr_limbs, other.expr_limbs),
142            max_carry_bits: self.max_carry_bits,
143        }
144    }
145
146    pub fn sub(&mut self, other: &mut FieldVariable) -> FieldVariable {
147        assert!(Rc::ptr_eq(&self.builder, &other.builder));
148        let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
149        FieldVariable::save_if_overflow(
150            self,
151            SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
152            limb_max_fn(self, other),
153        );
154        // Do again to check if the other also needs to be saved.
155        FieldVariable::save_if_overflow(
156            other,
157            SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
158            limb_max_fn(self, other),
159        );
160
161        let limb_max_abs = limb_max_fn(self, other);
162        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
163        FieldVariable {
164            expr: SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
165            builder: self.builder.clone(),
166            limb_max_abs,
167            max_overflow_bits,
168            expr_limbs: max(self.expr_limbs, other.expr_limbs),
169            max_carry_bits: self.max_carry_bits,
170        }
171    }
172
173    pub fn mul(&mut self, other: &mut FieldVariable) -> FieldVariable {
174        assert!(Rc::ptr_eq(&self.builder, &other.builder));
175        let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| {
176            a.limb_max_abs * b.limb_max_abs * min(a.expr_limbs, b.expr_limbs)
177        };
178        FieldVariable::save_if_overflow(
179            self,
180            SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
181            limb_max_fn(self, other),
182        );
183        // Do again to check if the other also needs to be saved.
184        FieldVariable::save_if_overflow(
185            other,
186            SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
187            limb_max_fn(self, other),
188        );
189
190        let limb_max_abs = limb_max_fn(self, other);
191        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
192        FieldVariable {
193            expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
194            builder: self.builder.clone(),
195            limb_max_abs,
196            max_overflow_bits,
197            expr_limbs: self.expr_limbs + other.expr_limbs - 1,
198            max_carry_bits: self.max_carry_bits,
199        }
200    }
201
202    pub fn square(&mut self) -> FieldVariable {
203        let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
204        FieldVariable::save_if_overflow(
205            self,
206            SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
207            limb_max_abs,
208        );
209
210        let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
211        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
212        FieldVariable {
213            expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
214            builder: self.builder.clone(),
215            limb_max_abs,
216            max_overflow_bits,
217            expr_limbs: self.expr_limbs * 2 - 1,
218            max_carry_bits: self.max_carry_bits,
219        }
220    }
221
222    pub fn int_add(&mut self, scalar: isize) -> FieldVariable {
223        let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
224        FieldVariable::save_if_overflow(
225            self,
226            SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
227            limb_max_abs,
228        );
229
230        let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
231        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
232        FieldVariable {
233            expr: SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
234            builder: self.builder.clone(),
235            limb_max_abs,
236            max_overflow_bits,
237            expr_limbs: self.expr_limbs,
238            max_carry_bits: self.max_carry_bits,
239        }
240    }
241
242    pub fn int_mul(&mut self, scalar: isize) -> FieldVariable {
243        let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
244        FieldVariable::save_if_overflow(
245            self,
246            SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
247            limb_max_abs,
248        );
249
250        let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
251        let max_overflow_bits = log2_ceil_usize(limb_max_abs);
252        FieldVariable {
253            expr: SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
254            builder: self.builder.clone(),
255            limb_max_abs,
256            max_overflow_bits,
257            expr_limbs: self.expr_limbs,
258            max_carry_bits: self.max_carry_bits,
259        }
260    }
261
262    // expr cannot have division, so auto-save a new variable.
263    // Note that division by zero will panic.
264    pub fn div(&mut self, other: &mut FieldVariable) -> FieldVariable {
265        assert!(Rc::ptr_eq(&self.builder, &other.builder));
266        let builder = self.builder.borrow();
267        let prime = builder.prime.clone();
268        let limb_bits = builder.limb_bits;
269        let num_limbs = builder.num_limbs;
270        let proper_max = builder.proper_max().clone();
271        drop(builder);
272
273        // This is a dummy variable, will be replaced later so the index within it doesn't matter.
274        // We use this to check if we need to save self/other first.
275        let fake_var = SymbolicExpr::Var(0);
276
277        // Constraint: other.expr * new_var - self.expr = 0 (mod p)
278        let new_constraint = SymbolicExpr::Sub(
279            Box::new(SymbolicExpr::Mul(
280                Box::new(other.expr.clone()),
281                Box::new(fake_var.clone()),
282            )),
283            Box::new(self.expr.clone()),
284        );
285        let carry_bits =
286            new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
287        if carry_bits > self.max_carry_bits {
288            self.save();
289        }
290        // Do it again to check if other needs to be saved.
291        let new_constraint = SymbolicExpr::Sub(
292            Box::new(SymbolicExpr::Mul(
293                Box::new(other.expr.clone()),
294                Box::new(fake_var.clone()),
295            )),
296            Box::new(self.expr.clone()),
297        );
298        let carry_bits =
299            new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
300        if carry_bits > self.max_carry_bits {
301            other.save();
302        }
303
304        let mut builder = self.builder.borrow_mut();
305        let (new_var_idx, new_var) = builder.new_var();
306        let new_constraint = SymbolicExpr::Sub(
307            Box::new(SymbolicExpr::Mul(
308                Box::new(other.expr.clone()),
309                Box::new(new_var.clone()),
310            )),
311            Box::new(self.expr.clone()),
312        );
313        builder.set_constraint(new_var_idx, new_constraint);
314        // Only compute can have division.
315        let compute = SymbolicExpr::Div(Box::new(self.expr.clone()), Box::new(other.expr.clone()));
316        builder.set_compute(new_var_idx, compute);
317        drop(builder);
318
319        FieldVariable::from_var(self.builder.clone(), new_var)
320    }
321
322    pub fn from_var(builder: Rc<RefCell<ExprBuilder>>, var: SymbolicExpr) -> FieldVariable {
323        let borrowed_builder = builder.borrow();
324        let max_carry_bits = borrowed_builder.max_carry_bits;
325        assert!(
326            matches!(var, SymbolicExpr::Var(_)),
327            "Expected var to be of type SymbolicExpr::Var"
328        );
329        let num_limbs = borrowed_builder.num_limbs;
330        let canonical_limb_bits = borrowed_builder.limb_bits;
331        drop(borrowed_builder);
332        FieldVariable {
333            expr: var,
334            builder,
335            limb_max_abs: (1 << canonical_limb_bits) - 1,
336            max_overflow_bits: canonical_limb_bits,
337            expr_limbs: num_limbs,
338            max_carry_bits,
339        }
340    }
341
342    pub fn select(flag_id: usize, a: &FieldVariable, b: &FieldVariable) -> FieldVariable {
343        assert!(Rc::ptr_eq(&a.builder, &b.builder));
344        let limb_max_abs = max(a.limb_max_abs, b.limb_max_abs);
345        let max_overflow_bits = max(a.max_overflow_bits, b.max_overflow_bits);
346        let expr_limbs = max(a.expr_limbs, b.expr_limbs);
347        FieldVariable {
348            expr: SymbolicExpr::Select(flag_id, Box::new(a.expr.clone()), Box::new(b.expr.clone())),
349            builder: a.builder.clone(),
350            limb_max_abs,
351            max_overflow_bits,
352            expr_limbs,
353            max_carry_bits: a.max_carry_bits,
354        }
355    }
356}
357
358impl Add<&mut FieldVariable> for &mut FieldVariable {
359    type Output = FieldVariable;
360
361    fn add(self, rhs: &mut FieldVariable) -> Self::Output {
362        self.add(rhs)
363    }
364}
365
366impl Add<FieldVariable> for FieldVariable {
367    type Output = FieldVariable;
368
369    fn add(mut self, mut rhs: FieldVariable) -> Self::Output {
370        let x = &mut self;
371        x.add(&mut rhs)
372    }
373}
374
375impl Sub<FieldVariable> for FieldVariable {
376    type Output = FieldVariable;
377
378    fn sub(mut self, mut rhs: FieldVariable) -> Self::Output {
379        let x = &mut self;
380        x.sub(&mut rhs)
381    }
382}
383
384impl Sub<&mut FieldVariable> for &mut FieldVariable {
385    type Output = FieldVariable;
386
387    fn sub(self, rhs: &mut FieldVariable) -> Self::Output {
388        self.sub(rhs)
389    }
390}
391
392impl Mul<FieldVariable> for FieldVariable {
393    type Output = FieldVariable;
394
395    fn mul(mut self, mut rhs: FieldVariable) -> Self::Output {
396        let x = &mut self;
397        x.mul(&mut rhs)
398    }
399}
400
401impl Mul<&mut FieldVariable> for &mut FieldVariable {
402    type Output = FieldVariable;
403
404    fn mul(self, rhs: &mut FieldVariable) -> Self::Output {
405        FieldVariable::mul(self, rhs)
406    }
407}
408
409// Note that division by zero will panic.
410impl Div<FieldVariable> for FieldVariable {
411    type Output = FieldVariable;
412
413    fn div(mut self, mut rhs: FieldVariable) -> Self::Output {
414        let x = &mut self;
415        x.div(&mut rhs)
416    }
417}
418
419impl Div<&mut FieldVariable> for &mut FieldVariable {
420    type Output = FieldVariable;
421
422    fn div(self, rhs: &mut FieldVariable) -> Self::Output {
423        FieldVariable::div(self, rhs)
424    }
425}