1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface,
8 VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12 utils::not,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
16use openvm_rv32im_transpiler::BranchLessThanOpcode;
17use openvm_stark_backend::{
18 interaction::InteractionBuilder,
19 p3_air::{AirBuilder, BaseAir},
20 p3_field::{Field, FieldAlgebra, PrimeField32},
21 rap::BaseAirWithPublicValues,
22};
23use serde::{Deserialize, Serialize};
24use serde_big_array::BigArray;
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow)]
29pub struct BranchLessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30 pub a: [T; NUM_LIMBS],
31 pub b: [T; NUM_LIMBS],
32
33 pub cmp_result: T,
35 pub imm: T,
36
37 pub opcode_blt_flag: T,
38 pub opcode_bltu_flag: T,
39 pub opcode_bge_flag: T,
40 pub opcode_bgeu_flag: T,
41
42 pub a_msb_f: T,
45 pub b_msb_f: T,
46
47 pub cmp_lt: T,
49
50 pub diff_marker: [T; NUM_LIMBS],
53 pub diff_val: T,
54}
55
56#[derive(Copy, Clone, Debug)]
57pub struct BranchLessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
58 pub bus: BitwiseOperationLookupBus,
59 offset: usize,
60}
61
62impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
63 for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
64{
65 fn width(&self) -> usize {
66 BranchLessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
67 }
68}
69impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
70 for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
71{
72}
73
74impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
75 for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
76where
77 AB: InteractionBuilder,
78 I: VmAdapterInterface<AB::Expr>,
79 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
80 I::Writes: Default,
81 I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
82{
83 fn eval(
84 &self,
85 builder: &mut AB,
86 local_core: &[AB::Var],
87 from_pc: AB::Var,
88 ) -> AdapterAirContext<AB::Expr, I> {
89 let cols: &BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
90 let flags = [
91 cols.opcode_blt_flag,
92 cols.opcode_bltu_flag,
93 cols.opcode_bge_flag,
94 cols.opcode_bgeu_flag,
95 ];
96
97 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
98 builder.assert_bool(flag);
99 acc + flag.into()
100 });
101 builder.assert_bool(is_valid.clone());
102 builder.assert_bool(cols.cmp_result);
103
104 let lt = cols.opcode_blt_flag + cols.opcode_bltu_flag;
105 let ge = cols.opcode_bge_flag + cols.opcode_bgeu_flag;
106 let signed = cols.opcode_blt_flag + cols.opcode_bge_flag;
107 builder.assert_eq(
108 cols.cmp_lt,
109 cols.cmp_result * lt.clone() + not(cols.cmp_result) * ge.clone(),
110 );
111
112 let a = &cols.a;
113 let b = &cols.b;
114 let marker = &cols.diff_marker;
115 let mut prefix_sum = AB::Expr::ZERO;
116
117 let a_diff = a[NUM_LIMBS - 1] - cols.a_msb_f;
119 let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
120 builder
121 .assert_zero(a_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - a_diff));
122 builder
123 .assert_zero(b_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - b_diff));
124
125 for i in (0..NUM_LIMBS).rev() {
126 let diff = (if i == NUM_LIMBS - 1 {
127 cols.b_msb_f - cols.a_msb_f
128 } else {
129 b[i] - a[i]
130 }) * (AB::Expr::from_canonical_u8(2) * cols.cmp_lt - AB::Expr::ONE);
131 prefix_sum += marker[i].into();
132 builder.assert_bool(marker[i]);
133 builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
134 builder.when(marker[i]).assert_eq(cols.diff_val, diff);
135 }
136 builder.assert_bool(prefix_sum.clone());
142 builder
143 .when(not::<AB::Expr>(prefix_sum.clone()))
144 .assert_zero(cols.cmp_lt);
145
146 self.bus
148 .send_range(
149 cols.a_msb_f + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
150 cols.b_msb_f + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
151 )
152 .eval(builder, is_valid.clone());
153
154 self.bus
156 .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
157 .eval(builder, prefix_sum);
158
159 let expected_opcode = flags
160 .iter()
161 .zip(BranchLessThanOpcode::iter())
162 .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
163 acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
164 })
165 + AB::Expr::from_canonical_usize(self.offset);
166
167 let to_pc = from_pc
168 + cols.cmp_result * cols.imm
169 + not(cols.cmp_result) * AB::Expr::from_canonical_u32(DEFAULT_PC_STEP);
170
171 AdapterAirContext {
172 to_pc: Some(to_pc),
173 reads: [cols.a.map(Into::into), cols.b.map(Into::into)].into(),
174 writes: Default::default(),
175 instruction: ImmInstruction {
176 is_valid,
177 opcode: expected_opcode,
178 immediate: cols.imm.into(),
179 }
180 .into(),
181 }
182 }
183
184 fn start_offset(&self) -> usize {
185 self.offset
186 }
187}
188
189#[repr(C)]
190#[derive(Clone, Debug, Serialize, Deserialize)]
191pub struct BranchLessThanCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
192 #[serde(with = "BigArray")]
193 pub a: [T; NUM_LIMBS],
194 #[serde(with = "BigArray")]
195 pub b: [T; NUM_LIMBS],
196 pub cmp_result: T,
197 pub cmp_lt: T,
198 pub imm: T,
199 pub a_msb_f: T,
200 pub b_msb_f: T,
201 pub diff_val: T,
202 pub diff_idx: usize,
203 pub opcode: BranchLessThanOpcode,
204}
205
206pub struct BranchLessThanCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
207 pub air: BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>,
208 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
209}
210
211impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> BranchLessThanCoreChip<NUM_LIMBS, LIMB_BITS> {
212 pub fn new(
213 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
214 offset: usize,
215 ) -> Self {
216 Self {
217 air: BranchLessThanCoreAir {
218 bus: bitwise_lookup_chip.bus(),
219 offset,
220 },
221 bitwise_lookup_chip,
222 }
223 }
224}
225
226impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
227 VmCoreChip<F, I> for BranchLessThanCoreChip<NUM_LIMBS, LIMB_BITS>
228where
229 I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
230 I::Writes: Default,
231{
232 type Record = BranchLessThanCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
233 type Air = BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>;
234
235 #[allow(clippy::type_complexity)]
236 fn execute_instruction(
237 &self,
238 instruction: &Instruction<F>,
239 from_pc: u32,
240 reads: I::Reads,
241 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
242 let Instruction { opcode, c: imm, .. } = *instruction;
243 let blt_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
244
245 let data: [[F; NUM_LIMBS]; 2] = reads.into();
246 let a = data[0].map(|x| x.as_canonical_u32());
247 let b = data[1].map(|y| y.as_canonical_u32());
248 let (cmp_result, diff_idx, a_sign, b_sign) =
249 run_cmp::<NUM_LIMBS, LIMB_BITS>(blt_opcode, &a, &b);
250
251 let signed = matches!(
252 blt_opcode,
253 BranchLessThanOpcode::BLT | BranchLessThanOpcode::BGE
254 );
255 let ge_opcode = matches!(
256 blt_opcode,
257 BranchLessThanOpcode::BGE | BranchLessThanOpcode::BGEU
258 );
259 let cmp_lt = cmp_result ^ ge_opcode;
260
261 let (a_msb_f, a_msb_range) = if a_sign {
264 (
265 -F::from_canonical_u32((1 << LIMB_BITS) - a[NUM_LIMBS - 1]),
266 a[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)),
267 )
268 } else {
269 (
270 F::from_canonical_u32(a[NUM_LIMBS - 1]),
271 a[NUM_LIMBS - 1] + ((signed as u32) << (LIMB_BITS - 1)),
272 )
273 };
274 let (b_msb_f, b_msb_range) = if b_sign {
275 (
276 -F::from_canonical_u32((1 << LIMB_BITS) - b[NUM_LIMBS - 1]),
277 b[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)),
278 )
279 } else {
280 (
281 F::from_canonical_u32(b[NUM_LIMBS - 1]),
282 b[NUM_LIMBS - 1] + ((signed as u32) << (LIMB_BITS - 1)),
283 )
284 };
285 self.bitwise_lookup_chip
286 .request_range(a_msb_range, b_msb_range);
287
288 let diff_val = if diff_idx == NUM_LIMBS {
289 0
290 } else if diff_idx == (NUM_LIMBS - 1) {
291 if cmp_lt {
292 b_msb_f - a_msb_f
293 } else {
294 a_msb_f - b_msb_f
295 }
296 .as_canonical_u32()
297 } else if cmp_lt {
298 b[diff_idx] - a[diff_idx]
299 } else {
300 a[diff_idx] - b[diff_idx]
301 };
302
303 if diff_idx != NUM_LIMBS {
304 self.bitwise_lookup_chip.request_range(diff_val - 1, 0);
305 }
306
307 let output = AdapterRuntimeContext {
308 to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()),
309 writes: Default::default(),
310 };
311 let record = BranchLessThanCoreRecord {
312 opcode: blt_opcode,
313 a: data[0],
314 b: data[1],
315 cmp_result: F::from_bool(cmp_result),
316 cmp_lt: F::from_bool(cmp_lt),
317 imm,
318 a_msb_f,
319 b_msb_f,
320 diff_val: F::from_canonical_u32(diff_val),
321 diff_idx,
322 };
323
324 Ok((output, record))
325 }
326
327 fn get_opcode_name(&self, opcode: usize) -> String {
328 format!(
329 "{:?}",
330 BranchLessThanOpcode::from_usize(opcode - self.air.offset)
331 )
332 }
333
334 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
335 let row_slice: &mut BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> =
336 row_slice.borrow_mut();
337 row_slice.a = record.a;
338 row_slice.b = record.b;
339 row_slice.cmp_result = record.cmp_result;
340 row_slice.cmp_lt = record.cmp_lt;
341 row_slice.imm = record.imm;
342 row_slice.a_msb_f = record.a_msb_f;
343 row_slice.b_msb_f = record.b_msb_f;
344 row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx));
345 row_slice.diff_val = record.diff_val;
346 row_slice.opcode_blt_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLT);
347 row_slice.opcode_bltu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLTU);
348 row_slice.opcode_bge_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGE);
349 row_slice.opcode_bgeu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGEU);
350 }
351
352 fn air(&self) -> &Self::Air {
353 &self.air
354 }
355}
356
357pub(super) fn run_cmp<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
359 local_opcode: BranchLessThanOpcode,
360 x: &[u32; NUM_LIMBS],
361 y: &[u32; NUM_LIMBS],
362) -> (bool, usize, bool, bool) {
363 let signed =
364 local_opcode == BranchLessThanOpcode::BLT || local_opcode == BranchLessThanOpcode::BGE;
365 let ge_op =
366 local_opcode == BranchLessThanOpcode::BGE || local_opcode == BranchLessThanOpcode::BGEU;
367 let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
368 let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
369 for i in (0..NUM_LIMBS).rev() {
370 if x[i] != y[i] {
371 return ((x[i] < y[i]) ^ x_sign ^ y_sign ^ ge_op, i, x_sign, y_sign);
372 }
373 }
374 (ge_op, NUM_LIMBS, x_sign, y_sign)
375}