openvm_circuit_primitives/is_less_than_array/
mod.rs
1use itertools::izip;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{
4 interaction::InteractionBuilder,
5 p3_air::AirBuilder,
6 p3_field::{FieldAlgebra, PrimeField32},
7};
8
9use crate::{
10 is_less_than::{IsLtSubAir, LessThanAuxCols},
11 utils::not,
12 var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
13 SubAir, TraceSubRowGenerator,
14};
15
16#[cfg(test)]
17pub mod tests;
18
19#[repr(C)]
20#[derive(Clone, Copy, Debug)]
21pub struct IsLtArrayIo<T, const NUM: usize> {
22 pub x: [T; NUM],
23 pub y: [T; NUM],
24 pub out: T,
27 pub count: T,
31}
32
33#[repr(C)]
34#[derive(AlignedBorrow, Clone, Copy, Debug)]
35pub struct IsLtArrayAuxCols<T, const NUM: usize, const AUX_LEN: usize> {
36 pub diff_marker: [T; NUM],
39 pub diff_inv: T,
40 pub lt_aux: LessThanAuxCols<T, AUX_LEN>,
41}
42
43#[derive(Clone, Copy, Debug)]
44pub struct IsLtArrayAuxColsRef<'a, T> {
45 pub diff_marker: &'a [T],
46 pub diff_inv: &'a T,
47 pub lt_decomp: &'a [T],
48}
49
50#[derive(Debug)]
51pub struct IsLtArrayAuxColsMut<'a, T> {
52 pub diff_marker: &'a mut [T],
53 pub diff_inv: &'a mut T,
54 pub lt_decomp: &'a mut [T],
55}
56
57impl<'a, T, const NUM: usize, const AUX_LEN: usize> From<&'a IsLtArrayAuxCols<T, NUM, AUX_LEN>>
58 for IsLtArrayAuxColsRef<'a, T>
59{
60 fn from(value: &'a IsLtArrayAuxCols<T, NUM, AUX_LEN>) -> Self {
61 Self {
62 diff_marker: &value.diff_marker,
63 diff_inv: &value.diff_inv,
64 lt_decomp: &value.lt_aux.lower_decomp,
65 }
66 }
67}
68
69impl<'a, T, const NUM: usize, const AUX_LEN: usize> From<&'a mut IsLtArrayAuxCols<T, NUM, AUX_LEN>>
70 for IsLtArrayAuxColsMut<'a, T>
71{
72 fn from(value: &'a mut IsLtArrayAuxCols<T, NUM, AUX_LEN>) -> Self {
73 Self {
74 diff_marker: &mut value.diff_marker,
75 diff_inv: &mut value.diff_inv,
76 lt_decomp: &mut value.lt_aux.lower_decomp,
77 }
78 }
79}
80
81#[derive(Copy, Clone, Debug)]
90pub struct IsLtArraySubAir<const NUM: usize> {
91 pub lt: IsLtSubAir,
92}
93
94impl<const NUM: usize> IsLtArraySubAir<NUM> {
95 pub fn new(bus: VariableRangeCheckerBus, max_bits: usize) -> Self {
96 Self {
97 lt: IsLtSubAir::new(bus, max_bits),
98 }
99 }
100
101 pub fn when_transition(self) -> IsLtArrayWhenTransitionAir<NUM> {
102 IsLtArrayWhenTransitionAir(self)
103 }
104
105 pub fn max_bits(&self) -> usize {
106 self.lt.max_bits
107 }
108
109 pub fn range_max_bits(&self) -> usize {
110 self.lt.range_max_bits()
111 }
112
113 fn eval_without_range_checks<AB: AirBuilder>(
116 &self,
117 builder: &mut AB,
118 io: IsLtArrayIo<AB::Expr, NUM>,
119 diff_marker: &[AB::Var],
120 diff_inv: AB::Var,
121 lt_decomp: &[AB::Var],
122 ) {
123 assert_eq!(diff_marker.len(), NUM);
124 let mut prefix_sum = AB::Expr::ZERO;
125 let mut diff_val = AB::Expr::ZERO;
126 for (x, y, &marker) in izip!(io.x, io.y, diff_marker) {
127 let diff = y - x;
128 diff_val += diff.clone() * marker.into();
129 prefix_sum += marker.into();
130 builder.assert_bool(marker);
131 builder
132 .when(io.count.clone())
133 .assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
134 builder.when(marker).assert_one(diff * diff_inv);
135 }
136 builder.assert_bool(prefix_sum.clone());
137 builder
145 .when(io.count.clone())
146 .when(not::<AB::Expr>(prefix_sum))
147 .assert_zero(io.out.clone());
148
149 self.lt
150 .eval_without_range_checks(builder, diff_val, io.out, io.count, lt_decomp);
151 }
152}
153
154impl<AB: InteractionBuilder, const NUM: usize> SubAir<AB> for IsLtArraySubAir<NUM> {
155 type AirContext<'a>
156 = (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>)
157 where
158 AB::Expr: 'a,
159 AB::Var: 'a,
160 AB: 'a;
161
162 fn eval<'a>(
163 &'a self,
164 builder: &'a mut AB,
165 (io, aux): (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>),
166 ) where
167 AB::Var: 'a,
168 AB::Expr: 'a,
169 {
170 self.lt
171 .eval_range_checks(builder, aux.lt_decomp, io.count.clone());
172 self.eval_without_range_checks(builder, io, aux.diff_marker, *aux.diff_inv, aux.lt_decomp);
173 }
174}
175
176#[derive(Copy, Clone, Debug)]
180pub struct IsLtArrayWhenTransitionAir<const NUM: usize>(pub IsLtArraySubAir<NUM>);
181
182impl<AB: InteractionBuilder, const NUM: usize> SubAir<AB> for IsLtArrayWhenTransitionAir<NUM> {
183 type AirContext<'a>
184 = (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>)
185 where
186 AB::Expr: 'a,
187 AB::Var: 'a,
188 AB: 'a;
189
190 fn eval<'a>(
191 &'a self,
192 builder: &'a mut AB,
193 (io, aux): (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>),
194 ) where
195 AB::Var: 'a,
196 AB::Expr: 'a,
197 {
198 self.0
199 .lt
200 .eval_range_checks(builder, aux.lt_decomp, io.count.clone());
201 self.0.eval_without_range_checks(
202 &mut builder.when_transition(),
203 io,
204 aux.diff_marker,
205 *aux.diff_inv,
206 aux.lt_decomp,
207 );
208 }
209}
210
211impl<F: PrimeField32, const NUM: usize> TraceSubRowGenerator<F> for IsLtArraySubAir<NUM> {
212 type TraceContext<'a> = (&'a VariableRangeCheckerChip, &'a [F], &'a [F]);
214 type ColsMut<'a> = (IsLtArrayAuxColsMut<'a, F>, &'a mut F);
216
217 #[inline(always)]
219 fn generate_subrow<'a>(
220 &'a self,
221 (range_checker, x, y): (&'a VariableRangeCheckerChip, &'a [F], &'a [F]),
222 (aux, out): (IsLtArrayAuxColsMut<'a, F>, &'a mut F),
223 ) {
224 tracing::trace!("IsLtArraySubAir::generate_subrow x={:?}, y={:?}", x, y);
225 let mut is_eq = true;
226 let mut diff_val = F::ZERO;
227 *aux.diff_inv = F::ZERO;
228 for (x_i, y_i, diff_marker) in izip!(x, y, aux.diff_marker.iter_mut()) {
229 if x_i != y_i && is_eq {
230 is_eq = false;
231 *diff_marker = F::ONE;
232 diff_val = *y_i - *x_i;
233 *aux.diff_inv = diff_val.inverse();
234 } else {
235 *diff_marker = F::ZERO;
236 }
237 }
238 let shifted_diff =
240 (diff_val + F::from_canonical_u32((1 << self.max_bits()) - 1)).as_canonical_u32();
241 let lower_u32 = shifted_diff & ((1 << self.max_bits()) - 1);
242 *out = F::from_bool(shifted_diff != lower_u32);
243
244 range_checker.decompose(lower_u32, self.max_bits(), aux.lt_decomp);
246 }
247}