1use std::{
2 cell::RefCell,
3 cmp::{max, min},
4 ops::{Add, Div, Mul, Sub},
5 rc::Rc,
6};
7
8use openvm_circuit_primitives::bigint::check_carry_to_zero::get_carry_max_abs_and_bits;
9use openvm_stark_backend::p3_util::log2_ceil_usize;
10
11use super::{ExprBuilder, SymbolicExpr};
12
13#[derive(Clone)]
14pub struct FieldVariable {
15 pub expr: SymbolicExpr,
21
22 pub builder: Rc<RefCell<ExprBuilder>>,
23
24 pub limb_max_abs: usize,
27 pub max_overflow_bits: usize,
30 pub expr_limbs: usize,
32
33 pub max_carry_bits: usize,
36}
37
38impl FieldVariable {
39 pub fn save(&mut self) -> usize {
43 if let SymbolicExpr::Var(var_id) = self.expr {
44 return var_id;
46 }
47 let mut builder = self.builder.borrow_mut();
48
49 let (new_var_idx, new_var) = builder.new_var();
51 let new_constraint =
53 SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(new_var.clone()));
54 builder.set_constraint(new_var_idx, new_constraint);
56 builder.set_compute(new_var_idx, self.expr.clone());
57
58 self.expr = new_var;
59 self.limb_max_abs = (1 << builder.limb_bits) - 1;
60 self.max_overflow_bits = builder.limb_bits;
61 self.expr_limbs = builder.num_limbs;
62
63 builder.num_variables - 1
64 }
65
66 pub fn save_output(&mut self) {
67 let index = self.save();
68 let mut builder = self.builder.borrow_mut();
69 builder.output_indices.push(index);
70 }
71
72 pub fn canonical_limb_bits(&self) -> usize {
73 self.builder.borrow().limb_bits
74 }
75
76 fn get_q_limbs(expr: SymbolicExpr, builder: &ExprBuilder) -> usize {
77 let constraint_expr = SymbolicExpr::Sub(
78 Box::new(expr),
79 Box::new(SymbolicExpr::Var(builder.num_variables)),
80 );
81 let (q_limbs, _) = constraint_expr.constraint_limbs(
82 &builder.prime,
83 builder.limb_bits,
84 builder.num_limbs,
85 builder.proper_max(),
86 );
87 q_limbs
88 }
89
90 fn save_if_overflow(
91 a: &mut FieldVariable, expr: SymbolicExpr, limb_max_abs: usize, ) {
95 if let SymbolicExpr::Var(_) = a.expr {
96 return;
97 }
98 let builder = a.builder.borrow();
99 let canonical_limb_bits = builder.limb_bits;
100 let q_limbs = FieldVariable::get_q_limbs(expr, &builder);
101 let canonical_limb_max_abs = (1 << canonical_limb_bits) - 1;
102
103 let limb_max_abs = limb_max_abs
105 + canonical_limb_max_abs + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, builder.num_limbs); drop(builder);
108
109 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
110 let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, canonical_limb_bits);
111 if carry_bits > a.max_carry_bits {
112 a.save();
113 }
114 }
115
116 pub fn add(&mut self, other: &mut FieldVariable) -> FieldVariable {
120 assert!(Rc::ptr_eq(&self.builder, &other.builder));
121 let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
122 FieldVariable::save_if_overflow(
123 self,
124 SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
125 limb_max_fn(self, other),
126 );
127 FieldVariable::save_if_overflow(
129 other,
130 SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
131 limb_max_fn(self, other),
132 );
133
134 let limb_max_abs = limb_max_fn(self, other);
135 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
136 FieldVariable {
137 expr: SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
138 builder: self.builder.clone(),
139 limb_max_abs,
140 max_overflow_bits,
141 expr_limbs: max(self.expr_limbs, other.expr_limbs),
142 max_carry_bits: self.max_carry_bits,
143 }
144 }
145
146 pub fn sub(&mut self, other: &mut FieldVariable) -> FieldVariable {
147 assert!(Rc::ptr_eq(&self.builder, &other.builder));
148 let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
149 FieldVariable::save_if_overflow(
150 self,
151 SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
152 limb_max_fn(self, other),
153 );
154 FieldVariable::save_if_overflow(
156 other,
157 SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
158 limb_max_fn(self, other),
159 );
160
161 let limb_max_abs = limb_max_fn(self, other);
162 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
163 FieldVariable {
164 expr: SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
165 builder: self.builder.clone(),
166 limb_max_abs,
167 max_overflow_bits,
168 expr_limbs: max(self.expr_limbs, other.expr_limbs),
169 max_carry_bits: self.max_carry_bits,
170 }
171 }
172
173 pub fn mul(&mut self, other: &mut FieldVariable) -> FieldVariable {
174 assert!(Rc::ptr_eq(&self.builder, &other.builder));
175 let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| {
176 a.limb_max_abs * b.limb_max_abs * min(a.expr_limbs, b.expr_limbs)
177 };
178 FieldVariable::save_if_overflow(
179 self,
180 SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
181 limb_max_fn(self, other),
182 );
183 FieldVariable::save_if_overflow(
185 other,
186 SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
187 limb_max_fn(self, other),
188 );
189
190 let limb_max_abs = limb_max_fn(self, other);
191 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
192 FieldVariable {
193 expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
194 builder: self.builder.clone(),
195 limb_max_abs,
196 max_overflow_bits,
197 expr_limbs: self.expr_limbs + other.expr_limbs - 1,
198 max_carry_bits: self.max_carry_bits,
199 }
200 }
201
202 pub fn square(&mut self) -> FieldVariable {
203 let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
204 FieldVariable::save_if_overflow(
205 self,
206 SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
207 limb_max_abs,
208 );
209
210 let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
211 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
212 FieldVariable {
213 expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
214 builder: self.builder.clone(),
215 limb_max_abs,
216 max_overflow_bits,
217 expr_limbs: self.expr_limbs * 2 - 1,
218 max_carry_bits: self.max_carry_bits,
219 }
220 }
221
222 pub fn int_add(&mut self, scalar: isize) -> FieldVariable {
223 let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
224 FieldVariable::save_if_overflow(
225 self,
226 SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
227 limb_max_abs,
228 );
229
230 let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
231 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
232 FieldVariable {
233 expr: SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
234 builder: self.builder.clone(),
235 limb_max_abs,
236 max_overflow_bits,
237 expr_limbs: self.expr_limbs,
238 max_carry_bits: self.max_carry_bits,
239 }
240 }
241
242 pub fn int_mul(&mut self, scalar: isize) -> FieldVariable {
243 let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
244 FieldVariable::save_if_overflow(
245 self,
246 SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
247 limb_max_abs,
248 );
249
250 let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
251 let max_overflow_bits = log2_ceil_usize(limb_max_abs);
252 FieldVariable {
253 expr: SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
254 builder: self.builder.clone(),
255 limb_max_abs,
256 max_overflow_bits,
257 expr_limbs: self.expr_limbs,
258 max_carry_bits: self.max_carry_bits,
259 }
260 }
261
262 pub fn div(&mut self, other: &mut FieldVariable) -> FieldVariable {
265 assert!(Rc::ptr_eq(&self.builder, &other.builder));
266 let builder = self.builder.borrow();
267 let prime = builder.prime.clone();
268 let limb_bits = builder.limb_bits;
269 let num_limbs = builder.num_limbs;
270 let proper_max = builder.proper_max().clone();
271 drop(builder);
272
273 let fake_var = SymbolicExpr::Var(0);
276
277 let new_constraint = SymbolicExpr::Sub(
279 Box::new(SymbolicExpr::Mul(
280 Box::new(other.expr.clone()),
281 Box::new(fake_var.clone()),
282 )),
283 Box::new(self.expr.clone()),
284 );
285 let carry_bits =
286 new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
287 if carry_bits > self.max_carry_bits {
288 self.save();
289 }
290 let new_constraint = SymbolicExpr::Sub(
292 Box::new(SymbolicExpr::Mul(
293 Box::new(other.expr.clone()),
294 Box::new(fake_var.clone()),
295 )),
296 Box::new(self.expr.clone()),
297 );
298 let carry_bits =
299 new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs, &proper_max);
300 if carry_bits > self.max_carry_bits {
301 other.save();
302 }
303
304 let mut builder = self.builder.borrow_mut();
305 let (new_var_idx, new_var) = builder.new_var();
306 let new_constraint = SymbolicExpr::Sub(
307 Box::new(SymbolicExpr::Mul(
308 Box::new(other.expr.clone()),
309 Box::new(new_var.clone()),
310 )),
311 Box::new(self.expr.clone()),
312 );
313 builder.set_constraint(new_var_idx, new_constraint);
314 let compute = SymbolicExpr::Div(Box::new(self.expr.clone()), Box::new(other.expr.clone()));
316 builder.set_compute(new_var_idx, compute);
317 drop(builder);
318
319 FieldVariable::from_var(self.builder.clone(), new_var)
320 }
321
322 pub fn from_var(builder: Rc<RefCell<ExprBuilder>>, var: SymbolicExpr) -> FieldVariable {
323 let borrowed_builder = builder.borrow();
324 let max_carry_bits = borrowed_builder.max_carry_bits;
325 assert!(
326 matches!(var, SymbolicExpr::Var(_)),
327 "Expected var to be of type SymbolicExpr::Var"
328 );
329 let num_limbs = borrowed_builder.num_limbs;
330 let canonical_limb_bits = borrowed_builder.limb_bits;
331 drop(borrowed_builder);
332 FieldVariable {
333 expr: var,
334 builder,
335 limb_max_abs: (1 << canonical_limb_bits) - 1,
336 max_overflow_bits: canonical_limb_bits,
337 expr_limbs: num_limbs,
338 max_carry_bits,
339 }
340 }
341
342 pub fn select(flag_id: usize, a: &FieldVariable, b: &FieldVariable) -> FieldVariable {
343 assert!(Rc::ptr_eq(&a.builder, &b.builder));
344 let limb_max_abs = max(a.limb_max_abs, b.limb_max_abs);
345 let max_overflow_bits = max(a.max_overflow_bits, b.max_overflow_bits);
346 let expr_limbs = max(a.expr_limbs, b.expr_limbs);
347 FieldVariable {
348 expr: SymbolicExpr::Select(flag_id, Box::new(a.expr.clone()), Box::new(b.expr.clone())),
349 builder: a.builder.clone(),
350 limb_max_abs,
351 max_overflow_bits,
352 expr_limbs,
353 max_carry_bits: a.max_carry_bits,
354 }
355 }
356}
357
358impl Add<&mut FieldVariable> for &mut FieldVariable {
359 type Output = FieldVariable;
360
361 fn add(self, rhs: &mut FieldVariable) -> Self::Output {
362 self.add(rhs)
363 }
364}
365
366impl Add<FieldVariable> for FieldVariable {
367 type Output = FieldVariable;
368
369 fn add(mut self, mut rhs: FieldVariable) -> Self::Output {
370 let x = &mut self;
371 x.add(&mut rhs)
372 }
373}
374
375impl Sub<FieldVariable> for FieldVariable {
376 type Output = FieldVariable;
377
378 fn sub(mut self, mut rhs: FieldVariable) -> Self::Output {
379 let x = &mut self;
380 x.sub(&mut rhs)
381 }
382}
383
384impl Sub<&mut FieldVariable> for &mut FieldVariable {
385 type Output = FieldVariable;
386
387 fn sub(self, rhs: &mut FieldVariable) -> Self::Output {
388 self.sub(rhs)
389 }
390}
391
392impl Mul<FieldVariable> for FieldVariable {
393 type Output = FieldVariable;
394
395 fn mul(mut self, mut rhs: FieldVariable) -> Self::Output {
396 let x = &mut self;
397 x.mul(&mut rhs)
398 }
399}
400
401impl Mul<&mut FieldVariable> for &mut FieldVariable {
402 type Output = FieldVariable;
403
404 fn mul(self, rhs: &mut FieldVariable) -> Self::Output {
405 FieldVariable::mul(self, rhs)
406 }
407}
408
409impl Div<FieldVariable> for FieldVariable {
411 type Output = FieldVariable;
412
413 fn div(mut self, mut rhs: FieldVariable) -> Self::Output {
414 let x = &mut self;
415 x.div(&mut rhs)
416 }
417}
418
419impl Div<&mut FieldVariable> for &mut FieldVariable {
420 type Output = FieldVariable;
421
422 fn div(self, rhs: &mut FieldVariable) -> Self::Output {
423 FieldVariable::div(self, rhs)
424 }
425}