openvm_rv32im_circuit/less_than/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    utils::not,
13    AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
17use openvm_rv32im_transpiler::LessThanOpcode;
18use openvm_stark_backend::{
19    interaction::InteractionBuilder,
20    p3_air::{AirBuilder, BaseAir},
21    p3_field::{Field, FieldAlgebra, PrimeField32},
22    rap::BaseAirWithPublicValues,
23};
24use strum::IntoEnumIterator;
25
26#[repr(C)]
27#[derive(AlignedBorrow, Debug)]
28pub struct LessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
29    pub b: [T; NUM_LIMBS],
30    pub c: [T; NUM_LIMBS],
31    pub cmp_result: T,
32
33    pub opcode_slt_flag: T,
34    pub opcode_sltu_flag: T,
35
36    // Most significant limb of b and c respectively as a field element, will be range
37    // checked to be within [-128, 127) if signed, [0, 256) if unsigned.
38    pub b_msb_f: T,
39    pub c_msb_f: T,
40
41    // 1 at the most significant index i such that b[i] != c[i], otherwise 0. If such
42    // an i exists, diff_val = c[i] - b[i] if c[i] > b[i] or b[i] - c[i] else.
43    pub diff_marker: [T; NUM_LIMBS],
44    pub diff_val: T,
45}
46
47#[derive(Copy, Clone, Debug, derive_new::new)]
48pub struct LessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
49    pub bus: BitwiseOperationLookupBus,
50    offset: usize,
51}
52
53impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
54    for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
55{
56    fn width(&self) -> usize {
57        LessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
58    }
59}
60impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
61    for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
62{
63}
64
65impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
66    for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
67where
68    AB: InteractionBuilder,
69    I: VmAdapterInterface<AB::Expr>,
70    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
71    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
72    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
73{
74    fn eval(
75        &self,
76        builder: &mut AB,
77        local_core: &[AB::Var],
78        _from_pc: AB::Var,
79    ) -> AdapterAirContext<AB::Expr, I> {
80        let cols: &LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
81        let flags = [cols.opcode_slt_flag, cols.opcode_sltu_flag];
82
83        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
84            builder.assert_bool(flag);
85            acc + flag.into()
86        });
87        builder.assert_bool(is_valid.clone());
88        builder.assert_bool(cols.cmp_result);
89
90        let b = &cols.b;
91        let c = &cols.c;
92        let marker = &cols.diff_marker;
93        let mut prefix_sum = AB::Expr::ZERO;
94
95        let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
96        let c_diff = c[NUM_LIMBS - 1] - cols.c_msb_f;
97        builder
98            .assert_zero(b_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - b_diff));
99        builder
100            .assert_zero(c_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - c_diff));
101
102        for i in (0..NUM_LIMBS).rev() {
103            let diff = (if i == NUM_LIMBS - 1 {
104                cols.c_msb_f - cols.b_msb_f
105            } else {
106                c[i] - b[i]
107            }) * (AB::Expr::from_canonical_u8(2) * cols.cmp_result - AB::Expr::ONE);
108            prefix_sum += marker[i].into();
109            builder.assert_bool(marker[i]);
110            builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
111            builder.when(marker[i]).assert_eq(cols.diff_val, diff);
112        }
113        // - If x != y, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index where
114        //   diff != 0. Constrains that diff == diff_val where diff_val is non-zero.
115        // - If x == y, then prefix_sum = 0 and cmp_result = 0. Here, prefix_sum cannot be 1 because
116        //   all diff are zero, making diff == diff_val fails.
117
118        builder.assert_bool(prefix_sum.clone());
119        builder
120            .when(not::<AB::Expr>(prefix_sum.clone()))
121            .assert_zero(cols.cmp_result);
122
123        // Check if b_msb_f and c_msb_f are in [-128, 127) if signed, [0, 256) if unsigned.
124        self.bus
125            .send_range(
126                cols.b_msb_f
127                    + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
128                cols.c_msb_f
129                    + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
130            )
131            .eval(builder, is_valid.clone());
132
133        // Range check to ensure diff_val is non-zero.
134        self.bus
135            .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
136            .eval(builder, prefix_sum);
137
138        let expected_opcode = flags
139            .iter()
140            .zip(LessThanOpcode::iter())
141            .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
142                acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
143            })
144            + AB::Expr::from_canonical_usize(self.offset);
145        let mut a: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
146        a[0] = cols.cmp_result.into();
147
148        AdapterAirContext {
149            to_pc: None,
150            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
151            writes: [a].into(),
152            instruction: MinimalInstruction {
153                is_valid,
154                opcode: expected_opcode,
155            }
156            .into(),
157        }
158    }
159
160    fn start_offset(&self) -> usize {
161        self.offset
162    }
163}
164
165#[repr(C)]
166#[derive(AlignedBytesBorrow, Debug)]
167pub struct LessThanCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
168    pub b: [u8; NUM_LIMBS],
169    pub c: [u8; NUM_LIMBS],
170    pub local_opcode: u8,
171}
172
173#[derive(Clone, Copy, derive_new::new)]
174pub struct LessThanExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
175    adapter: A,
176    pub offset: usize,
177}
178
179#[derive(Clone, derive_new::new)]
180pub struct LessThanFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
181    adapter: A,
182    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
183    pub offset: usize,
184}
185
186impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
187    for LessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
188where
189    F: PrimeField32,
190    A: 'static
191        + AdapterTraceExecutor<
192            F,
193            ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
194            WriteData: From<[[u8; NUM_LIMBS]; 1]>,
195        >,
196    for<'buf> RA: RecordArena<
197        'buf,
198        EmptyAdapterCoreLayout<F, A>,
199        (
200            A::RecordMut<'buf>,
201            &'buf mut LessThanCoreRecord<NUM_LIMBS, LIMB_BITS>,
202        ),
203    >,
204{
205    fn get_opcode_name(&self, opcode: usize) -> String {
206        format!("{:?}", LessThanOpcode::from_usize(opcode - self.offset))
207    }
208
209    fn execute(
210        &self,
211        state: VmStateMut<F, TracingMemory, RA>,
212        instruction: &Instruction<F>,
213    ) -> Result<(), ExecutionError> {
214        debug_assert!(LIMB_BITS <= 8);
215        let Instruction { opcode, .. } = instruction;
216
217        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
218        A::start(*state.pc, state.memory, &mut adapter_record);
219
220        let [rs1, rs2] = self
221            .adapter
222            .read(state.memory, instruction, &mut adapter_record)
223            .into();
224
225        core_record.b = rs1;
226        core_record.c = rs2;
227        core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8;
228
229        let (cmp_result, _, _, _) = run_less_than::<NUM_LIMBS, LIMB_BITS>(
230            core_record.local_opcode == LessThanOpcode::SLT as u8,
231            &rs1,
232            &rs2,
233        );
234
235        let mut output = [0u8; NUM_LIMBS];
236        output[0] = cmp_result as u8;
237
238        self.adapter.write(
239            state.memory,
240            instruction,
241            [output].into(),
242            &mut adapter_record,
243        );
244
245        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
246
247        Ok(())
248    }
249}
250
251impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
252    for LessThanFiller<A, NUM_LIMBS, LIMB_BITS>
253where
254    F: PrimeField32,
255    A: 'static + AdapterTraceFiller<F>,
256{
257    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
258        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
259        // LessThanCoreCols::width() elements
260        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
261        self.adapter.fill_trace_row(mem_helper, adapter_row);
262        // SAFETY: core_row contains a valid LessThanCoreRecord written by the executor
263        // during trace generation
264        let record: &LessThanCoreRecord<NUM_LIMBS, LIMB_BITS> =
265            unsafe { get_record_from_slice(&mut core_row, ()) };
266
267        let core_row: &mut LessThanCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
268
269        let is_slt = record.local_opcode == LessThanOpcode::SLT as u8;
270        let (cmp_result, diff_idx, b_sign, c_sign) =
271            run_less_than::<NUM_LIMBS, LIMB_BITS>(is_slt, &record.b, &record.c);
272
273        // We range check (b_msb_f + 128) and (c_msb_f + 128) if signed,
274        // b_msb_f and c_msb_f if not
275        let (b_msb_f, b_msb_range) = if b_sign {
276            (
277                -F::from_canonical_u16((1u16 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u16),
278                record.b[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)),
279            )
280        } else {
281            (
282                F::from_canonical_u8(record.b[NUM_LIMBS - 1]),
283                record.b[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)),
284            )
285        };
286        let (c_msb_f, c_msb_range) = if c_sign {
287            (
288                -F::from_canonical_u16((1u16 << LIMB_BITS) - record.c[NUM_LIMBS - 1] as u16),
289                record.c[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)),
290            )
291        } else {
292            (
293                F::from_canonical_u8(record.c[NUM_LIMBS - 1]),
294                record.c[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)),
295            )
296        };
297
298        core_row.diff_val = if diff_idx == NUM_LIMBS {
299            F::ZERO
300        } else if diff_idx == (NUM_LIMBS - 1) {
301            if cmp_result {
302                c_msb_f - b_msb_f
303            } else {
304                b_msb_f - c_msb_f
305            }
306        } else if cmp_result {
307            F::from_canonical_u8(record.c[diff_idx] - record.b[diff_idx])
308        } else {
309            F::from_canonical_u8(record.b[diff_idx] - record.c[diff_idx])
310        };
311
312        self.bitwise_lookup_chip
313            .request_range(b_msb_range as u32, c_msb_range as u32);
314
315        core_row.diff_marker = [F::ZERO; NUM_LIMBS];
316        if diff_idx != NUM_LIMBS {
317            self.bitwise_lookup_chip
318                .request_range(core_row.diff_val.as_canonical_u32() - 1, 0);
319            core_row.diff_marker[diff_idx] = F::ONE;
320        }
321
322        core_row.c_msb_f = c_msb_f;
323        core_row.b_msb_f = b_msb_f;
324        core_row.opcode_sltu_flag = F::from_bool(!is_slt);
325        core_row.opcode_slt_flag = F::from_bool(is_slt);
326        core_row.cmp_result = F::from_bool(cmp_result);
327        core_row.c = record.c.map(F::from_canonical_u8);
328        core_row.b = record.b.map(F::from_canonical_u8);
329    }
330}
331
332// Returns (cmp_result, diff_idx, x_sign, y_sign)
333#[inline(always)]
334pub(super) fn run_less_than<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
335    is_slt: bool,
336    x: &[u8; NUM_LIMBS],
337    y: &[u8; NUM_LIMBS],
338) -> (bool, usize, bool, bool) {
339    let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt;
340    let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt;
341    for i in (0..NUM_LIMBS).rev() {
342        if x[i] != y[i] {
343            return ((x[i] < y[i]) ^ x_sign ^ y_sign, i, x_sign, y_sign);
344        }
345    }
346    (false, NUM_LIMBS, x_sign, y_sign)
347}