1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7 arch::*,
8 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12 utils::not,
13 AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
17use openvm_rv32im_transpiler::LessThanOpcode;
18use openvm_stark_backend::{
19 interaction::InteractionBuilder,
20 p3_air::{AirBuilder, BaseAir},
21 p3_field::{Field, FieldAlgebra, PrimeField32},
22 rap::BaseAirWithPublicValues,
23};
24use strum::IntoEnumIterator;
25
26#[repr(C)]
27#[derive(AlignedBorrow, Debug)]
28pub struct LessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
29 pub b: [T; NUM_LIMBS],
30 pub c: [T; NUM_LIMBS],
31 pub cmp_result: T,
32
33 pub opcode_slt_flag: T,
34 pub opcode_sltu_flag: T,
35
36 pub b_msb_f: T,
39 pub c_msb_f: T,
40
41 pub diff_marker: [T; NUM_LIMBS],
44 pub diff_val: T,
45}
46
47#[derive(Copy, Clone, Debug, derive_new::new)]
48pub struct LessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
49 pub bus: BitwiseOperationLookupBus,
50 offset: usize,
51}
52
53impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
54 for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
55{
56 fn width(&self) -> usize {
57 LessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
58 }
59}
60impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
61 for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
62{
63}
64
65impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
66 for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
67where
68 AB: InteractionBuilder,
69 I: VmAdapterInterface<AB::Expr>,
70 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
71 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
72 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
73{
74 fn eval(
75 &self,
76 builder: &mut AB,
77 local_core: &[AB::Var],
78 _from_pc: AB::Var,
79 ) -> AdapterAirContext<AB::Expr, I> {
80 let cols: &LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
81 let flags = [cols.opcode_slt_flag, cols.opcode_sltu_flag];
82
83 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
84 builder.assert_bool(flag);
85 acc + flag.into()
86 });
87 builder.assert_bool(is_valid.clone());
88 builder.assert_bool(cols.cmp_result);
89
90 let b = &cols.b;
91 let c = &cols.c;
92 let marker = &cols.diff_marker;
93 let mut prefix_sum = AB::Expr::ZERO;
94
95 let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
96 let c_diff = c[NUM_LIMBS - 1] - cols.c_msb_f;
97 builder
98 .assert_zero(b_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - b_diff));
99 builder
100 .assert_zero(c_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - c_diff));
101
102 for i in (0..NUM_LIMBS).rev() {
103 let diff = (if i == NUM_LIMBS - 1 {
104 cols.c_msb_f - cols.b_msb_f
105 } else {
106 c[i] - b[i]
107 }) * (AB::Expr::from_canonical_u8(2) * cols.cmp_result - AB::Expr::ONE);
108 prefix_sum += marker[i].into();
109 builder.assert_bool(marker[i]);
110 builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
111 builder.when(marker[i]).assert_eq(cols.diff_val, diff);
112 }
113 builder.assert_bool(prefix_sum.clone());
119 builder
120 .when(not::<AB::Expr>(prefix_sum.clone()))
121 .assert_zero(cols.cmp_result);
122
123 self.bus
125 .send_range(
126 cols.b_msb_f
127 + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
128 cols.c_msb_f
129 + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
130 )
131 .eval(builder, is_valid.clone());
132
133 self.bus
135 .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
136 .eval(builder, prefix_sum);
137
138 let expected_opcode = flags
139 .iter()
140 .zip(LessThanOpcode::iter())
141 .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
142 acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
143 })
144 + AB::Expr::from_canonical_usize(self.offset);
145 let mut a: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
146 a[0] = cols.cmp_result.into();
147
148 AdapterAirContext {
149 to_pc: None,
150 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
151 writes: [a].into(),
152 instruction: MinimalInstruction {
153 is_valid,
154 opcode: expected_opcode,
155 }
156 .into(),
157 }
158 }
159
160 fn start_offset(&self) -> usize {
161 self.offset
162 }
163}
164
165#[repr(C)]
166#[derive(AlignedBytesBorrow, Debug)]
167pub struct LessThanCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
168 pub b: [u8; NUM_LIMBS],
169 pub c: [u8; NUM_LIMBS],
170 pub local_opcode: u8,
171}
172
173#[derive(Clone, Copy, derive_new::new)]
174pub struct LessThanExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
175 adapter: A,
176 pub offset: usize,
177}
178
179#[derive(Clone, derive_new::new)]
180pub struct LessThanFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
181 adapter: A,
182 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
183 pub offset: usize,
184}
185
186impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
187 for LessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
188where
189 F: PrimeField32,
190 A: 'static
191 + AdapterTraceExecutor<
192 F,
193 ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
194 WriteData: From<[[u8; NUM_LIMBS]; 1]>,
195 >,
196 for<'buf> RA: RecordArena<
197 'buf,
198 EmptyAdapterCoreLayout<F, A>,
199 (
200 A::RecordMut<'buf>,
201 &'buf mut LessThanCoreRecord<NUM_LIMBS, LIMB_BITS>,
202 ),
203 >,
204{
205 fn get_opcode_name(&self, opcode: usize) -> String {
206 format!("{:?}", LessThanOpcode::from_usize(opcode - self.offset))
207 }
208
209 fn execute(
210 &self,
211 state: VmStateMut<F, TracingMemory, RA>,
212 instruction: &Instruction<F>,
213 ) -> Result<(), ExecutionError> {
214 debug_assert!(LIMB_BITS <= 8);
215 let Instruction { opcode, .. } = instruction;
216
217 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
218 A::start(*state.pc, state.memory, &mut adapter_record);
219
220 let [rs1, rs2] = self
221 .adapter
222 .read(state.memory, instruction, &mut adapter_record)
223 .into();
224
225 core_record.b = rs1;
226 core_record.c = rs2;
227 core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8;
228
229 let (cmp_result, _, _, _) = run_less_than::<NUM_LIMBS, LIMB_BITS>(
230 core_record.local_opcode == LessThanOpcode::SLT as u8,
231 &rs1,
232 &rs2,
233 );
234
235 let mut output = [0u8; NUM_LIMBS];
236 output[0] = cmp_result as u8;
237
238 self.adapter.write(
239 state.memory,
240 instruction,
241 [output].into(),
242 &mut adapter_record,
243 );
244
245 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
246
247 Ok(())
248 }
249}
250
251impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
252 for LessThanFiller<A, NUM_LIMBS, LIMB_BITS>
253where
254 F: PrimeField32,
255 A: 'static + AdapterTraceFiller<F>,
256{
257 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
258 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
261 self.adapter.fill_trace_row(mem_helper, adapter_row);
262 let record: &LessThanCoreRecord<NUM_LIMBS, LIMB_BITS> =
265 unsafe { get_record_from_slice(&mut core_row, ()) };
266
267 let core_row: &mut LessThanCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
268
269 let is_slt = record.local_opcode == LessThanOpcode::SLT as u8;
270 let (cmp_result, diff_idx, b_sign, c_sign) =
271 run_less_than::<NUM_LIMBS, LIMB_BITS>(is_slt, &record.b, &record.c);
272
273 let (b_msb_f, b_msb_range) = if b_sign {
276 (
277 -F::from_canonical_u16((1u16 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u16),
278 record.b[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)),
279 )
280 } else {
281 (
282 F::from_canonical_u8(record.b[NUM_LIMBS - 1]),
283 record.b[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)),
284 )
285 };
286 let (c_msb_f, c_msb_range) = if c_sign {
287 (
288 -F::from_canonical_u16((1u16 << LIMB_BITS) - record.c[NUM_LIMBS - 1] as u16),
289 record.c[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)),
290 )
291 } else {
292 (
293 F::from_canonical_u8(record.c[NUM_LIMBS - 1]),
294 record.c[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)),
295 )
296 };
297
298 core_row.diff_val = if diff_idx == NUM_LIMBS {
299 F::ZERO
300 } else if diff_idx == (NUM_LIMBS - 1) {
301 if cmp_result {
302 c_msb_f - b_msb_f
303 } else {
304 b_msb_f - c_msb_f
305 }
306 } else if cmp_result {
307 F::from_canonical_u8(record.c[diff_idx] - record.b[diff_idx])
308 } else {
309 F::from_canonical_u8(record.b[diff_idx] - record.c[diff_idx])
310 };
311
312 self.bitwise_lookup_chip
313 .request_range(b_msb_range as u32, c_msb_range as u32);
314
315 core_row.diff_marker = [F::ZERO; NUM_LIMBS];
316 if diff_idx != NUM_LIMBS {
317 self.bitwise_lookup_chip
318 .request_range(core_row.diff_val.as_canonical_u32() - 1, 0);
319 core_row.diff_marker[diff_idx] = F::ONE;
320 }
321
322 core_row.c_msb_f = c_msb_f;
323 core_row.b_msb_f = b_msb_f;
324 core_row.opcode_sltu_flag = F::from_bool(!is_slt);
325 core_row.opcode_slt_flag = F::from_bool(is_slt);
326 core_row.cmp_result = F::from_bool(cmp_result);
327 core_row.c = record.c.map(F::from_canonical_u8);
328 core_row.b = record.b.map(F::from_canonical_u8);
329 }
330}
331
332#[inline(always)]
334pub(super) fn run_less_than<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
335 is_slt: bool,
336 x: &[u8; NUM_LIMBS],
337 y: &[u8; NUM_LIMBS],
338) -> (bool, usize, bool, bool) {
339 let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt;
340 let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt;
341 for i in (0..NUM_LIMBS).rev() {
342 if x[i] != y[i] {
343 return ((x[i] < y[i]) ^ x_sign ^ y_sign, i, x_sign, y_sign);
344 }
345 }
346 (false, NUM_LIMBS, x_sign, y_sign)
347}