openvm_circuit_primitives/assert_less_than/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
use derive_new::new;
use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_stark_backend::{
    interaction::InteractionBuilder,
    p3_air::AirBuilder,
    p3_field::{AbstractField, Field},
};

use crate::{
    var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
    SubAir, TraceSubRowGenerator,
};

#[cfg(test)]
pub mod tests;

/// The IO is typically provided with `T = AB::Expr` as external context.
// This does not derive AlignedBorrow because it is usually **not** going to be
// direct columns in an AIR.
#[repr(C)]
#[derive(Clone, Copy, Debug, Default)]
pub struct AssertLessThanIo<T> {
    pub x: T,
    pub y: T,
    /// Will only apply constraints when `count != 0`.
    /// Range checks are done with multiplicity `count`.
    /// If `count == 0` then no range checks are done.
    /// In practice `count` is always boolean, although this is not enforced
    /// by the subair.
    ///
    /// N.B.: in fact range checks could always be done, if the aux
    /// subrow values are set to 0 when `count == 0`. This woud slightly
    /// simplify the range check interactions, although usually doesn't change
    /// the overall constraint degree. It however leads to the annoyance that
    /// you must update the RangeChecker's multiplicities even on dummy padding
    /// rows. To improve quality of life,
    /// we currently use this more complex constraint.
    pub count: T,
}
impl<T> AssertLessThanIo<T> {
    pub fn new(x: impl Into<T>, y: impl Into<T>, count: impl Into<T>) -> Self {
        Self {
            x: x.into(),
            y: y.into(),
            count: count.into(),
        }
    }
}

/// These columns are owned by the SubAir. Typically used with `T = AB::Var`.
/// `AUX_LEN` is the number of AUX columns
/// we have that AUX_LEN = max_bits.div_ceil(bus.range_max_bits)
#[repr(C)]
#[derive(AlignedBorrow, Clone, Copy, Debug, new)]
pub struct LessThanAuxCols<T, const AUX_LEN: usize> {
    // lower_decomp consists of lower decomposed into limbs of size bus.range_max_bits
    // note: the final limb might have less than bus.range_max_bits bits
    pub lower_decomp: [T; AUX_LEN],
}

/// This is intended for use as a **SubAir**, not as a standalone Air.
///
/// This SubAir constrains that `x < y` when `count != 0`, assuming
/// the two numbers both have a max number of bits, given by `max_bits`.
/// The SubAir decomposes `y - x - 1` into limbs of
/// size `bus.range_max_bits`, and interacts with a
/// `VariableRangeCheckerBus` to range check the decompositions.
///
/// The SubAir will own auxilliary columns to store the decomposed limbs.
/// The number of limbs is `max_bits.div_ceil(bus.range_max_bits)`.
///
/// The expected max constraint degree of `eval` is
///     deg(count) + max(1, deg(x), deg(y))
#[derive(Copy, Clone, Debug)]
pub struct AssertLtSubAir {
    /// The bus for sends to range chip
    pub bus: VariableRangeCheckerBus,
    /// The maximum number of bits for the numbers to compare
    /// Soundness requirement: max_bits <= 29
    ///     max_bits > 29 doesn't work: the approach is to check that y-x-1 is non-negative.
    ///     For a field with prime modular, this is equivalent to checking that y-x-1 is in
    ///     the range [0, 2^max_bits - 1]. However, for max_bits > 29, if y is small enough
    ///     and x is large enough, then y-x-1 is negative but can still be in the range due
    ///     to the field size not being big enough.
    pub max_bits: usize,
    /// `decomp_limbs = max_bits.div_ceil(bus.range_max_bits)` is the
    /// number of limbs that `y - x - 1` will be decomposed into.
    pub decomp_limbs: usize,
}

impl AssertLtSubAir {
    pub fn new(bus: VariableRangeCheckerBus, max_bits: usize) -> Self {
        let decomp_limbs = max_bits.div_ceil(bus.range_max_bits);
        Self {
            bus,
            max_bits,
            decomp_limbs,
        }
    }

    pub fn when_transition(self) -> AssertLtWhenTransitionAir {
        AssertLtWhenTransitionAir(self)
    }

    pub fn range_max_bits(&self) -> usize {
        self.bus.range_max_bits
    }

    /// FOR INTERNAL USE ONLY.
    /// This AIR is only sound if interactions are enabled
    ///
    /// Constraints between `io` and `aux` are only enforced when `count != 0`.
    /// This means `aux` can be all zero independent on what `io` is by setting `count = 0`.
    #[inline(always)]
    fn eval_without_range_checks<AB: AirBuilder>(
        &self,
        builder: &mut AB,
        io: AssertLessThanIo<AB::Expr>,
        lower_decomp: &[AB::Var],
    ) {
        assert_eq!(lower_decomp.len(), self.decomp_limbs);
        // this is the desired intermediate value (i.e. y - x - 1)
        // deg(intermed_val) = deg(io)
        let intermed_val = io.y - io.x - AB::Expr::ONE;

        // Construct lower from lower_decomp:
        // - each limb of lower_decomp will be range checked
        // deg(lower) = 1
        let lower = lower_decomp
            .iter()
            .enumerate()
            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
                acc + val * AB::Expr::from_canonical_usize(1 << (i * self.range_max_bits()))
            });

        // constrain that y - x - 1 is equal to the constructed lower value.
        // this enforces that the intermediate value is in the range [0, 2^max_bits - 1], which is equivalent to x < y
        builder.when(io.count).assert_eq(intermed_val, lower);
        // the degree of this constraint is expected to be deg(count) + max(deg(intermed_val), deg(lower))
        // since we are constraining count * intermed_val == count * lower
    }

    #[inline(always)]
    fn eval_range_checks<AB: InteractionBuilder>(
        &self,
        builder: &mut AB,
        lower_decomp: &[AB::Var],
        count: impl Into<AB::Expr>,
    ) {
        let count = count.into();
        let mut bits_remaining = self.max_bits;
        // we range check the limbs of the lower_decomp so that we know each element
        // of lower_decomp has the correct number of bits
        for limb in lower_decomp {
            // the last limb might have fewer than `bus.range_max_bits` bits
            let range_bits = bits_remaining.min(self.range_max_bits());
            self.bus
                .range_check(*limb, range_bits)
                .eval(builder, count.clone());
            bits_remaining = bits_remaining.saturating_sub(self.range_max_bits());
        }
    }
}

impl<AB: InteractionBuilder> SubAir<AB> for AssertLtSubAir {
    type AirContext<'a>
        = (AssertLessThanIo<AB::Expr>, &'a [AB::Var])
    where
        AB::Expr: 'a,
        AB::Var: 'a,
        AB: 'a;

    // constrain that x < y
    // warning: send for range check must be included for the constraints to be sound
    fn eval<'a>(
        &'a self,
        builder: &'a mut AB,
        (io, lower_decomp): (AssertLessThanIo<AB::Expr>, &'a [AB::Var]),
    ) where
        AB::Var: 'a,
        AB::Expr: 'a,
    {
        // Note: every AIR that uses this sub-AIR must include the range checks for soundness
        self.eval_range_checks(builder, lower_decomp, io.count.clone());
        self.eval_without_range_checks(builder, io, lower_decomp);
    }
}

/// The same subair as [AssertLtSubAir] except that non-range check
/// constraints are not imposed on the last row.
/// Intended use case is for asserting less than between entries in
/// adjacent rows.
#[derive(Clone, Copy, Debug)]
pub struct AssertLtWhenTransitionAir(pub AssertLtSubAir);

impl<AB: InteractionBuilder> SubAir<AB> for AssertLtWhenTransitionAir {
    type AirContext<'a>
        = (AssertLessThanIo<AB::Expr>, &'a [AB::Var])
    where
        AB::Expr: 'a,
        AB::Var: 'a,
        AB: 'a;

    /// Imposes the non-interaction constraints on all except the last row. This is
    /// intended for use when the comparators `x, y` are on adjacent rows.
    ///
    /// This function does also enable the interaction constraints _on every row_.
    /// The `eval_interactions` performs range checks on `lower_decomp` on every row, even
    /// though in this AIR the lower_decomp is not used on the last row.
    /// This simply means the trace generation must fill in the last row with numbers in
    /// range (e.g., with zeros)
    fn eval<'a>(
        &'a self,
        builder: &'a mut AB,
        (io, lower_decomp): (AssertLessThanIo<AB::Expr>, &'a [AB::Var]),
    ) where
        AB::Var: 'a,
        AB::Expr: 'a,
    {
        self.0
            .eval_range_checks(builder, lower_decomp, io.count.clone());
        self.0
            .eval_without_range_checks(&mut builder.when_transition(), io, lower_decomp);
    }
}

impl<F: Field> TraceSubRowGenerator<F> for AssertLtSubAir {
    /// (range_checker, x, y)
    // x, y are u32 because memory records are storing u32 and there would be needless conversions. It also prevents a F: PrimeField32 trait bound.
    type TraceContext<'a> = (&'a VariableRangeCheckerChip, u32, u32);
    /// lower_decomp
    type ColsMut<'a> = &'a mut [F];

    /// Should only be used when `io.count != 0`.
    #[inline(always)]
    fn generate_subrow<'a>(
        &'a self,
        (range_checker, x, y): (&'a VariableRangeCheckerChip, u32, u32),
        lower_decomp: &'a mut [F],
    ) {
        debug_assert!(x < y, "assert {x} < {y} failed");
        debug_assert_eq!(lower_decomp.len(), self.decomp_limbs);
        debug_assert!(
            x < (1 << self.max_bits),
            "{x} has more than {} bits",
            self.max_bits
        );
        debug_assert!(
            y < (1 << self.max_bits),
            "{y} has more than {} bits",
            self.max_bits
        );

        // Note: if x < y then y - x - 1 should already have <= max_bits bits
        range_checker.decompose(y - x - 1, self.max_bits, lower_decomp);
    }
}