1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
8 VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12 utils::not,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{instruction::Instruction, LocalOpcode};
16use openvm_rv32im_transpiler::LessThanOpcode;
17use openvm_stark_backend::{
18 interaction::InteractionBuilder,
19 p3_air::{AirBuilder, BaseAir},
20 p3_field::{Field, FieldAlgebra, PrimeField32},
21 rap::BaseAirWithPublicValues,
22};
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use serde_big_array::BigArray;
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow)]
29pub struct LessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30 pub b: [T; NUM_LIMBS],
31 pub c: [T; NUM_LIMBS],
32 pub cmp_result: T,
33
34 pub opcode_slt_flag: T,
35 pub opcode_sltu_flag: T,
36
37 pub b_msb_f: T,
40 pub c_msb_f: T,
41
42 pub diff_marker: [T; NUM_LIMBS],
45 pub diff_val: T,
46}
47
48#[derive(Copy, Clone, Debug)]
49pub struct LessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
50 pub bus: BitwiseOperationLookupBus,
51 offset: usize,
52}
53
54impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
55 for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
56{
57 fn width(&self) -> usize {
58 LessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
59 }
60}
61impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
62 for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
63{
64}
65
66impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
67 for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
68where
69 AB: InteractionBuilder,
70 I: VmAdapterInterface<AB::Expr>,
71 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
72 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
73 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
74{
75 fn eval(
76 &self,
77 builder: &mut AB,
78 local_core: &[AB::Var],
79 _from_pc: AB::Var,
80 ) -> AdapterAirContext<AB::Expr, I> {
81 let cols: &LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
82 let flags = [cols.opcode_slt_flag, cols.opcode_sltu_flag];
83
84 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
85 builder.assert_bool(flag);
86 acc + flag.into()
87 });
88 builder.assert_bool(is_valid.clone());
89 builder.assert_bool(cols.cmp_result);
90
91 let b = &cols.b;
92 let c = &cols.c;
93 let marker = &cols.diff_marker;
94 let mut prefix_sum = AB::Expr::ZERO;
95
96 let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
97 let c_diff = c[NUM_LIMBS - 1] - cols.c_msb_f;
98 builder
99 .assert_zero(b_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - b_diff));
100 builder
101 .assert_zero(c_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - c_diff));
102
103 for i in (0..NUM_LIMBS).rev() {
104 let diff = (if i == NUM_LIMBS - 1 {
105 cols.c_msb_f - cols.b_msb_f
106 } else {
107 c[i] - b[i]
108 }) * (AB::Expr::from_canonical_u8(2) * cols.cmp_result - AB::Expr::ONE);
109 prefix_sum += marker[i].into();
110 builder.assert_bool(marker[i]);
111 builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
112 builder.when(marker[i]).assert_eq(cols.diff_val, diff);
113 }
114 builder.assert_bool(prefix_sum.clone());
120 builder
121 .when(not::<AB::Expr>(prefix_sum.clone()))
122 .assert_zero(cols.cmp_result);
123
124 self.bus
126 .send_range(
127 cols.b_msb_f
128 + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
129 cols.c_msb_f
130 + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
131 )
132 .eval(builder, is_valid.clone());
133
134 self.bus
136 .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
137 .eval(builder, prefix_sum);
138
139 let expected_opcode = flags
140 .iter()
141 .zip(LessThanOpcode::iter())
142 .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
143 acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
144 })
145 + AB::Expr::from_canonical_usize(self.offset);
146 let mut a: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
147 a[0] = cols.cmp_result.into();
148
149 AdapterAirContext {
150 to_pc: None,
151 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
152 writes: [a].into(),
153 instruction: MinimalInstruction {
154 is_valid,
155 opcode: expected_opcode,
156 }
157 .into(),
158 }
159 }
160
161 fn start_offset(&self) -> usize {
162 self.offset
163 }
164}
165
166#[repr(C)]
167#[derive(Clone, Debug, Serialize, Deserialize)]
168#[serde(bound = "T: Serialize + DeserializeOwned")]
169pub struct LessThanCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
170 #[serde(with = "BigArray")]
171 pub b: [T; NUM_LIMBS],
172 #[serde(with = "BigArray")]
173 pub c: [T; NUM_LIMBS],
174 pub cmp_result: T,
175 pub b_msb_f: T,
176 pub c_msb_f: T,
177 pub diff_val: T,
178 pub diff_idx: usize,
179 pub opcode: LessThanOpcode,
180}
181
182pub struct LessThanCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
183 pub air: LessThanCoreAir<NUM_LIMBS, LIMB_BITS>,
184 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
185}
186
187impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> LessThanCoreChip<NUM_LIMBS, LIMB_BITS> {
188 pub fn new(
189 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
190 offset: usize,
191 ) -> Self {
192 Self {
193 air: LessThanCoreAir {
194 bus: bitwise_lookup_chip.bus(),
195 offset,
196 },
197 bitwise_lookup_chip,
198 }
199 }
200}
201
202impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
203 VmCoreChip<F, I> for LessThanCoreChip<NUM_LIMBS, LIMB_BITS>
204where
205 I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
206 I::Writes: From<[[F; NUM_LIMBS]; 1]>,
207{
208 type Record = LessThanCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
209 type Air = LessThanCoreAir<NUM_LIMBS, LIMB_BITS>;
210
211 #[allow(clippy::type_complexity)]
212 fn execute_instruction(
213 &self,
214 instruction: &Instruction<F>,
215 _from_pc: u32,
216 reads: I::Reads,
217 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
218 let Instruction { opcode, .. } = instruction;
219 let less_than_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
220
221 let data: [[F; NUM_LIMBS]; 2] = reads.into();
222 let b = data[0].map(|x| x.as_canonical_u32());
223 let c = data[1].map(|y| y.as_canonical_u32());
224 let (cmp_result, diff_idx, b_sign, c_sign) =
225 run_less_than::<NUM_LIMBS, LIMB_BITS>(less_than_opcode, &b, &c);
226
227 let (b_msb_f, b_msb_range) = if b_sign {
230 (
231 -F::from_canonical_u32((1 << LIMB_BITS) - b[NUM_LIMBS - 1]),
232 b[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)),
233 )
234 } else {
235 (
236 F::from_canonical_u32(b[NUM_LIMBS - 1]),
237 b[NUM_LIMBS - 1]
238 + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)),
239 )
240 };
241 let (c_msb_f, c_msb_range) = if c_sign {
242 (
243 -F::from_canonical_u32((1 << LIMB_BITS) - c[NUM_LIMBS - 1]),
244 c[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)),
245 )
246 } else {
247 (
248 F::from_canonical_u32(c[NUM_LIMBS - 1]),
249 c[NUM_LIMBS - 1]
250 + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)),
251 )
252 };
253 self.bitwise_lookup_chip
254 .request_range(b_msb_range, c_msb_range);
255
256 let diff_val = if diff_idx == NUM_LIMBS {
257 0
258 } else if diff_idx == (NUM_LIMBS - 1) {
259 if cmp_result {
260 c_msb_f - b_msb_f
261 } else {
262 b_msb_f - c_msb_f
263 }
264 .as_canonical_u32()
265 } else if cmp_result {
266 c[diff_idx] - b[diff_idx]
267 } else {
268 b[diff_idx] - c[diff_idx]
269 };
270
271 if diff_idx != NUM_LIMBS {
272 self.bitwise_lookup_chip.request_range(diff_val - 1, 0);
273 }
274
275 let mut writes = [0u32; NUM_LIMBS];
276 writes[0] = cmp_result as u32;
277
278 let output = AdapterRuntimeContext::without_pc([writes.map(F::from_canonical_u32)]);
279 let record = LessThanCoreRecord {
280 opcode: less_than_opcode,
281 b: data[0],
282 c: data[1],
283 cmp_result: F::from_bool(cmp_result),
284 b_msb_f,
285 c_msb_f,
286 diff_val: F::from_canonical_u32(diff_val),
287 diff_idx,
288 };
289
290 Ok((output, record))
291 }
292
293 fn get_opcode_name(&self, opcode: usize) -> String {
294 format!("{:?}", LessThanOpcode::from_usize(opcode - self.air.offset))
295 }
296
297 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
298 let row_slice: &mut LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
299 row_slice.b = record.b;
300 row_slice.c = record.c;
301 row_slice.cmp_result = record.cmp_result;
302 row_slice.b_msb_f = record.b_msb_f;
303 row_slice.c_msb_f = record.c_msb_f;
304 row_slice.diff_val = record.diff_val;
305 row_slice.opcode_slt_flag = F::from_bool(record.opcode == LessThanOpcode::SLT);
306 row_slice.opcode_sltu_flag = F::from_bool(record.opcode == LessThanOpcode::SLTU);
307 row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx));
308 }
309
310 fn air(&self) -> &Self::Air {
311 &self.air
312 }
313}
314
315pub(super) fn run_less_than<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
317 opcode: LessThanOpcode,
318 x: &[u32; NUM_LIMBS],
319 y: &[u32; NUM_LIMBS],
320) -> (bool, usize, bool, bool) {
321 let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT;
322 let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT;
323 for i in (0..NUM_LIMBS).rev() {
324 if x[i] != y[i] {
325 return ((x[i] < y[i]) ^ x_sign ^ y_sign, i, x_sign, y_sign);
326 }
327 }
328 (false, NUM_LIMBS, x_sign, y_sign)
329}