1use std::{
2 cell::RefCell,
3 cmp::{max, min},
4 ops::{Add, Div, Mul, Sub},
5 rc::Rc,
6};
7
8use openvm_circuit_primitives::bigint::check_carry_to_zero::get_carry_max_abs_and_bits;
9use openvm_stark_backend::p3_util::log2_ceil_usize;
10
11use super::{ExprBuilder, SymbolicExpr};
12
13#[derive(Clone)]
14pub struct FieldVariable {
15 pub expr: SymbolicExpr,
21
22 pub builder: Rc<RefCell<ExprBuilder>>,
23
24 pub limb_max_abs: usize,
27 pub max_overflow_bits: usize,
30 pub expr_limbs: usize,
32
33 pub max_carry_bits: usize,
36}
37
38impl FieldVariable {
39 pub fn save(&mut self) -> usize {
43 if let SymbolicExpr::Var(var_id) = self.expr {
44 return var_id;
46 }
47 let mut builder = self.builder.borrow_mut();
48
49 let (new_var_idx, new_var) = builder.new_var();
51 let new_constraint =
53 SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(new_var.clone()));
54 builder.set_constraint(new_var_idx, new_constraint);
56 builder.set_compute(new_var_idx, self.expr.clone());
57
58 self.expr = new_var;
59 self.limb_max_abs = (1 << builder.limb_bits) - 1;
60 self.max_overflow_bits = builder.limb_bits;
61 self.expr_limbs = builder.num_limbs;
62
63 builder.num_variables - 1
64 }
65
66 pub fn save_output(&mut self) {
67 let index = self.save();
68 let mut builder = self.builder.borrow_mut();
69 builder.output_indices.push(index);
70 }
71
72 pub fn canonical_limb_bits(&self) -> usize {
73 self.builder.borrow().limb_bits
74 }
75
76 fn get_q_limbs(expr: SymbolicExpr, builder: &ExprBuilder) -> usize {
77 let constraint_expr = SymbolicExpr::Sub(
78 Box::new(expr),
79 Box::new(SymbolicExpr::Var(builder.num_variables)),
80 );
81 let (q_limbs, _) = constraint_expr.constraint_limbs(
82 &builder.prime,
83 builder.limb_bits,
84 builder.num_limbs,
85 builder.proper_max(),
86 );
87 q_limbs
88 }
89
90 fn save_if_overflow(
91 a: &mut FieldVariable, expr: SymbolicExpr, limb_max_abs: usize, ) {
96 if let SymbolicExpr::Var(_) = a.expr {
97 return;
98 }
99 let builder = a.builder.borrow();
100 let canonical_limb_bits = builder.limb_bits;
101 let q_limbs = FieldVariable::get_q_limbs(expr, &builder);
102 let canonical_limb_max_abs = (1 << canonical_limb_bits) - 1;
103
104 let limb_max_abs = limb_max_abs
106 + canonical_limb_max_abs + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, builder.num_limbs); drop(builder);
109
110 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
111 let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, canonical_limb_bits);
112 if carry_bits > a.max_carry_bits {
113 a.save();
114 }
115 }
116
117 pub fn add(&mut self, other: &mut FieldVariable) -> FieldVariable {
122 assert!(Rc::ptr_eq(&self.builder, &other.builder));
123 let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
124 FieldVariable::save_if_overflow(
125 self,
126 SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
127 limb_max_fn(self, other),
128 );
129 FieldVariable::save_if_overflow(
131 other,
132 SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
133 limb_max_fn(self, other),
134 );
135
136 let limb_max_abs = limb_max_fn(self, other);
137 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
138 FieldVariable {
139 expr: SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
140 builder: self.builder.clone(),
141 limb_max_abs,
142 max_overflow_bits,
143 expr_limbs: max(self.expr_limbs, other.expr_limbs),
144 max_carry_bits: self.max_carry_bits,
145 }
146 }
147
148 pub fn sub(&mut self, other: &mut FieldVariable) -> FieldVariable {
149 assert!(Rc::ptr_eq(&self.builder, &other.builder));
150 let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
151 FieldVariable::save_if_overflow(
152 self,
153 SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
154 limb_max_fn(self, other),
155 );
156 FieldVariable::save_if_overflow(
158 other,
159 SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
160 limb_max_fn(self, other),
161 );
162
163 let limb_max_abs = limb_max_fn(self, other);
164 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
165 FieldVariable {
166 expr: SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
167 builder: self.builder.clone(),
168 limb_max_abs,
169 max_overflow_bits,
170 expr_limbs: max(self.expr_limbs, other.expr_limbs),
171 max_carry_bits: self.max_carry_bits,
172 }
173 }
174
175 pub fn mul(&mut self, other: &mut FieldVariable) -> FieldVariable {
176 assert!(Rc::ptr_eq(&self.builder, &other.builder));
177 let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| {
178 a.limb_max_abs * b.limb_max_abs * min(a.expr_limbs, b.expr_limbs)
179 };
180 FieldVariable::save_if_overflow(
181 self,
182 SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
183 limb_max_fn(self, other),
184 );
185 FieldVariable::save_if_overflow(
187 other,
188 SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
189 limb_max_fn(self, other),
190 );
191
192 let limb_max_abs = limb_max_fn(self, other);
193 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
194 FieldVariable {
195 expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
196 builder: self.builder.clone(),
197 limb_max_abs,
198 max_overflow_bits,
199 expr_limbs: self.expr_limbs + other.expr_limbs - 1,
200 max_carry_bits: self.max_carry_bits,
201 }
202 }
203
204 pub fn square(&mut self) -> FieldVariable {
205 let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
206 FieldVariable::save_if_overflow(
207 self,
208 SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
209 limb_max_abs,
210 );
211
212 let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
213 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
214 FieldVariable {
215 expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
216 builder: self.builder.clone(),
217 limb_max_abs,
218 max_overflow_bits,
219 expr_limbs: self.expr_limbs * 2 - 1,
220 max_carry_bits: self.max_carry_bits,
221 }
222 }
223
224 pub fn int_add(&mut self, scalar: isize) -> FieldVariable {
225 let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
226 FieldVariable::save_if_overflow(
227 self,
228 SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
229 limb_max_abs,
230 );
231
232 let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
233 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
234 FieldVariable {
235 expr: SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
236 builder: self.builder.clone(),
237 limb_max_abs,
238 max_overflow_bits,
239 expr_limbs: self.expr_limbs,
240 max_carry_bits: self.max_carry_bits,
241 }
242 }
243
244 pub fn int_mul(&mut self, scalar: isize) -> FieldVariable {
245 let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
246 FieldVariable::save_if_overflow(
247 self,
248 SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
249 limb_max_abs,
250 );
251
252 let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
253 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
254 FieldVariable {
255 expr: SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
256 builder: self.builder.clone(),
257 limb_max_abs,
258 max_overflow_bits,
259 expr_limbs: self.expr_limbs,
260 max_carry_bits: self.max_carry_bits,
261 }
262 }
263
264 pub fn div(&mut self, other: &mut FieldVariable) -> FieldVariable {
267 assert!(Rc::ptr_eq(&self.builder, &other.builder));
268 let builder = self.builder.borrow();
269 let prime = builder.prime.clone();
270 let limb_bits = builder.limb_bits;
271 let num_limbs = builder.num_limbs;
272 let proper_max = builder.proper_max().clone();
273 drop(builder);
274
275 let fake_var = SymbolicExpr::Var(0);
278
279 let new_constraint = SymbolicExpr::Sub(
281 Box::new(SymbolicExpr::Mul(
282 Box::new(other.expr.clone()),
283 Box::new(fake_var.clone()),
284 )),
285 Box::new(self.expr.clone()),
286 );
287 let carry_bits =
288 new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
289 if carry_bits > self.max_carry_bits {
290 self.save();
291 }
292 let new_constraint = SymbolicExpr::Sub(
294 Box::new(SymbolicExpr::Mul(
295 Box::new(other.expr.clone()),
296 Box::new(fake_var.clone()),
297 )),
298 Box::new(self.expr.clone()),
299 );
300 let carry_bits =
301 new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
302 if carry_bits > self.max_carry_bits {
303 other.save();
304 }
305
306 let mut builder = self.builder.borrow_mut();
307 let (new_var_idx, new_var) = builder.new_var();
308 let new_constraint = SymbolicExpr::Sub(
309 Box::new(SymbolicExpr::Mul(
310 Box::new(other.expr.clone()),
311 Box::new(new_var.clone()),
312 )),
313 Box::new(self.expr.clone()),
314 );
315 builder.set_constraint(new_var_idx, new_constraint);
316 let compute = SymbolicExpr::Div(Box::new(self.expr.clone()), Box::new(other.expr.clone()));
318 builder.set_compute(new_var_idx, compute);
319 drop(builder);
320
321 FieldVariable::from_var(self.builder.clone(), new_var)
322 }
323
324 pub fn from_var(builder: Rc<RefCell<ExprBuilder>>, var: SymbolicExpr) -> FieldVariable {
325 let borrowed_builder = builder.borrow();
326 let max_carry_bits = borrowed_builder.max_carry_bits;
327 assert!(
328 matches!(var, SymbolicExpr::Var(_)),
329 "Expected var to be of type SymbolicExpr::Var"
330 );
331 let num_limbs = borrowed_builder.num_limbs;
332 let canonical_limb_bits = borrowed_builder.limb_bits;
333 drop(borrowed_builder);
334 FieldVariable {
335 expr: var,
336 builder,
337 limb_max_abs: (1 << canonical_limb_bits) - 1,
338 max_overflow_bits: canonical_limb_bits,
339 expr_limbs: num_limbs,
340 max_carry_bits,
341 }
342 }
343
344 pub fn select(flag_id: usize, a: &FieldVariable, b: &FieldVariable) -> FieldVariable {
345 assert!(Rc::ptr_eq(&a.builder, &b.builder));
346 let limb_max_abs = max(a.limb_max_abs, b.limb_max_abs);
347 let max_overflow_bits = max(a.max_overflow_bits, b.max_overflow_bits);
348 let expr_limbs = max(a.expr_limbs, b.expr_limbs);
349 FieldVariable {
350 expr: SymbolicExpr::Select(flag_id, Box::new(a.expr.clone()), Box::new(b.expr.clone())),
351 builder: a.builder.clone(),
352 limb_max_abs,
353 max_overflow_bits,
354 expr_limbs,
355 max_carry_bits: a.max_carry_bits,
356 }
357 }
358}
359
360impl Add<&mut FieldVariable> for &mut FieldVariable {
361 type Output = FieldVariable;
362
363 fn add(self, rhs: &mut FieldVariable) -> Self::Output {
364 self.add(rhs)
365 }
366}
367
368impl Add<FieldVariable> for FieldVariable {
369 type Output = FieldVariable;
370
371 fn add(mut self, mut rhs: FieldVariable) -> Self::Output {
372 let x = &mut self;
373 x.add(&mut rhs)
374 }
375}
376
377impl Sub<FieldVariable> for FieldVariable {
378 type Output = FieldVariable;
379
380 fn sub(mut self, mut rhs: FieldVariable) -> Self::Output {
381 let x = &mut self;
382 x.sub(&mut rhs)
383 }
384}
385
386impl Sub<&mut FieldVariable> for &mut FieldVariable {
387 type Output = FieldVariable;
388
389 fn sub(self, rhs: &mut FieldVariable) -> Self::Output {
390 self.sub(rhs)
391 }
392}
393
394impl Mul<FieldVariable> for FieldVariable {
395 type Output = FieldVariable;
396
397 fn mul(mut self, mut rhs: FieldVariable) -> Self::Output {
398 let x = &mut self;
399 x.mul(&mut rhs)
400 }
401}
402
403impl Mul<&mut FieldVariable> for &mut FieldVariable {
404 type Output = FieldVariable;
405
406 fn mul(self, rhs: &mut FieldVariable) -> Self::Output {
407 FieldVariable::mul(self, rhs)
408 }
409}
410
411impl Div<FieldVariable> for FieldVariable {
413 type Output = FieldVariable;
414
415 fn div(mut self, mut rhs: FieldVariable) -> Self::Output {
416 let x = &mut self;
417 x.div(&mut rhs)
418 }
419}
420
421impl Div<&mut FieldVariable> for &mut FieldVariable {
422 type Output = FieldVariable;
423
424 fn div(self, rhs: &mut FieldVariable) -> Self::Output {
425 FieldVariable::div(self, rhs)
426 }
427}