openvm_circuit_primitives/assert_less_than/
mod.rs

1use derive_new::new;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{
4    interaction::InteractionBuilder,
5    p3_air::AirBuilder,
6    p3_field::{Field, FieldAlgebra},
7};
8
9use crate::{
10    var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
11    SubAir, TraceSubRowGenerator,
12};
13
14#[cfg(test)]
15pub mod tests;
16
17/// The IO is typically provided with `T = AB::Expr` as external context.
18// This does not derive AlignedBorrow because it is usually **not** going to be
19// direct columns in an AIR.
20#[repr(C)]
21#[derive(Clone, Copy, Debug, Default)]
22pub struct AssertLessThanIo<T> {
23    pub x: T,
24    pub y: T,
25    /// Will only apply constraints when `count != 0`.
26    /// Range checks are done with multiplicity `count`.
27    /// If `count == 0` then no range checks are done.
28    /// `count` **assumed** to be boolean and must be constrained as such by
29    /// the caller.
30    ///
31    /// N.B.: in fact range checks could always be done, if the aux
32    /// subrow values are set to 0 when `count == 0`. This would slightly
33    /// simplify the range check interactions, although usually doesn't change
34    /// the overall constraint degree. It however leads to the annoyance that
35    /// you must update the RangeChecker's multiplicities even on dummy padding
36    /// rows. To improve quality of life,
37    /// we currently use this more complex constraint.
38    pub count: T,
39}
40impl<T> AssertLessThanIo<T> {
41    pub fn new(x: impl Into<T>, y: impl Into<T>, count: impl Into<T>) -> Self {
42        Self {
43            x: x.into(),
44            y: y.into(),
45            count: count.into(),
46        }
47    }
48}
49
50/// These columns are owned by the SubAir. Typically used with `T = AB::Var`.
51/// `AUX_LEN` is the number of AUX columns
52/// we have that AUX_LEN = max_bits.div_ceil(bus.range_max_bits)
53#[repr(C)]
54#[derive(AlignedBorrow, Clone, Copy, Debug, new)]
55pub struct LessThanAuxCols<T, const AUX_LEN: usize> {
56    // lower_decomp consists of lower decomposed into limbs of size bus.range_max_bits
57    // note: the final limb might have less than bus.range_max_bits bits
58    pub lower_decomp: [T; AUX_LEN],
59}
60
61/// This is intended for use as a **SubAir**, not as a standalone Air.
62///
63/// This SubAir constrains that `x < y` when `count != 0`, assuming
64/// the two numbers both have a max number of bits, given by `max_bits`.
65/// The SubAir decomposes `y - x - 1` into limbs of
66/// size `bus.range_max_bits`, and interacts with a
67/// `VariableRangeCheckerBus` to range check the decompositions.
68///
69/// The SubAir will own auxiliary columns to store the decomposed limbs.
70/// The number of limbs is `max_bits.div_ceil(bus.range_max_bits)`.
71///
72/// The expected max constraint degree of `eval` is
73///     deg(count) + max(1, deg(x), deg(y))
74#[derive(Copy, Clone, Debug)]
75pub struct AssertLtSubAir {
76    /// The bus for sends to range chip
77    pub bus: VariableRangeCheckerBus,
78    /// The maximum number of bits for the numbers to compare
79    /// Soundness requirement: max_bits <= 29
80    ///     max_bits > 29 doesn't work: the approach is to check that y-x-1 is non-negative.
81    ///     For a field with prime modular, this is equivalent to checking that y-x-1 is in
82    ///     the range [0, 2^max_bits - 1]. However, for max_bits > 29, if y is small enough
83    ///     and x is large enough, then y-x-1 is negative but can still be in the range due
84    ///     to the field size not being big enough.
85    pub max_bits: usize,
86    /// `decomp_limbs = max_bits.div_ceil(bus.range_max_bits)` is the
87    /// number of limbs that `y - x - 1` will be decomposed into.
88    pub decomp_limbs: usize,
89}
90
91impl AssertLtSubAir {
92    pub fn new(bus: VariableRangeCheckerBus, max_bits: usize) -> Self {
93        assert!(max_bits <= 29); // see soundness requirement above
94        let decomp_limbs = max_bits.div_ceil(bus.range_max_bits);
95        Self {
96            bus,
97            max_bits,
98            decomp_limbs,
99        }
100    }
101
102    pub fn range_max_bits(&self) -> usize {
103        self.bus.range_max_bits
104    }
105
106    /// FOR INTERNAL USE ONLY.
107    /// This AIR is only sound if interactions are enabled
108    ///
109    /// Constraints between `io` and `aux` are only enforced when `count != 0`.
110    /// This means `aux` can be all zero independent on what `io` is by setting `count = 0`.
111    #[inline(always)]
112    fn eval_without_range_checks<AB: AirBuilder>(
113        &self,
114        builder: &mut AB,
115        io: AssertLessThanIo<AB::Expr>,
116        lower_decomp: &[AB::Var],
117    ) {
118        assert_eq!(lower_decomp.len(), self.decomp_limbs);
119        // this is the desired intermediate value (i.e. y - x - 1)
120        // deg(intermed_val) = deg(io)
121        let intermed_val = io.y - io.x - AB::Expr::ONE;
122
123        // Construct lower from lower_decomp:
124        // - each limb of lower_decomp will be range checked
125        // deg(lower) = 1
126        let lower = lower_decomp
127            .iter()
128            .enumerate()
129            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
130                acc + val * AB::Expr::from_canonical_usize(1 << (i * self.range_max_bits()))
131            });
132
133        // constrain that y - x - 1 is equal to the constructed lower value.
134        // this enforces that the intermediate value is in the range [0, 2^max_bits - 1], which is equivalent to x < y
135        builder.when(io.count).assert_eq(intermed_val, lower);
136        // the degree of this constraint is expected to be deg(count) + max(deg(intermed_val), deg(lower))
137        // since we are constraining count * intermed_val == count * lower
138    }
139
140    #[inline(always)]
141    fn eval_range_checks<AB: InteractionBuilder>(
142        &self,
143        builder: &mut AB,
144        lower_decomp: &[AB::Var],
145        count: impl Into<AB::Expr>,
146    ) {
147        let count = count.into();
148        let mut bits_remaining = self.max_bits;
149        // we range check the limbs of the lower_decomp so that we know each element
150        // of lower_decomp has the correct number of bits
151        for limb in lower_decomp {
152            // the last limb might have fewer than `bus.range_max_bits` bits
153            let range_bits = bits_remaining.min(self.range_max_bits());
154            self.bus
155                .range_check(*limb, range_bits)
156                .eval(builder, count.clone());
157            bits_remaining = bits_remaining.saturating_sub(self.range_max_bits());
158        }
159    }
160}
161
162impl<AB: InteractionBuilder> SubAir<AB> for AssertLtSubAir {
163    type AirContext<'a>
164        = (AssertLessThanIo<AB::Expr>, &'a [AB::Var])
165    where
166        AB::Expr: 'a,
167        AB::Var: 'a,
168        AB: 'a;
169
170    // constrain that x < y
171    // warning: send for range check must be included for the constraints to be sound
172    fn eval<'a>(
173        &'a self,
174        builder: &'a mut AB,
175        (io, lower_decomp): (AssertLessThanIo<AB::Expr>, &'a [AB::Var]),
176    ) where
177        AB::Var: 'a,
178        AB::Expr: 'a,
179    {
180        // Note: every AIR that uses this sub-AIR must include the range checks for soundness
181        self.eval_range_checks(builder, lower_decomp, io.count.clone());
182        self.eval_without_range_checks(builder, io, lower_decomp);
183    }
184}
185
186impl<F: Field> TraceSubRowGenerator<F> for AssertLtSubAir {
187    /// (range_checker, x, y)
188    // x, y are u32 because memory records are storing u32 and there would be needless conversions. It also prevents a F: PrimeField32 trait bound.
189    type TraceContext<'a> = (&'a VariableRangeCheckerChip, u32, u32);
190    /// lower_decomp
191    type ColsMut<'a> = &'a mut [F];
192
193    /// Should only be used when `io.count != 0` i.e. only on non-padding rows.
194    #[inline(always)]
195    fn generate_subrow<'a>(
196        &'a self,
197        (range_checker, x, y): (&'a VariableRangeCheckerChip, u32, u32),
198        lower_decomp: &'a mut [F],
199    ) {
200        debug_assert!(x < y, "assert {x} < {y} failed");
201        debug_assert_eq!(lower_decomp.len(), self.decomp_limbs);
202        debug_assert!(
203            x < (1 << self.max_bits),
204            "{x} has more than {} bits",
205            self.max_bits
206        );
207        debug_assert!(
208            y < (1 << self.max_bits),
209            "{y} has more than {} bits",
210            self.max_bits
211        );
212
213        // Note: if x < y then y - x - 1 should already have <= max_bits bits
214        range_checker.decompose(y - x - 1, self.max_bits, lower_decomp);
215    }
216}