use std::{
cell::RefCell,
cmp::{max, min},
ops::{Add, Div, Mul, Sub},
rc::Rc,
};
use openvm_circuit_primitives::bigint::check_carry_to_zero::get_carry_max_abs_and_bits;
use openvm_stark_backend::p3_util::log2_ceil_usize;
use super::{ExprBuilder, SymbolicExpr};
#[derive(Clone)]
pub struct FieldVariable {
pub expr: SymbolicExpr,
pub builder: Rc<RefCell<ExprBuilder>>,
pub limb_max_abs: usize,
pub max_overflow_bits: usize,
pub expr_limbs: usize,
pub range_checker_bits: usize,
}
impl FieldVariable {
pub fn save(&mut self) -> usize {
if let SymbolicExpr::Var(var_id) = self.expr {
return var_id;
}
let mut builder = self.builder.borrow_mut();
let (new_var_idx, new_var) = builder.new_var();
let new_constraint =
SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(new_var.clone()));
builder.set_constraint(new_var_idx, new_constraint);
builder.set_compute(new_var_idx, self.expr.clone());
self.expr = new_var;
self.limb_max_abs = (1 << builder.limb_bits) - 1;
self.max_overflow_bits = builder.limb_bits;
self.expr_limbs = builder.num_limbs;
builder.num_variables - 1
}
pub fn save_output(&mut self) {
let index = self.save();
let mut builder = self.builder.borrow_mut();
builder.output_indices.push(index);
}
pub fn canonical_limb_bits(&self) -> usize {
self.builder.borrow().limb_bits
}
fn get_q_limbs(expr: SymbolicExpr, builder: &ExprBuilder) -> usize {
let constraint_expr = SymbolicExpr::Sub(
Box::new(expr),
Box::new(SymbolicExpr::Var(builder.num_variables)),
);
let (q_limbs, _) =
constraint_expr.constraint_limbs(&builder.prime, builder.limb_bits, builder.num_limbs);
q_limbs
}
fn save_if_overflow(
a: &mut FieldVariable, expr: SymbolicExpr, limb_max_abs: usize, ) {
if let SymbolicExpr::Var(_) = a.expr {
return;
}
let builder = a.builder.borrow();
let canonical_limb_bits = builder.limb_bits;
let q_limbs = FieldVariable::get_q_limbs(expr, &builder);
let canonical_limb_max_abs = (1 << canonical_limb_bits) - 1;
let limb_max_abs = limb_max_abs
+ canonical_limb_max_abs + canonical_limb_max_abs * canonical_limb_max_abs * min(q_limbs, builder.num_limbs); drop(builder);
let max_overflow_bits = log2_ceil_usize(limb_max_abs);
let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, canonical_limb_bits);
if carry_bits > a.range_checker_bits {
a.save();
}
}
pub fn add(&mut self, other: &mut FieldVariable) -> FieldVariable {
assert!(Rc::ptr_eq(&self.builder, &other.builder));
let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
FieldVariable::save_if_overflow(
self,
SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
limb_max_fn(self, other),
);
FieldVariable::save_if_overflow(
other,
SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
limb_max_fn(self, other),
);
let limb_max_abs = limb_max_fn(self, other);
let max_overflow_bits = log2_ceil_usize(limb_max_abs);
FieldVariable {
expr: SymbolicExpr::Add(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
builder: self.builder.clone(),
limb_max_abs,
max_overflow_bits,
expr_limbs: max(self.expr_limbs, other.expr_limbs),
range_checker_bits: self.range_checker_bits,
}
}
pub fn sub(&mut self, other: &mut FieldVariable) -> FieldVariable {
assert!(Rc::ptr_eq(&self.builder, &other.builder));
let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| a.limb_max_abs + b.limb_max_abs;
FieldVariable::save_if_overflow(
self,
SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
limb_max_fn(self, other),
);
FieldVariable::save_if_overflow(
other,
SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
limb_max_fn(self, other),
);
let limb_max_abs = limb_max_fn(self, other);
let max_overflow_bits = log2_ceil_usize(limb_max_abs);
FieldVariable {
expr: SymbolicExpr::Sub(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
builder: self.builder.clone(),
limb_max_abs,
max_overflow_bits,
expr_limbs: max(self.expr_limbs, other.expr_limbs),
range_checker_bits: self.range_checker_bits,
}
}
pub fn mul(&mut self, other: &mut FieldVariable) -> FieldVariable {
assert!(Rc::ptr_eq(&self.builder, &other.builder));
let limb_max_fn = |a: &FieldVariable, b: &FieldVariable| {
a.limb_max_abs * b.limb_max_abs * min(a.expr_limbs, b.expr_limbs)
};
FieldVariable::save_if_overflow(
self,
SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
limb_max_fn(self, other),
);
FieldVariable::save_if_overflow(
other,
SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
limb_max_fn(self, other),
);
let limb_max_abs = limb_max_fn(self, other);
let max_overflow_bits = log2_ceil_usize(limb_max_abs);
FieldVariable {
expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(other.expr.clone())),
builder: self.builder.clone(),
limb_max_abs,
max_overflow_bits,
expr_limbs: self.expr_limbs + other.expr_limbs - 1,
range_checker_bits: self.range_checker_bits,
}
}
pub fn square(&mut self) -> FieldVariable {
let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
FieldVariable::save_if_overflow(
self,
SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
limb_max_abs,
);
let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
let max_overflow_bits = log2_ceil_usize(limb_max_abs);
FieldVariable {
expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
builder: self.builder.clone(),
limb_max_abs,
max_overflow_bits,
expr_limbs: self.expr_limbs * 2 - 1,
range_checker_bits: self.range_checker_bits,
}
}
pub fn int_add(&mut self, scalar: isize) -> FieldVariable {
let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
FieldVariable::save_if_overflow(
self,
SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
limb_max_abs,
);
let limb_max_abs = self.limb_max_abs + scalar.unsigned_abs();
let max_overflow_bits = log2_ceil_usize(limb_max_abs);
FieldVariable {
expr: SymbolicExpr::IntAdd(Box::new(self.expr.clone()), scalar),
builder: self.builder.clone(),
limb_max_abs,
max_overflow_bits,
expr_limbs: self.expr_limbs,
range_checker_bits: self.range_checker_bits,
}
}
pub fn int_mul(&mut self, scalar: isize) -> FieldVariable {
let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
FieldVariable::save_if_overflow(
self,
SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
limb_max_abs,
);
let limb_max_abs = self.limb_max_abs * scalar.unsigned_abs();
let max_overflow_bits = log2_ceil_usize(limb_max_abs);
FieldVariable {
expr: SymbolicExpr::IntMul(Box::new(self.expr.clone()), scalar),
builder: self.builder.clone(),
limb_max_abs,
max_overflow_bits,
expr_limbs: self.expr_limbs,
range_checker_bits: self.range_checker_bits,
}
}
pub fn div(&mut self, other: &mut FieldVariable) -> FieldVariable {
assert!(Rc::ptr_eq(&self.builder, &other.builder));
let builder = self.builder.borrow();
let prime = builder.prime.clone();
let limb_bits = builder.limb_bits;
let num_limbs = builder.num_limbs;
drop(builder);
let fake_var = SymbolicExpr::Var(0);
let new_constraint = SymbolicExpr::Sub(
Box::new(SymbolicExpr::Mul(
Box::new(other.expr.clone()),
Box::new(fake_var.clone()),
)),
Box::new(self.expr.clone()),
);
let carry_bits = new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs);
if carry_bits > self.range_checker_bits {
self.save();
}
let new_constraint = SymbolicExpr::Sub(
Box::new(SymbolicExpr::Mul(
Box::new(other.expr.clone()),
Box::new(fake_var.clone()),
)),
Box::new(self.expr.clone()),
);
let carry_bits = new_constraint.constraint_carry_bits_with_pq(&prime, limb_bits, num_limbs);
if carry_bits > self.range_checker_bits {
other.save();
}
let mut builder = self.builder.borrow_mut();
let (new_var_idx, new_var) = builder.new_var();
let new_constraint = SymbolicExpr::Sub(
Box::new(SymbolicExpr::Mul(
Box::new(other.expr.clone()),
Box::new(new_var.clone()),
)),
Box::new(self.expr.clone()),
);
builder.set_constraint(new_var_idx, new_constraint);
let compute = SymbolicExpr::Div(Box::new(self.expr.clone()), Box::new(other.expr.clone()));
builder.set_compute(new_var_idx, compute);
drop(builder);
FieldVariable::from_var(self.builder.clone(), new_var)
}
pub fn from_var(builder: Rc<RefCell<ExprBuilder>>, var: SymbolicExpr) -> FieldVariable {
let borrowed_builder = builder.borrow();
let range_checker_bits = borrowed_builder.range_checker_bits;
assert!(
matches!(var, SymbolicExpr::Var(_)),
"Expected var to be of type SymbolicExpr::Var"
);
let num_limbs = borrowed_builder.num_limbs;
let canonical_limb_bits = borrowed_builder.limb_bits;
drop(borrowed_builder);
FieldVariable {
expr: var,
builder,
limb_max_abs: (1 << canonical_limb_bits) - 1,
max_overflow_bits: canonical_limb_bits,
expr_limbs: num_limbs,
range_checker_bits,
}
}
pub fn select(flag_id: usize, a: &FieldVariable, b: &FieldVariable) -> FieldVariable {
assert!(Rc::ptr_eq(&a.builder, &b.builder));
let left_limb_max_abs = max(a.limb_max_abs, b.limb_max_abs);
let left_max_overflow_bits = max(a.max_overflow_bits, b.max_overflow_bits);
let left_expr_limbs = max(a.expr_limbs, b.expr_limbs);
let right_limb_max_abs = left_limb_max_abs;
let right_max_overflow_bits = left_max_overflow_bits;
let right_expr_limbs = left_expr_limbs;
assert_eq!(left_limb_max_abs, right_limb_max_abs);
assert_eq!(left_max_overflow_bits, right_max_overflow_bits);
assert_eq!(left_expr_limbs, right_expr_limbs);
FieldVariable {
expr: SymbolicExpr::Select(flag_id, Box::new(a.expr.clone()), Box::new(b.expr.clone())),
builder: a.builder.clone(),
limb_max_abs: left_limb_max_abs,
max_overflow_bits: left_max_overflow_bits,
expr_limbs: left_expr_limbs,
range_checker_bits: a.range_checker_bits,
}
}
}
impl Add<&mut FieldVariable> for &mut FieldVariable {
type Output = FieldVariable;
fn add(self, rhs: &mut FieldVariable) -> Self::Output {
self.add(rhs)
}
}
impl Add<FieldVariable> for FieldVariable {
type Output = FieldVariable;
fn add(mut self, mut rhs: FieldVariable) -> Self::Output {
let x = &mut self;
x.add(&mut rhs)
}
}
impl Sub<FieldVariable> for FieldVariable {
type Output = FieldVariable;
fn sub(mut self, mut rhs: FieldVariable) -> Self::Output {
let x = &mut self;
x.sub(&mut rhs)
}
}
impl Sub<&mut FieldVariable> for &mut FieldVariable {
type Output = FieldVariable;
fn sub(self, rhs: &mut FieldVariable) -> Self::Output {
self.sub(rhs)
}
}
impl Mul<FieldVariable> for FieldVariable {
type Output = FieldVariable;
fn mul(mut self, mut rhs: FieldVariable) -> Self::Output {
let x = &mut self;
x.mul(&mut rhs)
}
}
impl Mul<&mut FieldVariable> for &mut FieldVariable {
type Output = FieldVariable;
fn mul(self, rhs: &mut FieldVariable) -> Self::Output {
FieldVariable::mul(self, rhs)
}
}
impl Div<FieldVariable> for FieldVariable {
type Output = FieldVariable;
fn div(mut self, mut rhs: FieldVariable) -> Self::Output {
let x = &mut self;
x.div(&mut rhs)
}
}
impl Div<&mut FieldVariable> for &mut FieldVariable {
type Output = FieldVariable;
fn div(self, rhs: &mut FieldVariable) -> Self::Output {
FieldVariable::div(self, rhs)
}
}