openvm_circuit_primitives/is_less_than_array/
mod.rs

1use itertools::izip;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{
4    interaction::InteractionBuilder,
5    p3_air::AirBuilder,
6    p3_field::{FieldAlgebra, PrimeField32},
7};
8
9use crate::{
10    is_less_than::{IsLtSubAir, LessThanAuxCols},
11    utils::not,
12    var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
13    SubAir, TraceSubRowGenerator,
14};
15
16#[cfg(test)]
17pub mod tests;
18
19#[repr(C)]
20#[derive(Clone, Copy, Debug)]
21pub struct IsLtArrayIo<T, const NUM: usize> {
22    pub x: [T; NUM],
23    pub y: [T; NUM],
24    /// The boolean output, constrained to equal (x < y) when `condition != 0`. The less than comparison
25    /// is done lexicographically.
26    pub out: T,
27    /// Constraints only hold when `count != 0`. When `count == 0`, setting all trace values
28    /// to zero still passes the constraints.
29    /// `count` is **assumed** to be boolean and must be constrained as such by the caller.
30    pub count: T,
31}
32
33#[repr(C)]
34#[derive(AlignedBorrow, Clone, Copy, Debug)]
35pub struct IsLtArrayAuxCols<T, const NUM: usize, const AUX_LEN: usize> {
36    // `diff_marker` is filled with 0 except at the lowest index i such that
37    // `x[i] != y[i]`. If such an `i` exists then it is constrained that `diff_inv = inv(y[i] - x[i])`.
38    pub diff_marker: [T; NUM],
39    pub diff_inv: T,
40    pub lt_aux: LessThanAuxCols<T, AUX_LEN>,
41}
42
43#[derive(Clone, Copy, Debug)]
44pub struct IsLtArrayAuxColsRef<'a, T> {
45    pub diff_marker: &'a [T],
46    pub diff_inv: &'a T,
47    pub lt_decomp: &'a [T],
48}
49
50#[derive(Debug)]
51pub struct IsLtArrayAuxColsMut<'a, T> {
52    pub diff_marker: &'a mut [T],
53    pub diff_inv: &'a mut T,
54    pub lt_decomp: &'a mut [T],
55}
56
57impl<'a, T, const NUM: usize, const AUX_LEN: usize> From<&'a IsLtArrayAuxCols<T, NUM, AUX_LEN>>
58    for IsLtArrayAuxColsRef<'a, T>
59{
60    fn from(value: &'a IsLtArrayAuxCols<T, NUM, AUX_LEN>) -> Self {
61        Self {
62            diff_marker: &value.diff_marker,
63            diff_inv: &value.diff_inv,
64            lt_decomp: &value.lt_aux.lower_decomp,
65        }
66    }
67}
68
69impl<'a, T, const NUM: usize, const AUX_LEN: usize> From<&'a mut IsLtArrayAuxCols<T, NUM, AUX_LEN>>
70    for IsLtArrayAuxColsMut<'a, T>
71{
72    fn from(value: &'a mut IsLtArrayAuxCols<T, NUM, AUX_LEN>) -> Self {
73        Self {
74            diff_marker: &mut value.diff_marker,
75            diff_inv: &mut value.diff_inv,
76            lt_decomp: &mut value.lt_aux.lower_decomp,
77        }
78    }
79}
80
81/// This SubAir constrains the boolean equal to 1 iff `x < y` (lexicographic comparison) assuming
82/// that all elements of both arrays `x, y` each have at most `max_bits` bits.
83///
84/// The constraints will constrain a selector for the first index where `x[i] != y[i]` and then
85/// use [IsLtSubAir] on `x[i], y[i]`.
86///
87/// The expected max constraint degree of `eval` is
88///     deg(count) + max(1, deg(x), deg(y))
89#[derive(Copy, Clone, Debug)]
90pub struct IsLtArraySubAir<const NUM: usize> {
91    pub lt: IsLtSubAir,
92}
93
94impl<const NUM: usize> IsLtArraySubAir<NUM> {
95    pub fn new(bus: VariableRangeCheckerBus, max_bits: usize) -> Self {
96        Self {
97            lt: IsLtSubAir::new(bus, max_bits),
98        }
99    }
100
101    pub fn when_transition(self) -> IsLtArrayWhenTransitionAir<NUM> {
102        IsLtArrayWhenTransitionAir(self)
103    }
104
105    pub fn max_bits(&self) -> usize {
106        self.lt.max_bits
107    }
108
109    pub fn range_max_bits(&self) -> usize {
110        self.lt.range_max_bits()
111    }
112
113    /// Constrain that `out` is boolean equal to `x < y` (lexicographic comparison)
114    /// **without** doing range checks on `lt_decomp`.
115    fn eval_without_range_checks<AB: AirBuilder>(
116        &self,
117        builder: &mut AB,
118        io: IsLtArrayIo<AB::Expr, NUM>,
119        diff_marker: &[AB::Var],
120        diff_inv: AB::Var,
121        lt_decomp: &[AB::Var],
122    ) {
123        assert_eq!(diff_marker.len(), NUM);
124        let mut prefix_sum = AB::Expr::ZERO;
125        let mut diff_val = AB::Expr::ZERO;
126        for (x, y, &marker) in izip!(io.x, io.y, diff_marker) {
127            let diff = y - x;
128            diff_val += diff.clone() * marker.into();
129            prefix_sum += marker.into();
130            builder.assert_bool(marker);
131            builder
132                .when(io.count.clone())
133                .assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
134            builder.when(marker).assert_one(diff * diff_inv);
135        }
136        builder.assert_bool(prefix_sum.clone());
137        // When condition != 0,
138        // - If `x != y`, then `prefix_sum = 1` so marker[i] must be nonzero iff
139        //   i is the first index where `x[i] != y[i]`. Constrains that
140        //   `diff_inv * (y[i] - x[i]) = 1` (`diff_val` is non-zero).
141        // - If `x == y`, then `prefix_sum = 0` and `out == 0` (below)
142        //     - `prefix_sum` cannot be 1 because all diff are zero and it would be impossible to find an inverse.
143
144        builder
145            .when(io.count.clone())
146            .when(not::<AB::Expr>(prefix_sum))
147            .assert_zero(io.out.clone());
148
149        self.lt
150            .eval_without_range_checks(builder, diff_val, io.out, io.count, lt_decomp);
151    }
152}
153
154impl<AB: InteractionBuilder, const NUM: usize> SubAir<AB> for IsLtArraySubAir<NUM> {
155    type AirContext<'a>
156        = (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>)
157    where
158        AB::Expr: 'a,
159        AB::Var: 'a,
160        AB: 'a;
161
162    fn eval<'a>(
163        &'a self,
164        builder: &'a mut AB,
165        (io, aux): (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>),
166    ) where
167        AB::Var: 'a,
168        AB::Expr: 'a,
169    {
170        self.lt
171            .eval_range_checks(builder, aux.lt_decomp, io.count.clone());
172        self.eval_without_range_checks(builder, io, aux.diff_marker, *aux.diff_inv, aux.lt_decomp);
173    }
174}
175
176/// The same subair as [IsLtArraySubAir] except that non-range check
177/// constraints are not imposed on the last row.
178/// Intended use case is for asserting less than between entries in adjacent rows.
179#[derive(Copy, Clone, Debug)]
180pub struct IsLtArrayWhenTransitionAir<const NUM: usize>(pub IsLtArraySubAir<NUM>);
181
182impl<AB: InteractionBuilder, const NUM: usize> SubAir<AB> for IsLtArrayWhenTransitionAir<NUM> {
183    type AirContext<'a>
184        = (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>)
185    where
186        AB::Expr: 'a,
187        AB::Var: 'a,
188        AB: 'a;
189
190    fn eval<'a>(
191        &'a self,
192        builder: &'a mut AB,
193        (io, aux): (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>),
194    ) where
195        AB::Var: 'a,
196        AB::Expr: 'a,
197    {
198        self.0
199            .lt
200            .eval_range_checks(builder, aux.lt_decomp, io.count.clone());
201        self.0.eval_without_range_checks(
202            &mut builder.when_transition(),
203            io,
204            aux.diff_marker,
205            *aux.diff_inv,
206            aux.lt_decomp,
207        );
208    }
209}
210
211impl<F: PrimeField32, const NUM: usize> TraceSubRowGenerator<F> for IsLtArraySubAir<NUM> {
212    /// `(range_checker, x, y)`
213    type TraceContext<'a> = (&'a VariableRangeCheckerChip, &'a [F], &'a [F]);
214    /// `(aux, out)`
215    type ColsMut<'a> = (IsLtArrayAuxColsMut<'a, F>, &'a mut F);
216
217    /// Only use this when `count != 0`.
218    #[inline(always)]
219    fn generate_subrow<'a>(
220        &'a self,
221        (range_checker, x, y): (&'a VariableRangeCheckerChip, &'a [F], &'a [F]),
222        (aux, out): (IsLtArrayAuxColsMut<'a, F>, &'a mut F),
223    ) {
224        tracing::trace!("IsLtArraySubAir::generate_subrow x={:?}, y={:?}", x, y);
225        let mut is_eq = true;
226        let mut diff_val = F::ZERO;
227        *aux.diff_inv = F::ZERO;
228        for (x_i, y_i, diff_marker) in izip!(x, y, aux.diff_marker.iter_mut()) {
229            if x_i != y_i && is_eq {
230                is_eq = false;
231                *diff_marker = F::ONE;
232                diff_val = *y_i - *x_i;
233                *aux.diff_inv = diff_val.inverse();
234            } else {
235                *diff_marker = F::ZERO;
236            }
237        }
238        // diff_val can be "negative" but shifted_diff is in [0, 2^{max_bits+1})
239        let shifted_diff =
240            (diff_val + F::from_canonical_u32((1 << self.max_bits()) - 1)).as_canonical_u32();
241        let lower_u32 = shifted_diff & ((1 << self.max_bits()) - 1);
242        *out = F::from_bool(shifted_diff != lower_u32);
243
244        // decompose lower_u32 into limbs and range check
245        range_checker.decompose(lower_u32, self.max_bits(), aux.lt_decomp);
246    }
247}