use alloc::rc::Rc;
use core::cmp;
use core::fmt::Debug;
use core::iter::{Product, Sum};
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use p3_field::{Field, FieldAlgebra};
use crate::symbolic_variable::SymbolicVariable;
#[derive(Clone, Debug)]
pub enum SymbolicExpression<F> {
Variable(SymbolicVariable<F>),
IsFirstRow,
IsLastRow,
IsTransition,
Constant(F),
Add {
x: Rc<Self>,
y: Rc<Self>,
degree_multiple: usize,
},
Sub {
x: Rc<Self>,
y: Rc<Self>,
degree_multiple: usize,
},
Neg {
x: Rc<Self>,
degree_multiple: usize,
},
Mul {
x: Rc<Self>,
y: Rc<Self>,
degree_multiple: usize,
},
}
impl<F> SymbolicExpression<F> {
pub const fn degree_multiple(&self) -> usize {
match self {
SymbolicExpression::Variable(v) => v.degree_multiple(),
SymbolicExpression::IsFirstRow => 1,
SymbolicExpression::IsLastRow => 1,
SymbolicExpression::IsTransition => 0,
SymbolicExpression::Constant(_) => 0,
SymbolicExpression::Add {
degree_multiple, ..
} => *degree_multiple,
SymbolicExpression::Sub {
degree_multiple, ..
} => *degree_multiple,
SymbolicExpression::Neg {
degree_multiple, ..
} => *degree_multiple,
SymbolicExpression::Mul {
degree_multiple, ..
} => *degree_multiple,
}
}
}
impl<F: Field> Default for SymbolicExpression<F> {
fn default() -> Self {
Self::Constant(F::ZERO)
}
}
impl<F: Field> From<F> for SymbolicExpression<F> {
fn from(value: F) -> Self {
Self::Constant(value)
}
}
impl<F: Field> FieldAlgebra for SymbolicExpression<F> {
type F = F;
const ZERO: Self = Self::Constant(F::ZERO);
const ONE: Self = Self::Constant(F::ONE);
const TWO: Self = Self::Constant(F::TWO);
const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
#[inline]
fn from_f(f: Self::F) -> Self {
f.into()
}
fn from_bool(b: bool) -> Self {
Self::Constant(F::from_bool(b))
}
fn from_canonical_u8(n: u8) -> Self {
Self::Constant(F::from_canonical_u8(n))
}
fn from_canonical_u16(n: u16) -> Self {
Self::Constant(F::from_canonical_u16(n))
}
fn from_canonical_u32(n: u32) -> Self {
Self::Constant(F::from_canonical_u32(n))
}
fn from_canonical_u64(n: u64) -> Self {
Self::Constant(F::from_canonical_u64(n))
}
fn from_canonical_usize(n: usize) -> Self {
Self::Constant(F::from_canonical_usize(n))
}
fn from_wrapped_u32(n: u32) -> Self {
Self::Constant(F::from_wrapped_u32(n))
}
fn from_wrapped_u64(n: u64) -> Self {
Self::Constant(F::from_wrapped_u64(n))
}
}
impl<F: Field, T> Add<T> for SymbolicExpression<F>
where
T: Into<Self>,
{
type Output = Self;
fn add(self, rhs: T) -> Self {
let rhs = rhs.into();
match (self, rhs) {
(Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs + rhs),
(lhs, rhs) => {
let degree_multiple = cmp::max(lhs.degree_multiple(), rhs.degree_multiple());
Self::Add {
x: Rc::new(lhs),
y: Rc::new(rhs),
degree_multiple,
}
}
}
}
}
impl<F: Field, T> AddAssign<T> for SymbolicExpression<F>
where
T: Into<Self>,
{
fn add_assign(&mut self, rhs: T) {
*self = self.clone() + rhs.into();
}
}
impl<F: Field, T> Sum<T> for SymbolicExpression<F>
where
T: Into<Self>,
{
fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
iter.map(Into::into)
.reduce(|x, y| x + y)
.unwrap_or(Self::ZERO)
}
}
impl<F: Field, T> Sub<T> for SymbolicExpression<F>
where
T: Into<Self>,
{
type Output = Self;
fn sub(self, rhs: T) -> Self {
let rhs = rhs.into();
match (self, rhs) {
(Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs - rhs),
(lhs, rhs) => {
let degree_multiple = cmp::max(lhs.degree_multiple(), rhs.degree_multiple());
Self::Sub {
x: Rc::new(lhs),
y: Rc::new(rhs),
degree_multiple,
}
}
}
}
}
impl<F: Field, T> SubAssign<T> for SymbolicExpression<F>
where
T: Into<Self>,
{
fn sub_assign(&mut self, rhs: T) {
*self = self.clone() - rhs.into();
}
}
impl<F: Field> Neg for SymbolicExpression<F> {
type Output = Self;
fn neg(self) -> Self {
match self {
Self::Constant(c) => Self::Constant(-c),
expr => {
let degree_multiple = expr.degree_multiple();
Self::Neg {
x: Rc::new(expr),
degree_multiple,
}
}
}
}
}
impl<F: Field, T> Mul<T> for SymbolicExpression<F>
where
T: Into<Self>,
{
type Output = Self;
fn mul(self, rhs: T) -> Self {
let rhs = rhs.into();
match (self, rhs) {
(Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs * rhs),
(lhs, rhs) => {
#[allow(clippy::suspicious_arithmetic_impl)]
let degree_multiple = lhs.degree_multiple() + rhs.degree_multiple();
Self::Mul {
x: Rc::new(lhs),
y: Rc::new(rhs),
degree_multiple,
}
}
}
}
}
impl<F: Field, T> MulAssign<T> for SymbolicExpression<F>
where
T: Into<Self>,
{
fn mul_assign(&mut self, rhs: T) {
*self = self.clone() * rhs.into();
}
}
impl<F: Field, T> Product<T> for SymbolicExpression<F>
where
T: Into<Self>,
{
fn product<I: Iterator<Item = T>>(iter: I) -> Self {
iter.map(Into::into)
.reduce(|x, y| x * y)
.unwrap_or(Self::ONE)
}
}