openvm_circuit_primitives/is_equal_array/
mod.rs1use itertools::izip;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{p3_air::AirBuilder, p3_field::Field};
4
5use crate::{SubAir, TraceSubRowGenerator};
6
7#[cfg(test)]
8mod tests;
9
10#[repr(C)]
11#[derive(Clone, Copy, Debug)]
12pub struct IsEqArrayIo<T, const NUM: usize> {
13 pub x: [T; NUM],
14 pub y: [T; NUM],
15 pub out: T,
17 pub condition: T,
20}
21
22#[repr(C)]
23#[derive(AlignedBorrow, Clone, Copy, Debug)]
24pub struct IsEqArrayAuxCols<T, const NUM: usize> {
25 pub diff_inv_marker: [T; NUM],
28}
29
30#[derive(Clone, Copy, Debug)]
31pub struct IsEqArraySubAir<const NUM: usize>;
32
33impl<AB: AirBuilder, const NUM: usize> SubAir<AB> for IsEqArraySubAir<NUM> {
34 type AirContext<'a>
36 = (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM])
37 where
38 AB::Expr: 'a,
39 AB::Var: 'a,
40 AB: 'a;
41
42 fn eval<'a>(
44 &'a self,
45 builder: &'a mut AB,
46 (io, diff_inv_marker): (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM]),
47 ) where
48 AB::Var: 'a,
49 AB::Expr: 'a,
50 {
51 let mut sum = io.out.clone();
52 for (x_i, y_i, inv_marker_i) in izip!(io.x, io.y, diff_inv_marker) {
57 sum += (x_i.clone() - y_i.clone()) * inv_marker_i;
58 builder.assert_zero(io.out.clone() * (x_i - y_i));
59 }
60 builder.when(io.condition).assert_one(sum);
61 builder.assert_bool(io.out);
62 }
63}
64
65impl<F: Field, const NUM: usize> TraceSubRowGenerator<F> for IsEqArraySubAir<NUM> {
66 type TraceContext<'a> = (&'a [F; NUM], &'a [F; NUM]);
68 type ColsMut<'a> = (&'a mut [F; NUM], &'a mut F);
70
71 #[inline(always)]
72 fn generate_subrow<'a>(
73 &'a self,
74 (x, y): (&'a [F; NUM], &'a [F; NUM]),
75 (diff_inv_marker, out): (&'a mut [F; NUM], &'a mut F),
76 ) {
77 let mut is_eq = true;
78 for (x_i, y_i, diff_inv_marker_i) in izip!(x, y, diff_inv_marker) {
79 if x_i != y_i && is_eq {
80 is_eq = false;
81 *diff_inv_marker_i = (*x_i - *y_i).inverse();
82 } else {
83 *diff_inv_marker_i = F::ZERO;
84 }
85 }
86 *out = F::from_bool(is_eq);
87 }
88}