openvm_circuit_primitives/assert_less_than/
mod.rs

1use derive_new::new;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{
4    interaction::InteractionBuilder,
5    p3_air::AirBuilder,
6    p3_field::{Field, FieldAlgebra},
7};
8
9use crate::{
10    var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
11    SubAir, TraceSubRowGenerator,
12};
13
14#[cfg(test)]
15pub mod tests;
16
17/// The IO is typically provided with `T = AB::Expr` as external context.
18// This does not derive AlignedBorrow because it is usually **not** going to be
19// direct columns in an AIR.
20#[repr(C)]
21#[derive(Clone, Copy, Debug, Default)]
22pub struct AssertLessThanIo<T> {
23    pub x: T,
24    pub y: T,
25    /// Will only apply constraints when `count != 0`.
26    /// Range checks are done with multiplicity `count`.
27    /// If `count == 0` then no range checks are done.
28    /// `count` **assumed** to be boolean and must be constrained as such by
29    /// the caller.
30    ///
31    /// N.B.: in fact range checks could always be done, if the aux
32    /// subrow values are set to 0 when `count == 0`. This would slightly
33    /// simplify the range check interactions, although usually doesn't change
34    /// the overall constraint degree. It however leads to the annoyance that
35    /// you must update the RangeChecker's multiplicities even on dummy padding
36    /// rows. To improve quality of life,
37    /// we currently use this more complex constraint.
38    pub count: T,
39}
40impl<T> AssertLessThanIo<T> {
41    pub fn new(x: impl Into<T>, y: impl Into<T>, count: impl Into<T>) -> Self {
42        Self {
43            x: x.into(),
44            y: y.into(),
45            count: count.into(),
46        }
47    }
48}
49
50/// These columns are owned by the SubAir. Typically used with `T = AB::Var`.
51/// `AUX_LEN` is the number of AUX columns
52/// we have that AUX_LEN = max_bits.div_ceil(bus.range_max_bits)
53#[repr(C)]
54#[derive(AlignedBorrow, Clone, Copy, Debug, new)]
55pub struct LessThanAuxCols<T, const AUX_LEN: usize> {
56    // lower_decomp consists of lower decomposed into limbs of size bus.range_max_bits
57    // note: the final limb might have less than bus.range_max_bits bits
58    pub lower_decomp: [T; AUX_LEN],
59}
60
61/// This is intended for use as a **SubAir**, not as a standalone Air.
62///
63/// This SubAir constrains that `x < y` when `count != 0`, assuming
64/// the two numbers both have a max number of bits, given by `max_bits`.
65/// The SubAir decomposes `y - x - 1` into limbs of
66/// size `bus.range_max_bits`, and interacts with a
67/// `VariableRangeCheckerBus` to range check the decompositions.
68///
69/// The SubAir will own auxiliary columns to store the decomposed limbs.
70/// The number of limbs is `max_bits.div_ceil(bus.range_max_bits)`.
71///
72/// The expected max constraint degree of `eval` is
73///     deg(count) + max(1, deg(x), deg(y))
74#[derive(Copy, Clone, Debug)]
75pub struct AssertLtSubAir {
76    /// The bus for sends to range chip
77    pub bus: VariableRangeCheckerBus,
78    /// The maximum number of bits for the numbers to compare
79    /// Soundness requirement: max_bits <= 29
80    ///     max_bits > 29 doesn't work: the approach is to check that y-x-1 is non-negative.
81    ///     For a field with prime modular, this is equivalent to checking that y-x-1 is in
82    ///     the range [0, 2^max_bits - 1]. However, for max_bits > 29, if y is small enough
83    ///     and x is large enough, then y-x-1 is negative but can still be in the range due
84    ///     to the field size not being big enough.
85    pub max_bits: usize,
86    /// `decomp_limbs = max_bits.div_ceil(bus.range_max_bits)` is the
87    /// number of limbs that `y - x - 1` will be decomposed into.
88    pub decomp_limbs: usize,
89}
90
91impl AssertLtSubAir {
92    pub fn new(bus: VariableRangeCheckerBus, max_bits: usize) -> Self {
93        assert!(max_bits <= 29); // see soundness requirement above
94        let decomp_limbs = max_bits.div_ceil(bus.range_max_bits);
95        Self {
96            bus,
97            max_bits,
98            decomp_limbs,
99        }
100    }
101
102    pub fn range_max_bits(&self) -> usize {
103        self.bus.range_max_bits
104    }
105
106    /// FOR INTERNAL USE ONLY.
107    /// This AIR is only sound if interactions are enabled
108    ///
109    /// Constraints between `io` and `aux` are only enforced when `count != 0`.
110    /// This means `aux` can be all zero independent on what `io` is by setting `count = 0`.
111    #[inline(always)]
112    fn eval_without_range_checks<AB: AirBuilder>(
113        &self,
114        builder: &mut AB,
115        io: AssertLessThanIo<AB::Expr>,
116        lower_decomp: &[AB::Var],
117    ) {
118        assert_eq!(lower_decomp.len(), self.decomp_limbs);
119        // this is the desired intermediate value (i.e. y - x - 1)
120        // deg(intermed_val) = deg(io)
121        let intermed_val = io.y - io.x - AB::Expr::ONE;
122
123        // Construct lower from lower_decomp:
124        // - each limb of lower_decomp will be range checked
125        // deg(lower) = 1
126        let lower = lower_decomp
127            .iter()
128            .enumerate()
129            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
130                acc + val * AB::Expr::from_canonical_usize(1 << (i * self.range_max_bits()))
131            });
132
133        // constrain that y - x - 1 is equal to the constructed lower value.
134        // this enforces that the intermediate value is in the range [0, 2^max_bits - 1], which is
135        // equivalent to x < y
136        builder.when(io.count).assert_eq(intermed_val, lower);
137        // the degree of this constraint is expected to be deg(count) + max(deg(intermed_val),
138        // deg(lower)) since we are constraining count * intermed_val == count * lower
139    }
140
141    #[inline(always)]
142    fn eval_range_checks<AB: InteractionBuilder>(
143        &self,
144        builder: &mut AB,
145        lower_decomp: &[AB::Var],
146        count: impl Into<AB::Expr>,
147    ) {
148        let count = count.into();
149        let mut bits_remaining = self.max_bits;
150        // we range check the limbs of the lower_decomp so that we know each element
151        // of lower_decomp has the correct number of bits
152        for limb in lower_decomp {
153            // the last limb might have fewer than `bus.range_max_bits` bits
154            let range_bits = bits_remaining.min(self.range_max_bits());
155            self.bus
156                .range_check(*limb, range_bits)
157                .eval(builder, count.clone());
158            bits_remaining = bits_remaining.saturating_sub(self.range_max_bits());
159        }
160    }
161}
162
163impl<AB: InteractionBuilder> SubAir<AB> for AssertLtSubAir {
164    type AirContext<'a>
165        = (AssertLessThanIo<AB::Expr>, &'a [AB::Var])
166    where
167        AB::Expr: 'a,
168        AB::Var: 'a,
169        AB: 'a;
170
171    // constrain that x < y
172    // warning: send for range check must be included for the constraints to be sound
173    fn eval<'a>(
174        &'a self,
175        builder: &'a mut AB,
176        (io, lower_decomp): (AssertLessThanIo<AB::Expr>, &'a [AB::Var]),
177    ) where
178        AB::Var: 'a,
179        AB::Expr: 'a,
180    {
181        // Note: every AIR that uses this sub-AIR must include the range checks for soundness
182        self.eval_range_checks(builder, lower_decomp, io.count.clone());
183        self.eval_without_range_checks(builder, io, lower_decomp);
184    }
185}
186
187impl<F: Field> TraceSubRowGenerator<F> for AssertLtSubAir {
188    /// (range_checker, x, y)
189    // x, y are u32 because memory records are storing u32 and there would be needless conversions.
190    // It also prevents a F: PrimeField32 trait bound.
191    type TraceContext<'a> = (&'a VariableRangeCheckerChip, u32, u32);
192    /// lower_decomp
193    type ColsMut<'a> = &'a mut [F];
194
195    /// Should only be used when `io.count != 0` i.e. only on non-padding rows.
196    #[inline(always)]
197    fn generate_subrow<'a>(
198        &'a self,
199        (range_checker, x, y): (&'a VariableRangeCheckerChip, u32, u32),
200        lower_decomp: &'a mut [F],
201    ) {
202        debug_assert!(x < y, "assert {x} < {y} failed");
203        debug_assert_eq!(lower_decomp.len(), self.decomp_limbs);
204        debug_assert!(
205            x < (1 << self.max_bits),
206            "{x} has more than {} bits",
207            self.max_bits
208        );
209        debug_assert!(
210            y < (1 << self.max_bits),
211            "{y} has more than {} bits",
212            self.max_bits
213        );
214
215        // Note: if x < y then y - x - 1 should already have <= max_bits bits
216        range_checker.decompose(y - x - 1, self.max_bits, lower_decomp);
217    }
218}