openvm_circuit_primitives/is_equal_array/
mod.rs
1use itertools::izip;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{p3_air::AirBuilder, p3_field::Field};
4
5use crate::{SubAir, TraceSubRowGenerator};
6
7#[cfg(test)]
8mod tests;
9
10#[repr(C)]
11#[derive(Clone, Copy, Debug)]
12pub struct IsEqArrayIo<T, const NUM: usize> {
13 pub x: [T; NUM],
14 pub y: [T; NUM],
15 pub out: T,
17 pub condition: T,
20}
21
22#[repr(C)]
23#[derive(AlignedBorrow, Clone, Copy, Debug)]
24pub struct IsEqArrayAuxCols<T, const NUM: usize> {
25 pub diff_inv_marker: [T; NUM],
28}
29
30#[derive(Clone, Copy, Debug)]
31pub struct IsEqArraySubAir<const NUM: usize>;
32
33impl<AB: AirBuilder, const NUM: usize> SubAir<AB> for IsEqArraySubAir<NUM> {
34 type AirContext<'a>
36 = (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM])
37 where
38 AB::Expr: 'a,
39 AB::Var: 'a,
40 AB: 'a;
41
42 fn eval<'a>(
44 &'a self,
45 builder: &'a mut AB,
46 (io, diff_inv_marker): (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM]),
47 ) where
48 AB::Var: 'a,
49 AB::Expr: 'a,
50 {
51 let mut sum = io.out.clone();
52 for (x_i, y_i, inv_marker_i) in izip!(io.x, io.y, diff_inv_marker) {
56 sum += (x_i.clone() - y_i.clone()) * inv_marker_i;
57 builder.assert_zero(io.out.clone() * (x_i - y_i));
58 }
59 builder.when(io.condition).assert_one(sum);
60 builder.assert_bool(io.out);
61 }
62}
63
64impl<F: Field, const NUM: usize> TraceSubRowGenerator<F> for IsEqArraySubAir<NUM> {
65 type TraceContext<'a> = (&'a [F; NUM], &'a [F; NUM]);
67 type ColsMut<'a> = (&'a mut [F; NUM], &'a mut F);
69
70 #[inline(always)]
71 fn generate_subrow<'a>(
72 &'a self,
73 (x, y): (&'a [F; NUM], &'a [F; NUM]),
74 (diff_inv_marker, out): (&'a mut [F; NUM], &'a mut F),
75 ) {
76 let mut is_eq = true;
77 for (x_i, y_i, diff_inv_marker_i) in izip!(x, y, diff_inv_marker) {
78 if x_i != y_i && is_eq {
79 is_eq = false;
80 *diff_inv_marker_i = (*x_i - *y_i).inverse();
81 } else {
82 *diff_inv_marker_i = F::ZERO;
83 }
84 }
85 *out = F::from_bool(is_eq);
86 }
87}