openvm_circuit_primitives/is_equal_array/
mod.rs

1use itertools::izip;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{p3_air::AirBuilder, p3_field::Field};
4
5use crate::{SubAir, TraceSubRowGenerator};
6
7#[cfg(test)]
8mod tests;
9
10#[repr(C)]
11#[derive(Clone, Copy, Debug)]
12pub struct IsEqArrayIo<T, const NUM: usize> {
13    pub x: [T; NUM],
14    pub y: [T; NUM],
15    /// The boolean output, constrained to equal (x == y) when `condition != 0`.
16    pub out: T,
17    /// Constraints only hold when `condition != 0`. When `condition == 0`, setting all trace
18    /// values to zero still passes the constraints.
19    pub condition: T,
20}
21
22#[repr(C)]
23#[derive(AlignedBorrow, Clone, Copy, Debug)]
24pub struct IsEqArrayAuxCols<T, const NUM: usize> {
25    // `diff_inv_marker` is filled with 0 except at the lowest index i such that
26    // `x[i] != y[i]`. If such an `i` exists `diff_inv_marker[i]` is the inverse of `x[i] - y[i]`.
27    pub diff_inv_marker: [T; NUM],
28}
29
30#[derive(Clone, Copy, Debug)]
31pub struct IsEqArraySubAir<const NUM: usize>;
32
33impl<AB: AirBuilder, const NUM: usize> SubAir<AB> for IsEqArraySubAir<NUM> {
34    /// `(io, diff_inv_marker)`
35    type AirContext<'a>
36        = (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM])
37    where
38        AB::Expr: 'a,
39        AB::Var: 'a,
40        AB: 'a;
41
42    /// Constrain that out == (x == y) when condition != 0
43    fn eval<'a>(
44        &'a self,
45        builder: &'a mut AB,
46        (io, diff_inv_marker): (IsEqArrayIo<AB::Expr, NUM>, [AB::Var; NUM]),
47    ) where
48        AB::Var: 'a,
49        AB::Expr: 'a,
50    {
51        let mut sum = io.out.clone();
52        // If x == y: then sum == 1 implies out = 1.
53        // If x != y: then out * (x[i] - y[i]) == 0 implies out = 0.
54        //            to get the sum == 1 to be satisfied, we set diff_inv_marker[i] = (x[i] -
55        // y[i])^{-1} at the first index i such that x[i] != y[i].
56        for (x_i, y_i, inv_marker_i) in izip!(io.x, io.y, diff_inv_marker) {
57            sum += (x_i.clone() - y_i.clone()) * inv_marker_i;
58            builder.assert_zero(io.out.clone() * (x_i - y_i));
59        }
60        builder.when(io.condition).assert_one(sum);
61        builder.assert_bool(io.out);
62    }
63}
64
65impl<F: Field, const NUM: usize> TraceSubRowGenerator<F> for IsEqArraySubAir<NUM> {
66    /// (x, y)
67    type TraceContext<'a> = (&'a [F; NUM], &'a [F; NUM]);
68    /// (diff_inv_marker, out)
69    type ColsMut<'a> = (&'a mut [F; NUM], &'a mut F);
70
71    #[inline(always)]
72    fn generate_subrow<'a>(
73        &'a self,
74        (x, y): (&'a [F; NUM], &'a [F; NUM]),
75        (diff_inv_marker, out): (&'a mut [F; NUM], &'a mut F),
76    ) {
77        let mut is_eq = true;
78        for (x_i, y_i, diff_inv_marker_i) in izip!(x, y, diff_inv_marker) {
79            if x_i != y_i && is_eq {
80                is_eq = false;
81                *diff_inv_marker_i = (*x_i - *y_i).inverse();
82            } else {
83                *diff_inv_marker_i = F::ZERO;
84            }
85        }
86        *out = F::from_bool(is_eq);
87    }
88}