openvm_circuit_primitives/is_less_than_array/
mod.rs

1use itertools::izip;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{
4    interaction::InteractionBuilder,
5    p3_air::AirBuilder,
6    p3_field::{FieldAlgebra, PrimeField32},
7};
8
9use crate::{
10    is_less_than::{IsLtSubAir, LessThanAuxCols},
11    utils::not,
12    var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
13    SubAir, TraceSubRowGenerator,
14};
15
16#[cfg(test)]
17pub mod tests;
18
19#[repr(C)]
20#[derive(Clone, Copy, Debug)]
21pub struct IsLtArrayIo<T, const NUM: usize> {
22    pub x: [T; NUM],
23    pub y: [T; NUM],
24    /// The boolean output, constrained to equal (x < y) when `condition != 0`. The less than
25    /// comparison is done lexicographically.
26    pub out: T,
27    /// Constraints only hold when `count != 0`. When `count == 0`, setting all trace values
28    /// to zero still passes the constraints.
29    /// `count` is **assumed** to be boolean and must be constrained as such by the caller.
30    pub count: T,
31}
32
33#[repr(C)]
34#[derive(AlignedBorrow, Clone, Copy, Debug)]
35pub struct IsLtArrayAuxCols<T, const NUM: usize, const AUX_LEN: usize> {
36    // `diff_marker` is filled with 0 except at the lowest index i such that
37    // `x[i] != y[i]`. If such an `i` exists then it is constrained that `diff_inv = inv(y[i] -
38    // x[i])`.
39    pub diff_marker: [T; NUM],
40    pub diff_inv: T,
41    pub lt_aux: LessThanAuxCols<T, AUX_LEN>,
42}
43
44#[derive(Clone, Copy, Debug)]
45pub struct IsLtArrayAuxColsRef<'a, T> {
46    pub diff_marker: &'a [T],
47    pub diff_inv: &'a T,
48    pub lt_decomp: &'a [T],
49}
50
51#[derive(Debug)]
52pub struct IsLtArrayAuxColsMut<'a, T> {
53    pub diff_marker: &'a mut [T],
54    pub diff_inv: &'a mut T,
55    pub lt_decomp: &'a mut [T],
56}
57
58impl<'a, T, const NUM: usize, const AUX_LEN: usize> From<&'a IsLtArrayAuxCols<T, NUM, AUX_LEN>>
59    for IsLtArrayAuxColsRef<'a, T>
60{
61    fn from(value: &'a IsLtArrayAuxCols<T, NUM, AUX_LEN>) -> Self {
62        Self {
63            diff_marker: &value.diff_marker,
64            diff_inv: &value.diff_inv,
65            lt_decomp: &value.lt_aux.lower_decomp,
66        }
67    }
68}
69
70impl<'a, T, const NUM: usize, const AUX_LEN: usize> From<&'a mut IsLtArrayAuxCols<T, NUM, AUX_LEN>>
71    for IsLtArrayAuxColsMut<'a, T>
72{
73    fn from(value: &'a mut IsLtArrayAuxCols<T, NUM, AUX_LEN>) -> Self {
74        Self {
75            diff_marker: &mut value.diff_marker,
76            diff_inv: &mut value.diff_inv,
77            lt_decomp: &mut value.lt_aux.lower_decomp,
78        }
79    }
80}
81
82/// This SubAir constrains the boolean equal to 1 iff `x < y` (lexicographic comparison) assuming
83/// that all elements of both arrays `x, y` each have at most `max_bits` bits.
84///
85/// The constraints will constrain a selector for the first index where `x[i] != y[i]` and then
86/// use [IsLtSubAir] on `x[i], y[i]`.
87///
88/// The expected max constraint degree of `eval` is
89///     deg(count) + max(1, deg(x), deg(y))
90#[derive(Copy, Clone, Debug)]
91pub struct IsLtArraySubAir<const NUM: usize> {
92    pub lt: IsLtSubAir,
93}
94
95impl<const NUM: usize> IsLtArraySubAir<NUM> {
96    pub fn new(bus: VariableRangeCheckerBus, max_bits: usize) -> Self {
97        Self {
98            lt: IsLtSubAir::new(bus, max_bits),
99        }
100    }
101
102    pub fn when_transition(self) -> IsLtArrayWhenTransitionAir<NUM> {
103        IsLtArrayWhenTransitionAir(self)
104    }
105
106    pub fn max_bits(&self) -> usize {
107        self.lt.max_bits
108    }
109
110    pub fn range_max_bits(&self) -> usize {
111        self.lt.range_max_bits()
112    }
113
114    /// Constrain that `out` is boolean equal to `x < y` (lexicographic comparison)
115    /// **without** doing range checks on `lt_decomp`.
116    fn eval_without_range_checks<AB: AirBuilder>(
117        &self,
118        builder: &mut AB,
119        io: IsLtArrayIo<AB::Expr, NUM>,
120        diff_marker: &[AB::Var],
121        diff_inv: AB::Var,
122        lt_decomp: &[AB::Var],
123    ) {
124        assert_eq!(diff_marker.len(), NUM);
125        let mut prefix_sum = AB::Expr::ZERO;
126        let mut diff_val = AB::Expr::ZERO;
127        for (x, y, &marker) in izip!(io.x, io.y, diff_marker) {
128            let diff = y - x;
129            diff_val += diff.clone() * marker.into();
130            prefix_sum += marker.into();
131            builder.assert_bool(marker);
132            builder
133                .when(io.count.clone())
134                .assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
135            builder.when(marker).assert_one(diff * diff_inv);
136        }
137        builder.assert_bool(prefix_sum.clone());
138        // When condition != 0,
139        // - If `x != y`, then `prefix_sum = 1` so marker[i] must be nonzero iff i is the first
140        //   index where `x[i] != y[i]`. Constrains that `diff_inv * (y[i] - x[i]) = 1` (`diff_val`
141        //   is non-zero).
142        // - If `x == y`, then `prefix_sum = 0` and `out == 0` (below)
143        //     - `prefix_sum` cannot be 1 because all diff are zero and it would be impossible to
144        //       find an inverse.
145
146        builder
147            .when(io.count.clone())
148            .when(not::<AB::Expr>(prefix_sum))
149            .assert_zero(io.out.clone());
150
151        self.lt
152            .eval_without_range_checks(builder, diff_val, io.out, io.count, lt_decomp);
153    }
154}
155
156impl<AB: InteractionBuilder, const NUM: usize> SubAir<AB> for IsLtArraySubAir<NUM> {
157    type AirContext<'a>
158        = (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>)
159    where
160        AB::Expr: 'a,
161        AB::Var: 'a,
162        AB: 'a;
163
164    fn eval<'a>(
165        &'a self,
166        builder: &'a mut AB,
167        (io, aux): (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>),
168    ) where
169        AB::Var: 'a,
170        AB::Expr: 'a,
171    {
172        self.lt
173            .eval_range_checks(builder, aux.lt_decomp, io.count.clone());
174        self.eval_without_range_checks(builder, io, aux.diff_marker, *aux.diff_inv, aux.lt_decomp);
175    }
176}
177
178/// The same subair as [IsLtArraySubAir] except that non-range check
179/// constraints are not imposed on the last row.
180/// Intended use case is for asserting less than between entries in adjacent rows.
181#[derive(Copy, Clone, Debug)]
182pub struct IsLtArrayWhenTransitionAir<const NUM: usize>(pub IsLtArraySubAir<NUM>);
183
184impl<AB: InteractionBuilder, const NUM: usize> SubAir<AB> for IsLtArrayWhenTransitionAir<NUM> {
185    type AirContext<'a>
186        = (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>)
187    where
188        AB::Expr: 'a,
189        AB::Var: 'a,
190        AB: 'a;
191
192    fn eval<'a>(
193        &'a self,
194        builder: &'a mut AB,
195        (io, aux): (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>),
196    ) where
197        AB::Var: 'a,
198        AB::Expr: 'a,
199    {
200        self.0
201            .lt
202            .eval_range_checks(builder, aux.lt_decomp, io.count.clone());
203        self.0.eval_without_range_checks(
204            &mut builder.when_transition(),
205            io,
206            aux.diff_marker,
207            *aux.diff_inv,
208            aux.lt_decomp,
209        );
210    }
211}
212
213impl<F: PrimeField32, const NUM: usize> TraceSubRowGenerator<F> for IsLtArraySubAir<NUM> {
214    /// `(range_checker, x, y)`
215    type TraceContext<'a> = (&'a VariableRangeCheckerChip, &'a [F], &'a [F]);
216    /// `(aux, out)`
217    type ColsMut<'a> = (IsLtArrayAuxColsMut<'a, F>, &'a mut F);
218
219    /// Only use this when `count != 0`.
220    #[inline(always)]
221    fn generate_subrow<'a>(
222        &'a self,
223        (range_checker, x, y): (&'a VariableRangeCheckerChip, &'a [F], &'a [F]),
224        (aux, out): (IsLtArrayAuxColsMut<'a, F>, &'a mut F),
225    ) {
226        tracing::trace!("IsLtArraySubAir::generate_subrow x={:?}, y={:?}", x, y);
227        let mut is_eq = true;
228        let mut diff_val = F::ZERO;
229        *aux.diff_inv = F::ZERO;
230        for (x_i, y_i, diff_marker) in izip!(x, y, aux.diff_marker.iter_mut()) {
231            if x_i != y_i && is_eq {
232                is_eq = false;
233                *diff_marker = F::ONE;
234                diff_val = *y_i - *x_i;
235                *aux.diff_inv = diff_val.inverse();
236            } else {
237                *diff_marker = F::ZERO;
238            }
239        }
240        // diff_val can be "negative" but shifted_diff is in [0, 2^{max_bits+1})
241        let shifted_diff =
242            (diff_val + F::from_canonical_u32((1 << self.max_bits()) - 1)).as_canonical_u32();
243        let lower_u32 = shifted_diff & ((1 << self.max_bits()) - 1);
244        *out = F::from_bool(shifted_diff != lower_u32);
245
246        // decompose lower_u32 into limbs and range check
247        range_checker.decompose(lower_u32, self.max_bits(), aux.lt_decomp);
248    }
249}