openvm_rv32im_circuit/branch_lt/
core.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::*,
5    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
6};
7use openvm_circuit_primitives::{
8    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
9    utils::not,
10    AlignedBytesBorrow,
11};
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
14use openvm_rv32im_transpiler::BranchLessThanOpcode;
15use openvm_stark_backend::{
16    interaction::InteractionBuilder,
17    p3_air::{AirBuilder, BaseAir},
18    p3_field::{Field, FieldAlgebra, PrimeField32},
19    rap::BaseAirWithPublicValues,
20};
21use strum::IntoEnumIterator;
22
23#[repr(C)]
24#[derive(AlignedBorrow)]
25pub struct BranchLessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
26    pub a: [T; NUM_LIMBS],
27    pub b: [T; NUM_LIMBS],
28
29    // Boolean result of a op b. Should branch if and only if cmp_result = 1.
30    pub cmp_result: T,
31    pub imm: T,
32
33    pub opcode_blt_flag: T,
34    pub opcode_bltu_flag: T,
35    pub opcode_bge_flag: T,
36    pub opcode_bgeu_flag: T,
37
38    // Most significant limb of a and b respectively as a field element, will be range
39    // checked to be within [-128, 127) if signed and [0, 256) if unsigned.
40    pub a_msb_f: T,
41    pub b_msb_f: T,
42
43    // 1 if a < b, 0 otherwise.
44    pub cmp_lt: T,
45
46    // 1 at the most significant index i such that a[i] != b[i], otherwise 0. If such
47    // an i exists, diff_val = b[i] - a[i].
48    pub diff_marker: [T; NUM_LIMBS],
49    pub diff_val: T,
50}
51
52#[derive(Copy, Clone, Debug, derive_new::new)]
53pub struct BranchLessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
54    pub bus: BitwiseOperationLookupBus,
55    offset: usize,
56}
57
58impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
59    for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
60{
61    fn width(&self) -> usize {
62        BranchLessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
63    }
64}
65impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
66    for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
67{
68}
69
70impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
71    for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
72where
73    AB: InteractionBuilder,
74    I: VmAdapterInterface<AB::Expr>,
75    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
76    I::Writes: Default,
77    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
78{
79    fn eval(
80        &self,
81        builder: &mut AB,
82        local_core: &[AB::Var],
83        from_pc: AB::Var,
84    ) -> AdapterAirContext<AB::Expr, I> {
85        let cols: &BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
86        let flags = [
87            cols.opcode_blt_flag,
88            cols.opcode_bltu_flag,
89            cols.opcode_bge_flag,
90            cols.opcode_bgeu_flag,
91        ];
92
93        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
94            builder.assert_bool(flag);
95            acc + flag.into()
96        });
97        builder.assert_bool(is_valid.clone());
98        builder.assert_bool(cols.cmp_result);
99
100        let lt = cols.opcode_blt_flag + cols.opcode_bltu_flag;
101        let ge = cols.opcode_bge_flag + cols.opcode_bgeu_flag;
102        let signed = cols.opcode_blt_flag + cols.opcode_bge_flag;
103        builder.assert_eq(
104            cols.cmp_lt,
105            cols.cmp_result * lt.clone() + not(cols.cmp_result) * ge.clone(),
106        );
107
108        let a = &cols.a;
109        let b = &cols.b;
110        let marker = &cols.diff_marker;
111        let mut prefix_sum = AB::Expr::ZERO;
112
113        // Check if a_msb_f and b_msb_f are signed values of a[NUM_LIMBS - 1] and b[NUM_LIMBS - 1]
114        // in prime field F.
115        let a_diff = a[NUM_LIMBS - 1] - cols.a_msb_f;
116        let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
117        builder
118            .assert_zero(a_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - a_diff));
119        builder
120            .assert_zero(b_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - b_diff));
121
122        for i in (0..NUM_LIMBS).rev() {
123            let diff = (if i == NUM_LIMBS - 1 {
124                cols.b_msb_f - cols.a_msb_f
125            } else {
126                b[i] - a[i]
127            }) * (AB::Expr::from_canonical_u8(2) * cols.cmp_lt - AB::Expr::ONE);
128            prefix_sum += marker[i].into();
129            builder.assert_bool(marker[i]);
130            builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
131            builder.when(marker[i]).assert_eq(cols.diff_val, diff);
132        }
133        // - If x != y, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index where
134        //   diff != 0. Constrains that diff == diff_val where diff_val is non-zero.
135        // - If x == y, then prefix_sum = 0 and cmp_lt = 0. Here, prefix_sum cannot be 1 because all
136        //   diff are zero, making diff == diff_val fails.
137
138        builder.assert_bool(prefix_sum.clone());
139        builder
140            .when(not::<AB::Expr>(prefix_sum.clone()))
141            .assert_zero(cols.cmp_lt);
142
143        // Check if a_msb_f and b_msb_f are in [-128, 127) if signed, [0, 256) if unsigned.
144        self.bus
145            .send_range(
146                cols.a_msb_f + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
147                cols.b_msb_f + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
148            )
149            .eval(builder, is_valid.clone());
150
151        // Range check to ensure diff_val is non-zero.
152        self.bus
153            .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
154            .eval(builder, prefix_sum);
155
156        let expected_opcode = flags
157            .iter()
158            .zip(BranchLessThanOpcode::iter())
159            .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
160                acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
161            })
162            + AB::Expr::from_canonical_usize(self.offset);
163
164        let to_pc = from_pc
165            + cols.cmp_result * cols.imm
166            + not(cols.cmp_result) * AB::Expr::from_canonical_u32(DEFAULT_PC_STEP);
167
168        AdapterAirContext {
169            to_pc: Some(to_pc),
170            reads: [cols.a.map(Into::into), cols.b.map(Into::into)].into(),
171            writes: Default::default(),
172            instruction: ImmInstruction {
173                is_valid,
174                opcode: expected_opcode,
175                immediate: cols.imm.into(),
176            }
177            .into(),
178        }
179    }
180
181    fn start_offset(&self) -> usize {
182        self.offset
183    }
184}
185
186#[repr(C)]
187#[derive(AlignedBytesBorrow, Debug)]
188pub struct BranchLessThanCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
189    pub a: [u8; NUM_LIMBS],
190    pub b: [u8; NUM_LIMBS],
191    pub imm: u32,
192    pub local_opcode: u8,
193}
194
195#[derive(Clone, Copy, derive_new::new)]
196pub struct BranchLessThanExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
197    adapter: A,
198    pub offset: usize,
199}
200
201#[derive(Clone, derive_new::new)]
202pub struct BranchLessThanFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
203    adapter: A,
204    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
205    pub offset: usize,
206}
207
208impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
209    for BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
210where
211    F: PrimeField32,
212    A: 'static + AdapterTraceExecutor<F, ReadData: Into<[[u8; NUM_LIMBS]; 2]>, WriteData = ()>,
213    for<'buf> RA: RecordArena<
214        'buf,
215        EmptyAdapterCoreLayout<F, A>,
216        (
217            A::RecordMut<'buf>,
218            &'buf mut BranchLessThanCoreRecord<NUM_LIMBS, LIMB_BITS>,
219        ),
220    >,
221{
222    fn get_opcode_name(&self, opcode: usize) -> String {
223        format!(
224            "{:?}",
225            BranchLessThanOpcode::from_usize(opcode - self.offset)
226        )
227    }
228
229    fn execute(
230        &self,
231        state: VmStateMut<F, TracingMemory, RA>,
232        instruction: &Instruction<F>,
233    ) -> Result<(), ExecutionError> {
234        let &Instruction { opcode, c: imm, .. } = instruction;
235
236        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
237
238        A::start(*state.pc, state.memory, &mut adapter_record);
239
240        let [rs1, rs2] = self
241            .adapter
242            .read(state.memory, instruction, &mut adapter_record)
243            .into();
244
245        core_record.a = rs1;
246        core_record.b = rs2;
247        core_record.imm = imm.as_canonical_u32();
248        core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8;
249
250        if run_cmp::<NUM_LIMBS, LIMB_BITS>(core_record.local_opcode, &rs1, &rs2).0 {
251            *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32();
252        } else {
253            *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
254        }
255
256        Ok(())
257    }
258}
259
260impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
261    for BranchLessThanFiller<A, NUM_LIMBS, LIMB_BITS>
262where
263    F: PrimeField32,
264    A: 'static + AdapterTraceFiller<F>,
265{
266    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
267        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
268        // BranchLessThanCoreCols::width() elements
269        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
270
271        // SAFETY: core_row contains a valid BranchLessThanCoreRecord written by the executor
272        // during trace generation
273        let record: &BranchLessThanCoreRecord<NUM_LIMBS, LIMB_BITS> =
274            unsafe { get_record_from_slice(&mut core_row, ()) };
275
276        self.adapter.fill_trace_row(mem_helper, adapter_row);
277        let core_row: &mut BranchLessThanCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
278
279        let signed = record.local_opcode == BranchLessThanOpcode::BLT as u8
280            || record.local_opcode == BranchLessThanOpcode::BGE as u8;
281        let ge_op = record.local_opcode == BranchLessThanOpcode::BGE as u8
282            || record.local_opcode == BranchLessThanOpcode::BGEU as u8;
283
284        let (cmp_result, diff_idx, a_sign, b_sign) =
285            run_cmp::<NUM_LIMBS, LIMB_BITS>(record.local_opcode, &record.a, &record.b);
286
287        let cmp_lt = cmp_result ^ ge_op;
288
289        // We range check (a_msb_f + 128) and (b_msb_f + 128) if signed,
290        // a_msb_f and b_msb_f if not
291        let (a_msb_f, a_msb_range) = if a_sign {
292            (
293                -F::from_canonical_u32((1 << LIMB_BITS) - record.a[NUM_LIMBS - 1] as u32),
294                record.a[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)),
295            )
296        } else {
297            (
298                F::from_canonical_u32(record.a[NUM_LIMBS - 1] as u32),
299                record.a[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)),
300            )
301        };
302        let (b_msb_f, b_msb_range) = if b_sign {
303            (
304                -F::from_canonical_u32((1 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u32),
305                record.b[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)),
306            )
307        } else {
308            (
309                F::from_canonical_u32(record.b[NUM_LIMBS - 1] as u32),
310                record.b[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)),
311            )
312        };
313
314        core_row.diff_val = if diff_idx == NUM_LIMBS {
315            F::ZERO
316        } else if diff_idx == (NUM_LIMBS - 1) {
317            if cmp_lt {
318                b_msb_f - a_msb_f
319            } else {
320                a_msb_f - b_msb_f
321            }
322        } else if cmp_lt {
323            F::from_canonical_u8(record.b[diff_idx] - record.a[diff_idx])
324        } else {
325            F::from_canonical_u8(record.a[diff_idx] - record.b[diff_idx])
326        };
327
328        self.bitwise_lookup_chip
329            .request_range(a_msb_range, b_msb_range);
330
331        core_row.diff_marker = [F::ZERO; NUM_LIMBS];
332
333        if diff_idx != NUM_LIMBS {
334            self.bitwise_lookup_chip
335                .request_range(core_row.diff_val.as_canonical_u32() - 1, 0);
336            core_row.diff_marker[diff_idx] = F::ONE;
337        }
338
339        core_row.cmp_lt = F::from_bool(cmp_lt);
340        core_row.b_msb_f = b_msb_f;
341        core_row.a_msb_f = a_msb_f;
342        core_row.opcode_bgeu_flag =
343            F::from_bool(record.local_opcode == BranchLessThanOpcode::BGEU as u8);
344        core_row.opcode_bge_flag =
345            F::from_bool(record.local_opcode == BranchLessThanOpcode::BGE as u8);
346        core_row.opcode_bltu_flag =
347            F::from_bool(record.local_opcode == BranchLessThanOpcode::BLTU as u8);
348        core_row.opcode_blt_flag =
349            F::from_bool(record.local_opcode == BranchLessThanOpcode::BLT as u8);
350
351        core_row.imm = F::from_canonical_u32(record.imm);
352        core_row.cmp_result = F::from_bool(cmp_result);
353        core_row.b = record.b.map(F::from_canonical_u8);
354        core_row.a = record.a.map(F::from_canonical_u8);
355    }
356}
357
358// Returns (cmp_result, diff_idx, x_sign, y_sign)
359#[inline(always)]
360pub(super) fn run_cmp<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
361    local_opcode: u8,
362    x: &[u8; NUM_LIMBS],
363    y: &[u8; NUM_LIMBS],
364) -> (bool, usize, bool, bool) {
365    let signed = local_opcode == BranchLessThanOpcode::BLT as u8
366        || local_opcode == BranchLessThanOpcode::BGE as u8;
367    let ge_op = local_opcode == BranchLessThanOpcode::BGE as u8
368        || local_opcode == BranchLessThanOpcode::BGEU as u8;
369    let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
370    let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
371    for i in (0..NUM_LIMBS).rev() {
372        if x[i] != y[i] {
373            return ((x[i] < y[i]) ^ x_sign ^ y_sign ^ ge_op, i, x_sign, y_sign);
374        }
375    }
376    (ge_op, NUM_LIMBS, x_sign, y_sign)
377}