1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4 arch::*,
5 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
6};
7use openvm_circuit_primitives::{
8 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
9 utils::not,
10 AlignedBytesBorrow,
11};
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
14use openvm_rv32im_transpiler::BranchLessThanOpcode;
15use openvm_stark_backend::{
16 interaction::InteractionBuilder,
17 p3_air::{AirBuilder, BaseAir},
18 p3_field::{Field, FieldAlgebra, PrimeField32},
19 rap::BaseAirWithPublicValues,
20};
21use strum::IntoEnumIterator;
22
23#[repr(C)]
24#[derive(AlignedBorrow)]
25pub struct BranchLessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
26 pub a: [T; NUM_LIMBS],
27 pub b: [T; NUM_LIMBS],
28
29 pub cmp_result: T,
31 pub imm: T,
32
33 pub opcode_blt_flag: T,
34 pub opcode_bltu_flag: T,
35 pub opcode_bge_flag: T,
36 pub opcode_bgeu_flag: T,
37
38 pub a_msb_f: T,
41 pub b_msb_f: T,
42
43 pub cmp_lt: T,
45
46 pub diff_marker: [T; NUM_LIMBS],
49 pub diff_val: T,
50}
51
52#[derive(Copy, Clone, Debug, derive_new::new)]
53pub struct BranchLessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
54 pub bus: BitwiseOperationLookupBus,
55 offset: usize,
56}
57
58impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
59 for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
60{
61 fn width(&self) -> usize {
62 BranchLessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
63 }
64}
65impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
66 for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
67{
68}
69
70impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
71 for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
72where
73 AB: InteractionBuilder,
74 I: VmAdapterInterface<AB::Expr>,
75 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
76 I::Writes: Default,
77 I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
78{
79 fn eval(
80 &self,
81 builder: &mut AB,
82 local_core: &[AB::Var],
83 from_pc: AB::Var,
84 ) -> AdapterAirContext<AB::Expr, I> {
85 let cols: &BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
86 let flags = [
87 cols.opcode_blt_flag,
88 cols.opcode_bltu_flag,
89 cols.opcode_bge_flag,
90 cols.opcode_bgeu_flag,
91 ];
92
93 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
94 builder.assert_bool(flag);
95 acc + flag.into()
96 });
97 builder.assert_bool(is_valid.clone());
98 builder.assert_bool(cols.cmp_result);
99
100 let lt = cols.opcode_blt_flag + cols.opcode_bltu_flag;
101 let ge = cols.opcode_bge_flag + cols.opcode_bgeu_flag;
102 let signed = cols.opcode_blt_flag + cols.opcode_bge_flag;
103 builder.assert_eq(
104 cols.cmp_lt,
105 cols.cmp_result * lt.clone() + not(cols.cmp_result) * ge.clone(),
106 );
107
108 let a = &cols.a;
109 let b = &cols.b;
110 let marker = &cols.diff_marker;
111 let mut prefix_sum = AB::Expr::ZERO;
112
113 let a_diff = a[NUM_LIMBS - 1] - cols.a_msb_f;
116 let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
117 builder
118 .assert_zero(a_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - a_diff));
119 builder
120 .assert_zero(b_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - b_diff));
121
122 for i in (0..NUM_LIMBS).rev() {
123 let diff = (if i == NUM_LIMBS - 1 {
124 cols.b_msb_f - cols.a_msb_f
125 } else {
126 b[i] - a[i]
127 }) * (AB::Expr::from_canonical_u8(2) * cols.cmp_lt - AB::Expr::ONE);
128 prefix_sum += marker[i].into();
129 builder.assert_bool(marker[i]);
130 builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
131 builder.when(marker[i]).assert_eq(cols.diff_val, diff);
132 }
133 builder.assert_bool(prefix_sum.clone());
139 builder
140 .when(not::<AB::Expr>(prefix_sum.clone()))
141 .assert_zero(cols.cmp_lt);
142
143 self.bus
145 .send_range(
146 cols.a_msb_f + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
147 cols.b_msb_f + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
148 )
149 .eval(builder, is_valid.clone());
150
151 self.bus
153 .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
154 .eval(builder, prefix_sum);
155
156 let expected_opcode = flags
157 .iter()
158 .zip(BranchLessThanOpcode::iter())
159 .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
160 acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
161 })
162 + AB::Expr::from_canonical_usize(self.offset);
163
164 let to_pc = from_pc
165 + cols.cmp_result * cols.imm
166 + not(cols.cmp_result) * AB::Expr::from_canonical_u32(DEFAULT_PC_STEP);
167
168 AdapterAirContext {
169 to_pc: Some(to_pc),
170 reads: [cols.a.map(Into::into), cols.b.map(Into::into)].into(),
171 writes: Default::default(),
172 instruction: ImmInstruction {
173 is_valid,
174 opcode: expected_opcode,
175 immediate: cols.imm.into(),
176 }
177 .into(),
178 }
179 }
180
181 fn start_offset(&self) -> usize {
182 self.offset
183 }
184}
185
186#[repr(C)]
187#[derive(AlignedBytesBorrow, Debug)]
188pub struct BranchLessThanCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
189 pub a: [u8; NUM_LIMBS],
190 pub b: [u8; NUM_LIMBS],
191 pub imm: u32,
192 pub local_opcode: u8,
193}
194
195#[derive(Clone, Copy, derive_new::new)]
196pub struct BranchLessThanExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
197 adapter: A,
198 pub offset: usize,
199}
200
201#[derive(Clone, derive_new::new)]
202pub struct BranchLessThanFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
203 adapter: A,
204 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
205 pub offset: usize,
206}
207
208impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
209 for BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
210where
211 F: PrimeField32,
212 A: 'static + AdapterTraceExecutor<F, ReadData: Into<[[u8; NUM_LIMBS]; 2]>, WriteData = ()>,
213 for<'buf> RA: RecordArena<
214 'buf,
215 EmptyAdapterCoreLayout<F, A>,
216 (
217 A::RecordMut<'buf>,
218 &'buf mut BranchLessThanCoreRecord<NUM_LIMBS, LIMB_BITS>,
219 ),
220 >,
221{
222 fn get_opcode_name(&self, opcode: usize) -> String {
223 format!(
224 "{:?}",
225 BranchLessThanOpcode::from_usize(opcode - self.offset)
226 )
227 }
228
229 fn execute(
230 &self,
231 state: VmStateMut<F, TracingMemory, RA>,
232 instruction: &Instruction<F>,
233 ) -> Result<(), ExecutionError> {
234 let &Instruction { opcode, c: imm, .. } = instruction;
235
236 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
237
238 A::start(*state.pc, state.memory, &mut adapter_record);
239
240 let [rs1, rs2] = self
241 .adapter
242 .read(state.memory, instruction, &mut adapter_record)
243 .into();
244
245 core_record.a = rs1;
246 core_record.b = rs2;
247 core_record.imm = imm.as_canonical_u32();
248 core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8;
249
250 if run_cmp::<NUM_LIMBS, LIMB_BITS>(core_record.local_opcode, &rs1, &rs2).0 {
251 *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32();
252 } else {
253 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
254 }
255
256 Ok(())
257 }
258}
259
260impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
261 for BranchLessThanFiller<A, NUM_LIMBS, LIMB_BITS>
262where
263 F: PrimeField32,
264 A: 'static + AdapterTraceFiller<F>,
265{
266 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
267 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
270
271 let record: &BranchLessThanCoreRecord<NUM_LIMBS, LIMB_BITS> =
274 unsafe { get_record_from_slice(&mut core_row, ()) };
275
276 self.adapter.fill_trace_row(mem_helper, adapter_row);
277 let core_row: &mut BranchLessThanCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
278
279 let signed = record.local_opcode == BranchLessThanOpcode::BLT as u8
280 || record.local_opcode == BranchLessThanOpcode::BGE as u8;
281 let ge_op = record.local_opcode == BranchLessThanOpcode::BGE as u8
282 || record.local_opcode == BranchLessThanOpcode::BGEU as u8;
283
284 let (cmp_result, diff_idx, a_sign, b_sign) =
285 run_cmp::<NUM_LIMBS, LIMB_BITS>(record.local_opcode, &record.a, &record.b);
286
287 let cmp_lt = cmp_result ^ ge_op;
288
289 let (a_msb_f, a_msb_range) = if a_sign {
292 (
293 -F::from_canonical_u32((1 << LIMB_BITS) - record.a[NUM_LIMBS - 1] as u32),
294 record.a[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)),
295 )
296 } else {
297 (
298 F::from_canonical_u32(record.a[NUM_LIMBS - 1] as u32),
299 record.a[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)),
300 )
301 };
302 let (b_msb_f, b_msb_range) = if b_sign {
303 (
304 -F::from_canonical_u32((1 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u32),
305 record.b[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)),
306 )
307 } else {
308 (
309 F::from_canonical_u32(record.b[NUM_LIMBS - 1] as u32),
310 record.b[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)),
311 )
312 };
313
314 core_row.diff_val = if diff_idx == NUM_LIMBS {
315 F::ZERO
316 } else if diff_idx == (NUM_LIMBS - 1) {
317 if cmp_lt {
318 b_msb_f - a_msb_f
319 } else {
320 a_msb_f - b_msb_f
321 }
322 } else if cmp_lt {
323 F::from_canonical_u8(record.b[diff_idx] - record.a[diff_idx])
324 } else {
325 F::from_canonical_u8(record.a[diff_idx] - record.b[diff_idx])
326 };
327
328 self.bitwise_lookup_chip
329 .request_range(a_msb_range, b_msb_range);
330
331 core_row.diff_marker = [F::ZERO; NUM_LIMBS];
332
333 if diff_idx != NUM_LIMBS {
334 self.bitwise_lookup_chip
335 .request_range(core_row.diff_val.as_canonical_u32() - 1, 0);
336 core_row.diff_marker[diff_idx] = F::ONE;
337 }
338
339 core_row.cmp_lt = F::from_bool(cmp_lt);
340 core_row.b_msb_f = b_msb_f;
341 core_row.a_msb_f = a_msb_f;
342 core_row.opcode_bgeu_flag =
343 F::from_bool(record.local_opcode == BranchLessThanOpcode::BGEU as u8);
344 core_row.opcode_bge_flag =
345 F::from_bool(record.local_opcode == BranchLessThanOpcode::BGE as u8);
346 core_row.opcode_bltu_flag =
347 F::from_bool(record.local_opcode == BranchLessThanOpcode::BLTU as u8);
348 core_row.opcode_blt_flag =
349 F::from_bool(record.local_opcode == BranchLessThanOpcode::BLT as u8);
350
351 core_row.imm = F::from_canonical_u32(record.imm);
352 core_row.cmp_result = F::from_bool(cmp_result);
353 core_row.b = record.b.map(F::from_canonical_u8);
354 core_row.a = record.a.map(F::from_canonical_u8);
355 }
356}
357
358#[inline(always)]
360pub(super) fn run_cmp<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
361 local_opcode: u8,
362 x: &[u8; NUM_LIMBS],
363 y: &[u8; NUM_LIMBS],
364) -> (bool, usize, bool, bool) {
365 let signed = local_opcode == BranchLessThanOpcode::BLT as u8
366 || local_opcode == BranchLessThanOpcode::BGE as u8;
367 let ge_op = local_opcode == BranchLessThanOpcode::BGE as u8
368 || local_opcode == BranchLessThanOpcode::BGEU as u8;
369 let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
370 let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
371 for i in (0..NUM_LIMBS).rev() {
372 if x[i] != y[i] {
373 return ((x[i] < y[i]) ^ x_sign ^ y_sign ^ ge_op, i, x_sign, y_sign);
374 }
375 }
376 (ge_op, NUM_LIMBS, x_sign, y_sign)
377}