openvm_circuit_primitives/is_equal_array/
mod.rsuse itertools::izip;
use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_stark_backend::{p3_air::AirBuilder, p3_field::Field};
use crate::{SubAir, TraceSubRowGenerator};
#[cfg(test)]
mod tests;
#[repr(C)]
#[derive(Clone, Copy, Debug)]
pub struct IsEqArrayIo<T, const NUM: usize> {
pub x: [T; NUM],
pub y: [T; NUM],
pub out: T,
pub condition: T,
}
#[repr(C)]
#[derive(AlignedBorrow, Clone, Copy, Debug)]
pub struct IsEqArrayAuxCols<T, const NUM: usize> {
pub diff_inv_marker: [T; NUM],
}
#[derive(Clone, Copy, Debug)]
pub struct IsEqArraySubAir<const NUM: usize>;
impl<AB: AirBuilder, const NUM: usize> SubAir<AB> for IsEqArraySubAir<NUM> {
type AirContext<'a>
= (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM])
where
AB::Expr: 'a,
AB::Var: 'a,
AB: 'a;
fn eval<'a>(
&'a self,
builder: &'a mut AB,
(io, diff_inv_marker): (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM]),
) where
AB::Var: 'a,
AB::Expr: 'a,
{
let mut sum = io.out.clone();
for (x_i, y_i, inv_marker_i) in izip!(io.x, io.y, diff_inv_marker) {
sum += (x_i.clone() - y_i.clone()) * inv_marker_i;
builder.assert_zero(io.out.clone() * (x_i - y_i));
}
builder.when(io.condition).assert_one(sum);
builder.assert_bool(io.out);
}
}
impl<F: Field, const NUM: usize> TraceSubRowGenerator<F> for IsEqArraySubAir<NUM> {
type TraceContext<'a> = (&'a [F; NUM], &'a [F; NUM]);
type ColsMut<'a> = (&'a mut [F; NUM], &'a mut F);
#[inline(always)]
fn generate_subrow<'a>(
&'a self,
(x, y): (&'a [F; NUM], &'a [F; NUM]),
(diff_inv_marker, out): (&'a mut [F; NUM], &'a mut F),
) {
let mut is_eq = true;
for (x_i, y_i, diff_inv_marker_i) in izip!(x, y, diff_inv_marker) {
if x_i != y_i && is_eq {
is_eq = false;
*diff_inv_marker_i = (*x_i - *y_i).inverse();
} else {
*diff_inv_marker_i = F::ZERO;
}
}
*out = F::from_bool(is_eq);
}
}