openvm_circuit_primitives/is_less_than_array/
mod.rs1use itertools::izip;
2use openvm_circuit_primitives_derive::AlignedBorrow;
3use openvm_stark_backend::{
4 interaction::InteractionBuilder,
5 p3_air::AirBuilder,
6 p3_field::{FieldAlgebra, PrimeField32},
7};
8
9use crate::{
10 is_less_than::{IsLtSubAir, LessThanAuxCols},
11 utils::not,
12 var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
13 SubAir, TraceSubRowGenerator,
14};
15
16#[cfg(test)]
17pub mod tests;
18
19#[repr(C)]
20#[derive(Clone, Copy, Debug)]
21pub struct IsLtArrayIo<T, const NUM: usize> {
22 pub x: [T; NUM],
23 pub y: [T; NUM],
24 pub out: T,
27 pub count: T,
31}
32
33#[repr(C)]
34#[derive(AlignedBorrow, Clone, Copy, Debug)]
35pub struct IsLtArrayAuxCols<T, const NUM: usize, const AUX_LEN: usize> {
36 pub diff_marker: [T; NUM],
40 pub diff_inv: T,
41 pub lt_aux: LessThanAuxCols<T, AUX_LEN>,
42}
43
44#[derive(Clone, Copy, Debug)]
45pub struct IsLtArrayAuxColsRef<'a, T> {
46 pub diff_marker: &'a [T],
47 pub diff_inv: &'a T,
48 pub lt_decomp: &'a [T],
49}
50
51#[derive(Debug)]
52pub struct IsLtArrayAuxColsMut<'a, T> {
53 pub diff_marker: &'a mut [T],
54 pub diff_inv: &'a mut T,
55 pub lt_decomp: &'a mut [T],
56}
57
58impl<'a, T, const NUM: usize, const AUX_LEN: usize> From<&'a IsLtArrayAuxCols<T, NUM, AUX_LEN>>
59 for IsLtArrayAuxColsRef<'a, T>
60{
61 fn from(value: &'a IsLtArrayAuxCols<T, NUM, AUX_LEN>) -> Self {
62 Self {
63 diff_marker: &value.diff_marker,
64 diff_inv: &value.diff_inv,
65 lt_decomp: &value.lt_aux.lower_decomp,
66 }
67 }
68}
69
70impl<'a, T, const NUM: usize, const AUX_LEN: usize> From<&'a mut IsLtArrayAuxCols<T, NUM, AUX_LEN>>
71 for IsLtArrayAuxColsMut<'a, T>
72{
73 fn from(value: &'a mut IsLtArrayAuxCols<T, NUM, AUX_LEN>) -> Self {
74 Self {
75 diff_marker: &mut value.diff_marker,
76 diff_inv: &mut value.diff_inv,
77 lt_decomp: &mut value.lt_aux.lower_decomp,
78 }
79 }
80}
81
82#[derive(Copy, Clone, Debug)]
91pub struct IsLtArraySubAir<const NUM: usize> {
92 pub lt: IsLtSubAir,
93}
94
95impl<const NUM: usize> IsLtArraySubAir<NUM> {
96 pub fn new(bus: VariableRangeCheckerBus, max_bits: usize) -> Self {
97 Self {
98 lt: IsLtSubAir::new(bus, max_bits),
99 }
100 }
101
102 pub fn when_transition(self) -> IsLtArrayWhenTransitionAir<NUM> {
103 IsLtArrayWhenTransitionAir(self)
104 }
105
106 pub fn max_bits(&self) -> usize {
107 self.lt.max_bits
108 }
109
110 pub fn range_max_bits(&self) -> usize {
111 self.lt.range_max_bits()
112 }
113
114 fn eval_without_range_checks<AB: AirBuilder>(
117 &self,
118 builder: &mut AB,
119 io: IsLtArrayIo<AB::Expr, NUM>,
120 diff_marker: &[AB::Var],
121 diff_inv: AB::Var,
122 lt_decomp: &[AB::Var],
123 ) {
124 assert_eq!(diff_marker.len(), NUM);
125 let mut prefix_sum = AB::Expr::ZERO;
126 let mut diff_val = AB::Expr::ZERO;
127 for (x, y, &marker) in izip!(io.x, io.y, diff_marker) {
128 let diff = y - x;
129 diff_val += diff.clone() * marker.into();
130 prefix_sum += marker.into();
131 builder.assert_bool(marker);
132 builder
133 .when(io.count.clone())
134 .assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
135 builder.when(marker).assert_one(diff * diff_inv);
136 }
137 builder.assert_bool(prefix_sum.clone());
138 builder
147 .when(io.count.clone())
148 .when(not::<AB::Expr>(prefix_sum))
149 .assert_zero(io.out.clone());
150
151 self.lt
152 .eval_without_range_checks(builder, diff_val, io.out, io.count, lt_decomp);
153 }
154}
155
156impl<AB: InteractionBuilder, const NUM: usize> SubAir<AB> for IsLtArraySubAir<NUM> {
157 type AirContext<'a>
158 = (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>)
159 where
160 AB::Expr: 'a,
161 AB::Var: 'a,
162 AB: 'a;
163
164 fn eval<'a>(
165 &'a self,
166 builder: &'a mut AB,
167 (io, aux): (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>),
168 ) where
169 AB::Var: 'a,
170 AB::Expr: 'a,
171 {
172 self.lt
173 .eval_range_checks(builder, aux.lt_decomp, io.count.clone());
174 self.eval_without_range_checks(builder, io, aux.diff_marker, *aux.diff_inv, aux.lt_decomp);
175 }
176}
177
178#[derive(Copy, Clone, Debug)]
182pub struct IsLtArrayWhenTransitionAir<const NUM: usize>(pub IsLtArraySubAir<NUM>);
183
184impl<AB: InteractionBuilder, const NUM: usize> SubAir<AB> for IsLtArrayWhenTransitionAir<NUM> {
185 type AirContext<'a>
186 = (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>)
187 where
188 AB::Expr: 'a,
189 AB::Var: 'a,
190 AB: 'a;
191
192 fn eval<'a>(
193 &'a self,
194 builder: &'a mut AB,
195 (io, aux): (IsLtArrayIo<AB::Expr, NUM>, IsLtArrayAuxColsRef<'a, AB::Var>),
196 ) where
197 AB::Var: 'a,
198 AB::Expr: 'a,
199 {
200 self.0
201 .lt
202 .eval_range_checks(builder, aux.lt_decomp, io.count.clone());
203 self.0.eval_without_range_checks(
204 &mut builder.when_transition(),
205 io,
206 aux.diff_marker,
207 *aux.diff_inv,
208 aux.lt_decomp,
209 );
210 }
211}
212
213impl<F: PrimeField32, const NUM: usize> TraceSubRowGenerator<F> for IsLtArraySubAir<NUM> {
214 type TraceContext<'a> = (&'a VariableRangeCheckerChip, &'a [F], &'a [F]);
216 type ColsMut<'a> = (IsLtArrayAuxColsMut<'a, F>, &'a mut F);
218
219 #[inline(always)]
221 fn generate_subrow<'a>(
222 &'a self,
223 (range_checker, x, y): (&'a VariableRangeCheckerChip, &'a [F], &'a [F]),
224 (aux, out): (IsLtArrayAuxColsMut<'a, F>, &'a mut F),
225 ) {
226 tracing::trace!("IsLtArraySubAir::generate_subrow x={:?}, y={:?}", x, y);
227 let mut is_eq = true;
228 let mut diff_val = F::ZERO;
229 *aux.diff_inv = F::ZERO;
230 for (x_i, y_i, diff_marker) in izip!(x, y, aux.diff_marker.iter_mut()) {
231 if x_i != y_i && is_eq {
232 is_eq = false;
233 *diff_marker = F::ONE;
234 diff_val = *y_i - *x_i;
235 *aux.diff_inv = diff_val.inverse();
236 } else {
237 *diff_marker = F::ZERO;
238 }
239 }
240 let shifted_diff =
242 (diff_val + F::from_canonical_u32((1 << self.max_bits()) - 1)).as_canonical_u32();
243 let lower_u32 = shifted_diff & ((1 << self.max_bits()) - 1);
244 *out = F::from_bool(shifted_diff != lower_u32);
245
246 range_checker.decompose(lower_u32, self.max_bits(), aux.lt_decomp);
248 }
249}