1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4 arch::*,
5 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
6};
7use openvm_circuit_primitives::{
8 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
9 utils::not,
10 AlignedBytesBorrow,
11};
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
14use openvm_rv32im_transpiler::BranchLessThanOpcode;
15use openvm_stark_backend::{
16 interaction::InteractionBuilder,
17 p3_air::{AirBuilder, BaseAir},
18 p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
19 rap::BaseAirWithPublicValues,
20};
21use strum::IntoEnumIterator;
22
23#[repr(C)]
24#[derive(AlignedBorrow)]
25pub struct BranchLessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
26 pub a: [T; NUM_LIMBS],
27 pub b: [T; NUM_LIMBS],
28
29 pub cmp_result: T,
31 pub imm: T,
32
33 pub opcode_blt_flag: T,
34 pub opcode_bltu_flag: T,
35 pub opcode_bge_flag: T,
36 pub opcode_bgeu_flag: T,
37
38 pub a_msb_f: T,
41 pub b_msb_f: T,
42
43 pub cmp_lt: T,
45
46 pub diff_marker: [T; NUM_LIMBS],
49 pub diff_val: T,
50}
51
52#[derive(Copy, Clone, Debug, derive_new::new)]
53pub struct BranchLessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
54 pub bus: BitwiseOperationLookupBus,
55 offset: usize,
56}
57
58impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
59 for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
60{
61 fn width(&self) -> usize {
62 BranchLessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
63 }
64}
65impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
66 for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
67{
68}
69
70impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
71 for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
72where
73 AB: InteractionBuilder,
74 I: VmAdapterInterface<AB::Expr>,
75 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
76 I::Writes: Default,
77 I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
78{
79 fn eval(
80 &self,
81 builder: &mut AB,
82 local_core: &[AB::Var],
83 from_pc: AB::Var,
84 ) -> AdapterAirContext<AB::Expr, I> {
85 let cols: &BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
86 let flags = [
87 cols.opcode_blt_flag,
88 cols.opcode_bltu_flag,
89 cols.opcode_bge_flag,
90 cols.opcode_bgeu_flag,
91 ];
92
93 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
94 builder.assert_bool(flag);
95 acc + flag.into()
96 });
97 builder.assert_bool(is_valid.clone());
98 builder.assert_bool(cols.cmp_result);
99
100 let lt = cols.opcode_blt_flag + cols.opcode_bltu_flag;
101 let ge = cols.opcode_bge_flag + cols.opcode_bgeu_flag;
102 let signed = cols.opcode_blt_flag + cols.opcode_bge_flag;
103 builder.assert_eq(
104 cols.cmp_lt,
105 cols.cmp_result * lt.clone() + not(cols.cmp_result) * ge.clone(),
106 );
107
108 let a = &cols.a;
109 let b = &cols.b;
110 let marker = &cols.diff_marker;
111 let mut prefix_sum = AB::Expr::ZERO;
112
113 let a_diff = a[NUM_LIMBS - 1] - cols.a_msb_f;
116 let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
117 builder.assert_zero(a_diff.clone() * (AB::Expr::from_u32(1 << LIMB_BITS) - a_diff));
118 builder.assert_zero(b_diff.clone() * (AB::Expr::from_u32(1 << LIMB_BITS) - b_diff));
119
120 for i in (0..NUM_LIMBS).rev() {
121 let diff = (if i == NUM_LIMBS - 1 {
122 cols.b_msb_f - cols.a_msb_f
123 } else {
124 b[i] - a[i]
125 }) * (AB::Expr::from_u8(2) * cols.cmp_lt - AB::Expr::ONE);
126 prefix_sum += marker[i].into();
127 builder.assert_bool(marker[i]);
128 builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
129 builder.when(marker[i]).assert_eq(cols.diff_val, diff);
130 }
131 builder.assert_bool(prefix_sum.clone());
137 builder
138 .when(not::<AB::Expr>(prefix_sum.clone()))
139 .assert_zero(cols.cmp_lt);
140
141 self.bus
143 .send_range(
144 cols.a_msb_f + AB::Expr::from_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
145 cols.b_msb_f + AB::Expr::from_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
146 )
147 .eval(builder, is_valid.clone());
148
149 self.bus
151 .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
152 .eval(builder, prefix_sum);
153
154 let expected_opcode = flags
155 .iter()
156 .zip(BranchLessThanOpcode::iter())
157 .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
158 acc + (*flag).into() * AB::Expr::from_u8(opcode as u8)
159 })
160 + AB::Expr::from_usize(self.offset);
161
162 let to_pc = from_pc
163 + cols.cmp_result * cols.imm
164 + not(cols.cmp_result) * AB::Expr::from_u32(DEFAULT_PC_STEP);
165
166 AdapterAirContext {
167 to_pc: Some(to_pc),
168 reads: [cols.a.map(Into::into), cols.b.map(Into::into)].into(),
169 writes: Default::default(),
170 instruction: ImmInstruction {
171 is_valid,
172 opcode: expected_opcode,
173 immediate: cols.imm.into(),
174 }
175 .into(),
176 }
177 }
178
179 fn start_offset(&self) -> usize {
180 self.offset
181 }
182}
183
184#[repr(C)]
185#[derive(AlignedBytesBorrow, Debug)]
186pub struct BranchLessThanCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
187 pub a: [u8; NUM_LIMBS],
188 pub b: [u8; NUM_LIMBS],
189 pub imm: u32,
190 pub local_opcode: u8,
191}
192
193#[derive(Clone, Copy, derive_new::new)]
194pub struct BranchLessThanExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
195 adapter: A,
196 pub offset: usize,
197}
198
199#[derive(Clone, derive_new::new)]
200pub struct BranchLessThanFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
201 adapter: A,
202 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
203 pub offset: usize,
204}
205
206impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
207 for BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
208where
209 F: PrimeField32,
210 A: 'static + AdapterTraceExecutor<F, ReadData: Into<[[u8; NUM_LIMBS]; 2]>, WriteData = ()>,
211 for<'buf> RA: RecordArena<
212 'buf,
213 EmptyAdapterCoreLayout<F, A>,
214 (
215 A::RecordMut<'buf>,
216 &'buf mut BranchLessThanCoreRecord<NUM_LIMBS, LIMB_BITS>,
217 ),
218 >,
219{
220 fn get_opcode_name(&self, opcode: usize) -> String {
221 format!(
222 "{:?}",
223 BranchLessThanOpcode::from_usize(opcode - self.offset)
224 )
225 }
226
227 fn execute(
228 &self,
229 state: VmStateMut<F, TracingMemory, RA>,
230 instruction: &Instruction<F>,
231 ) -> Result<(), ExecutionError> {
232 let &Instruction { opcode, c: imm, .. } = instruction;
233
234 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
235
236 A::start(*state.pc, state.memory, &mut adapter_record);
237
238 let [rs1, rs2] = self
239 .adapter
240 .read(state.memory, instruction, &mut adapter_record)
241 .into();
242
243 core_record.a = rs1;
244 core_record.b = rs2;
245 core_record.imm = imm.as_canonical_u32();
246 core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8;
247
248 if run_cmp::<NUM_LIMBS, LIMB_BITS>(core_record.local_opcode, &rs1, &rs2).0 {
249 *state.pc = (F::from_u32(*state.pc) + imm).as_canonical_u32();
250 } else {
251 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
252 }
253
254 Ok(())
255 }
256}
257
258impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
259 for BranchLessThanFiller<A, NUM_LIMBS, LIMB_BITS>
260where
261 F: PrimeField32,
262 A: 'static + AdapterTraceFiller<F>,
263{
264 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
265 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
268
269 let record: &BranchLessThanCoreRecord<NUM_LIMBS, LIMB_BITS> =
272 unsafe { get_record_from_slice(&mut core_row, ()) };
273
274 self.adapter.fill_trace_row(mem_helper, adapter_row);
275 let core_row: &mut BranchLessThanCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
276
277 let signed = record.local_opcode == BranchLessThanOpcode::BLT as u8
278 || record.local_opcode == BranchLessThanOpcode::BGE as u8;
279 let ge_op = record.local_opcode == BranchLessThanOpcode::BGE as u8
280 || record.local_opcode == BranchLessThanOpcode::BGEU as u8;
281
282 let (cmp_result, diff_idx, a_sign, b_sign) =
283 run_cmp::<NUM_LIMBS, LIMB_BITS>(record.local_opcode, &record.a, &record.b);
284
285 let cmp_lt = cmp_result ^ ge_op;
286
287 let (a_msb_f, a_msb_range) = if a_sign {
290 (
291 -F::from_u32((1 << LIMB_BITS) - record.a[NUM_LIMBS - 1] as u32),
292 record.a[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)),
293 )
294 } else {
295 (
296 F::from_u32(record.a[NUM_LIMBS - 1] as u32),
297 record.a[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)),
298 )
299 };
300 let (b_msb_f, b_msb_range) = if b_sign {
301 (
302 -F::from_u32((1 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u32),
303 record.b[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)),
304 )
305 } else {
306 (
307 F::from_u32(record.b[NUM_LIMBS - 1] as u32),
308 record.b[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)),
309 )
310 };
311
312 core_row.diff_val = if diff_idx == NUM_LIMBS {
313 F::ZERO
314 } else if diff_idx == (NUM_LIMBS - 1) {
315 if cmp_lt {
316 b_msb_f - a_msb_f
317 } else {
318 a_msb_f - b_msb_f
319 }
320 } else if cmp_lt {
321 F::from_u8(record.b[diff_idx] - record.a[diff_idx])
322 } else {
323 F::from_u8(record.a[diff_idx] - record.b[diff_idx])
324 };
325
326 self.bitwise_lookup_chip
327 .request_range(a_msb_range, b_msb_range);
328
329 core_row.diff_marker = [F::ZERO; NUM_LIMBS];
330
331 if diff_idx != NUM_LIMBS {
332 self.bitwise_lookup_chip
333 .request_range(core_row.diff_val.as_canonical_u32() - 1, 0);
334 core_row.diff_marker[diff_idx] = F::ONE;
335 }
336
337 core_row.cmp_lt = F::from_bool(cmp_lt);
338 core_row.b_msb_f = b_msb_f;
339 core_row.a_msb_f = a_msb_f;
340 core_row.opcode_bgeu_flag =
341 F::from_bool(record.local_opcode == BranchLessThanOpcode::BGEU as u8);
342 core_row.opcode_bge_flag =
343 F::from_bool(record.local_opcode == BranchLessThanOpcode::BGE as u8);
344 core_row.opcode_bltu_flag =
345 F::from_bool(record.local_opcode == BranchLessThanOpcode::BLTU as u8);
346 core_row.opcode_blt_flag =
347 F::from_bool(record.local_opcode == BranchLessThanOpcode::BLT as u8);
348
349 core_row.imm = F::from_u32(record.imm);
350 core_row.cmp_result = F::from_bool(cmp_result);
351 core_row.b = record.b.map(F::from_u8);
352 core_row.a = record.a.map(F::from_u8);
353 }
354}
355
356#[inline(always)]
358pub(super) fn run_cmp<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
359 local_opcode: u8,
360 x: &[u8; NUM_LIMBS],
361 y: &[u8; NUM_LIMBS],
362) -> (bool, usize, bool, bool) {
363 let signed = local_opcode == BranchLessThanOpcode::BLT as u8
364 || local_opcode == BranchLessThanOpcode::BGE as u8;
365 let ge_op = local_opcode == BranchLessThanOpcode::BGE as u8
366 || local_opcode == BranchLessThanOpcode::BGEU as u8;
367 let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
368 let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
369 for i in (0..NUM_LIMBS).rev() {
370 if x[i] != y[i] {
371 return ((x[i] < y[i]) ^ x_sign ^ y_sign ^ ge_op, i, x_sign, y_sign);
372 }
373 }
374 (ge_op, NUM_LIMBS, x_sign, y_sign)
375}