openvm_circuit_primitives/is_equal_array/
mod.rs

1use itertools::izip;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{p3_air::AirBuilder, p3_field::Field};
4
5use crate::{SubAir, TraceSubRowGenerator};
6
7#[cfg(test)]
8mod tests;
9
10#[repr(C)]
11#[derive(Clone, Copy, Debug)]
12pub struct IsEqArrayIo<T, const NUM: usize> {
13    pub x: [T; NUM],
14    pub y: [T; NUM],
15    /// The boolean output, constrained to equal (x == y) when `condition != 0`.
16    pub out: T,
17    /// Constraints only hold when `condition != 0`. When `condition == 0`, setting all trace values
18    /// to zero still passes the constraints.
19    pub condition: T,
20}
21
22#[repr(C)]
23#[derive(AlignedBorrow, Clone, Copy, Debug)]
24pub struct IsEqArrayAuxCols<T, const NUM: usize> {
25    // `diff_inv_marker` is filled with 0 except at the lowest index i such that
26    // `x[i] != y[i]`. If such an `i` exists `diff_inv_marker[i]` is the inverse of `x[i] - y[i]`.
27    pub diff_inv_marker: [T; NUM],
28}
29
30#[derive(Clone, Copy, Debug)]
31pub struct IsEqArraySubAir<const NUM: usize>;
32
33impl<AB: AirBuilder, const NUM: usize> SubAir<AB> for IsEqArraySubAir<NUM> {
34    /// `(io, diff_inv_marker)`
35    type AirContext<'a>
36        = (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM])
37    where
38        AB::Expr: 'a,
39        AB::Var: 'a,
40        AB: 'a;
41
42    /// Constrain that out == (x == y) when condition != 0
43    fn eval<'a>(
44        &'a self,
45        builder: &'a mut AB,
46        (io, diff_inv_marker): (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM]),
47    ) where
48        AB::Var: 'a,
49        AB::Expr: 'a,
50    {
51        let mut sum = io.out.clone();
52        // If x == y: then sum == 1 implies out = 1.
53        // If x != y: then out * (x[i] - y[i]) == 0 implies out = 0.
54        //            to get the sum == 1 to be satisfied, we set diff_inv_marker[i] = (x[i] - y[i])^{-1} at the first index i such that x[i] != y[i].
55        for (x_i, y_i, inv_marker_i) in izip!(io.x, io.y, diff_inv_marker) {
56            sum += (x_i.clone() - y_i.clone()) * inv_marker_i;
57            builder.assert_zero(io.out.clone() * (x_i - y_i));
58        }
59        builder.when(io.condition).assert_one(sum);
60        builder.assert_bool(io.out);
61    }
62}
63
64impl<F: Field, const NUM: usize> TraceSubRowGenerator<F> for IsEqArraySubAir<NUM> {
65    /// (x, y)
66    type TraceContext<'a> = (&'a [F; NUM], &'a [F; NUM]);
67    /// (diff_inv_marker, out)
68    type ColsMut<'a> = (&'a mut [F; NUM], &'a mut F);
69
70    #[inline(always)]
71    fn generate_subrow<'a>(
72        &'a self,
73        (x, y): (&'a [F; NUM], &'a [F; NUM]),
74        (diff_inv_marker, out): (&'a mut [F; NUM], &'a mut F),
75    ) {
76        let mut is_eq = true;
77        for (x_i, y_i, diff_inv_marker_i) in izip!(x, y, diff_inv_marker) {
78            if x_i != y_i && is_eq {
79                is_eq = false;
80                *diff_inv_marker_i = (*x_i - *y_i).inverse();
81            } else {
82                *diff_inv_marker_i = F::ZERO;
83            }
84        }
85        *out = F::from_bool(is_eq);
86    }
87}