openvm_rv32im_circuit/branch_lt/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface,
8    VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    utils::not,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
16use openvm_rv32im_transpiler::BranchLessThanOpcode;
17use openvm_stark_backend::{
18    interaction::InteractionBuilder,
19    p3_air::{AirBuilder, BaseAir},
20    p3_field::{Field, FieldAlgebra, PrimeField32},
21    rap::BaseAirWithPublicValues,
22};
23use serde::{Deserialize, Serialize};
24use serde_big_array::BigArray;
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow)]
29pub struct BranchLessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30    pub a: [T; NUM_LIMBS],
31    pub b: [T; NUM_LIMBS],
32
33    // Boolean result of a op b. Should branch if and only if cmp_result = 1.
34    pub cmp_result: T,
35    pub imm: T,
36
37    pub opcode_blt_flag: T,
38    pub opcode_bltu_flag: T,
39    pub opcode_bge_flag: T,
40    pub opcode_bgeu_flag: T,
41
42    // Most significant limb of a and b respectively as a field element, will be range
43    // checked to be within [-128, 127) if signed and [0, 256) if unsigned.
44    pub a_msb_f: T,
45    pub b_msb_f: T,
46
47    // 1 if a < b, 0 otherwise.
48    pub cmp_lt: T,
49
50    // 1 at the most significant index i such that a[i] != b[i], otherwise 0. If such
51    // an i exists, diff_val = b[i] - a[i].
52    pub diff_marker: [T; NUM_LIMBS],
53    pub diff_val: T,
54}
55
56#[derive(Copy, Clone, Debug)]
57pub struct BranchLessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
58    pub bus: BitwiseOperationLookupBus,
59    offset: usize,
60}
61
62impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
63    for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
64{
65    fn width(&self) -> usize {
66        BranchLessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
67    }
68}
69impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
70    for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
71{
72}
73
74impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
75    for BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>
76where
77    AB: InteractionBuilder,
78    I: VmAdapterInterface<AB::Expr>,
79    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
80    I::Writes: Default,
81    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
82{
83    fn eval(
84        &self,
85        builder: &mut AB,
86        local_core: &[AB::Var],
87        from_pc: AB::Var,
88    ) -> AdapterAirContext<AB::Expr, I> {
89        let cols: &BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
90        let flags = [
91            cols.opcode_blt_flag,
92            cols.opcode_bltu_flag,
93            cols.opcode_bge_flag,
94            cols.opcode_bgeu_flag,
95        ];
96
97        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
98            builder.assert_bool(flag);
99            acc + flag.into()
100        });
101        builder.assert_bool(is_valid.clone());
102        builder.assert_bool(cols.cmp_result);
103
104        let lt = cols.opcode_blt_flag + cols.opcode_bltu_flag;
105        let ge = cols.opcode_bge_flag + cols.opcode_bgeu_flag;
106        let signed = cols.opcode_blt_flag + cols.opcode_bge_flag;
107        builder.assert_eq(
108            cols.cmp_lt,
109            cols.cmp_result * lt.clone() + not(cols.cmp_result) * ge.clone(),
110        );
111
112        let a = &cols.a;
113        let b = &cols.b;
114        let marker = &cols.diff_marker;
115        let mut prefix_sum = AB::Expr::ZERO;
116
117        // Check if a_msb_f and b_msb_f are signed values of a[NUM_LIMBS - 1] and b[NUM_LIMBS - 1] in prime field F.
118        let a_diff = a[NUM_LIMBS - 1] - cols.a_msb_f;
119        let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
120        builder
121            .assert_zero(a_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - a_diff));
122        builder
123            .assert_zero(b_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - b_diff));
124
125        for i in (0..NUM_LIMBS).rev() {
126            let diff = (if i == NUM_LIMBS - 1 {
127                cols.b_msb_f - cols.a_msb_f
128            } else {
129                b[i] - a[i]
130            }) * (AB::Expr::from_canonical_u8(2) * cols.cmp_lt - AB::Expr::ONE);
131            prefix_sum += marker[i].into();
132            builder.assert_bool(marker[i]);
133            builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
134            builder.when(marker[i]).assert_eq(cols.diff_val, diff);
135        }
136        // - If x != y, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index where diff != 0.
137        //   Constrains that diff == diff_val where diff_val is non-zero.
138        // - If x == y, then prefix_sum = 0 and cmp_lt = 0.
139        //   Here, prefix_sum cannot be 1 because all diff are zero, making diff == diff_val fails.
140
141        builder.assert_bool(prefix_sum.clone());
142        builder
143            .when(not::<AB::Expr>(prefix_sum.clone()))
144            .assert_zero(cols.cmp_lt);
145
146        // Check if a_msb_f and b_msb_f are in [-128, 127) if signed, [0, 256) if unsigned.
147        self.bus
148            .send_range(
149                cols.a_msb_f + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
150                cols.b_msb_f + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * signed.clone(),
151            )
152            .eval(builder, is_valid.clone());
153
154        // Range check to ensure diff_val is non-zero.
155        self.bus
156            .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
157            .eval(builder, prefix_sum);
158
159        let expected_opcode = flags
160            .iter()
161            .zip(BranchLessThanOpcode::iter())
162            .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
163                acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
164            })
165            + AB::Expr::from_canonical_usize(self.offset);
166
167        let to_pc = from_pc
168            + cols.cmp_result * cols.imm
169            + not(cols.cmp_result) * AB::Expr::from_canonical_u32(DEFAULT_PC_STEP);
170
171        AdapterAirContext {
172            to_pc: Some(to_pc),
173            reads: [cols.a.map(Into::into), cols.b.map(Into::into)].into(),
174            writes: Default::default(),
175            instruction: ImmInstruction {
176                is_valid,
177                opcode: expected_opcode,
178                immediate: cols.imm.into(),
179            }
180            .into(),
181        }
182    }
183
184    fn start_offset(&self) -> usize {
185        self.offset
186    }
187}
188
189#[repr(C)]
190#[derive(Clone, Debug, Serialize, Deserialize)]
191pub struct BranchLessThanCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
192    #[serde(with = "BigArray")]
193    pub a: [T; NUM_LIMBS],
194    #[serde(with = "BigArray")]
195    pub b: [T; NUM_LIMBS],
196    pub cmp_result: T,
197    pub cmp_lt: T,
198    pub imm: T,
199    pub a_msb_f: T,
200    pub b_msb_f: T,
201    pub diff_val: T,
202    pub diff_idx: usize,
203    pub opcode: BranchLessThanOpcode,
204}
205
206pub struct BranchLessThanCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
207    pub air: BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>,
208    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
209}
210
211impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> BranchLessThanCoreChip<NUM_LIMBS, LIMB_BITS> {
212    pub fn new(
213        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
214        offset: usize,
215    ) -> Self {
216        Self {
217            air: BranchLessThanCoreAir {
218                bus: bitwise_lookup_chip.bus(),
219                offset,
220            },
221            bitwise_lookup_chip,
222        }
223    }
224}
225
226impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
227    VmCoreChip<F, I> for BranchLessThanCoreChip<NUM_LIMBS, LIMB_BITS>
228where
229    I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
230    I::Writes: Default,
231{
232    type Record = BranchLessThanCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
233    type Air = BranchLessThanCoreAir<NUM_LIMBS, LIMB_BITS>;
234
235    #[allow(clippy::type_complexity)]
236    fn execute_instruction(
237        &self,
238        instruction: &Instruction<F>,
239        from_pc: u32,
240        reads: I::Reads,
241    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
242        let Instruction { opcode, c: imm, .. } = *instruction;
243        let blt_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
244
245        let data: [[F; NUM_LIMBS]; 2] = reads.into();
246        let a = data[0].map(|x| x.as_canonical_u32());
247        let b = data[1].map(|y| y.as_canonical_u32());
248        let (cmp_result, diff_idx, a_sign, b_sign) =
249            run_cmp::<NUM_LIMBS, LIMB_BITS>(blt_opcode, &a, &b);
250
251        let signed = matches!(
252            blt_opcode,
253            BranchLessThanOpcode::BLT | BranchLessThanOpcode::BGE
254        );
255        let ge_opcode = matches!(
256            blt_opcode,
257            BranchLessThanOpcode::BGE | BranchLessThanOpcode::BGEU
258        );
259        let cmp_lt = cmp_result ^ ge_opcode;
260
261        // We range check (a_msb_f + 128) and (b_msb_f + 128) if signed,
262        // a_msb_f and b_msb_f if not
263        let (a_msb_f, a_msb_range) = if a_sign {
264            (
265                -F::from_canonical_u32((1 << LIMB_BITS) - a[NUM_LIMBS - 1]),
266                a[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)),
267            )
268        } else {
269            (
270                F::from_canonical_u32(a[NUM_LIMBS - 1]),
271                a[NUM_LIMBS - 1] + ((signed as u32) << (LIMB_BITS - 1)),
272            )
273        };
274        let (b_msb_f, b_msb_range) = if b_sign {
275            (
276                -F::from_canonical_u32((1 << LIMB_BITS) - b[NUM_LIMBS - 1]),
277                b[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)),
278            )
279        } else {
280            (
281                F::from_canonical_u32(b[NUM_LIMBS - 1]),
282                b[NUM_LIMBS - 1] + ((signed as u32) << (LIMB_BITS - 1)),
283            )
284        };
285        self.bitwise_lookup_chip
286            .request_range(a_msb_range, b_msb_range);
287
288        let diff_val = if diff_idx == NUM_LIMBS {
289            0
290        } else if diff_idx == (NUM_LIMBS - 1) {
291            if cmp_lt {
292                b_msb_f - a_msb_f
293            } else {
294                a_msb_f - b_msb_f
295            }
296            .as_canonical_u32()
297        } else if cmp_lt {
298            b[diff_idx] - a[diff_idx]
299        } else {
300            a[diff_idx] - b[diff_idx]
301        };
302
303        if diff_idx != NUM_LIMBS {
304            self.bitwise_lookup_chip.request_range(diff_val - 1, 0);
305        }
306
307        let output = AdapterRuntimeContext {
308            to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()),
309            writes: Default::default(),
310        };
311        let record = BranchLessThanCoreRecord {
312            opcode: blt_opcode,
313            a: data[0],
314            b: data[1],
315            cmp_result: F::from_bool(cmp_result),
316            cmp_lt: F::from_bool(cmp_lt),
317            imm,
318            a_msb_f,
319            b_msb_f,
320            diff_val: F::from_canonical_u32(diff_val),
321            diff_idx,
322        };
323
324        Ok((output, record))
325    }
326
327    fn get_opcode_name(&self, opcode: usize) -> String {
328        format!(
329            "{:?}",
330            BranchLessThanOpcode::from_usize(opcode - self.air.offset)
331        )
332    }
333
334    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
335        let row_slice: &mut BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> =
336            row_slice.borrow_mut();
337        row_slice.a = record.a;
338        row_slice.b = record.b;
339        row_slice.cmp_result = record.cmp_result;
340        row_slice.cmp_lt = record.cmp_lt;
341        row_slice.imm = record.imm;
342        row_slice.a_msb_f = record.a_msb_f;
343        row_slice.b_msb_f = record.b_msb_f;
344        row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx));
345        row_slice.diff_val = record.diff_val;
346        row_slice.opcode_blt_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLT);
347        row_slice.opcode_bltu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLTU);
348        row_slice.opcode_bge_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGE);
349        row_slice.opcode_bgeu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGEU);
350    }
351
352    fn air(&self) -> &Self::Air {
353        &self.air
354    }
355}
356
357// Returns (cmp_result, diff_idx, x_sign, y_sign)
358pub(super) fn run_cmp<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
359    local_opcode: BranchLessThanOpcode,
360    x: &[u32; NUM_LIMBS],
361    y: &[u32; NUM_LIMBS],
362) -> (bool, usize, bool, bool) {
363    let signed =
364        local_opcode == BranchLessThanOpcode::BLT || local_opcode == BranchLessThanOpcode::BGE;
365    let ge_op =
366        local_opcode == BranchLessThanOpcode::BGE || local_opcode == BranchLessThanOpcode::BGEU;
367    let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
368    let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed;
369    for i in (0..NUM_LIMBS).rev() {
370        if x[i] != y[i] {
371            return ((x[i] < y[i]) ^ x_sign ^ y_sign ^ ge_op, i, x_sign, y_sign);
372        }
373    }
374    (ge_op, NUM_LIMBS, x_sign, y_sign)
375}