openvm_rv32im_circuit/less_than/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    utils::not,
13    AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
17use openvm_rv32im_transpiler::LessThanOpcode;
18use openvm_stark_backend::{
19    interaction::InteractionBuilder,
20    p3_air::{AirBuilder, BaseAir},
21    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
22    rap::BaseAirWithPublicValues,
23};
24use strum::IntoEnumIterator;
25
26#[repr(C)]
27#[derive(AlignedBorrow, Debug)]
28pub struct LessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
29    pub b: [T; NUM_LIMBS],
30    pub c: [T; NUM_LIMBS],
31    pub cmp_result: T,
32
33    pub opcode_slt_flag: T,
34    pub opcode_sltu_flag: T,
35
36    // Most significant limb of b and c respectively as a field element, will be range
37    // checked to be within [-128, 127) if signed, [0, 256) if unsigned.
38    pub b_msb_f: T,
39    pub c_msb_f: T,
40
41    // 1 at the most significant index i such that b[i] != c[i], otherwise 0. If such
42    // an i exists, diff_val = c[i] - b[i] if c[i] > b[i] or b[i] - c[i] else.
43    pub diff_marker: [T; NUM_LIMBS],
44    pub diff_val: T,
45}
46
47#[derive(Copy, Clone, Debug, derive_new::new)]
48pub struct LessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
49    pub bus: BitwiseOperationLookupBus,
50    offset: usize,
51}
52
53impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
54    for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
55{
56    fn width(&self) -> usize {
57        LessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
58    }
59}
60impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
61    for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
62{
63}
64
65impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
66    for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
67where
68    AB: InteractionBuilder,
69    I: VmAdapterInterface<AB::Expr>,
70    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
71    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
72    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
73{
74    fn eval(
75        &self,
76        builder: &mut AB,
77        local_core: &[AB::Var],
78        _from_pc: AB::Var,
79    ) -> AdapterAirContext<AB::Expr, I> {
80        let cols: &LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
81        let flags = [cols.opcode_slt_flag, cols.opcode_sltu_flag];
82
83        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
84            builder.assert_bool(flag);
85            acc + flag.into()
86        });
87        builder.assert_bool(is_valid.clone());
88        builder.assert_bool(cols.cmp_result);
89
90        let b = &cols.b;
91        let c = &cols.c;
92        let marker = &cols.diff_marker;
93        let mut prefix_sum = AB::Expr::ZERO;
94
95        let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
96        let c_diff = c[NUM_LIMBS - 1] - cols.c_msb_f;
97        builder.assert_zero(b_diff.clone() * (AB::Expr::from_u32(1 << LIMB_BITS) - b_diff));
98        builder.assert_zero(c_diff.clone() * (AB::Expr::from_u32(1 << LIMB_BITS) - c_diff));
99
100        for i in (0..NUM_LIMBS).rev() {
101            let diff = (if i == NUM_LIMBS - 1 {
102                cols.c_msb_f - cols.b_msb_f
103            } else {
104                c[i] - b[i]
105            }) * (AB::Expr::from_u8(2) * cols.cmp_result - AB::Expr::ONE);
106            prefix_sum += marker[i].into();
107            builder.assert_bool(marker[i]);
108            builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
109            builder.when(marker[i]).assert_eq(cols.diff_val, diff);
110        }
111        // - If x != y, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index where
112        //   diff != 0. Constrains that diff == diff_val where diff_val is non-zero.
113        // - If x == y, then prefix_sum = 0 and cmp_result = 0. Here, prefix_sum cannot be 1 because
114        //   all diff are zero, making diff == diff_val fails.
115
116        builder.assert_bool(prefix_sum.clone());
117        builder
118            .when(not::<AB::Expr>(prefix_sum.clone()))
119            .assert_zero(cols.cmp_result);
120
121        // Check if b_msb_f and c_msb_f are in [-128, 127) if signed, [0, 256) if unsigned.
122        self.bus
123            .send_range(
124                cols.b_msb_f + AB::Expr::from_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
125                cols.c_msb_f + AB::Expr::from_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
126            )
127            .eval(builder, is_valid.clone());
128
129        // Range check to ensure diff_val is non-zero.
130        self.bus
131            .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
132            .eval(builder, prefix_sum);
133
134        let expected_opcode = flags
135            .iter()
136            .zip(LessThanOpcode::iter())
137            .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
138                acc + (*flag).into() * AB::Expr::from_u8(opcode as u8)
139            })
140            + AB::Expr::from_usize(self.offset);
141        let mut a: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
142        a[0] = cols.cmp_result.into();
143
144        AdapterAirContext {
145            to_pc: None,
146            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
147            writes: [a].into(),
148            instruction: MinimalInstruction {
149                is_valid,
150                opcode: expected_opcode,
151            }
152            .into(),
153        }
154    }
155
156    fn start_offset(&self) -> usize {
157        self.offset
158    }
159}
160
161#[repr(C)]
162#[derive(AlignedBytesBorrow, Debug)]
163pub struct LessThanCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
164    pub b: [u8; NUM_LIMBS],
165    pub c: [u8; NUM_LIMBS],
166    pub local_opcode: u8,
167}
168
169#[derive(Clone, Copy, derive_new::new)]
170pub struct LessThanExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
171    adapter: A,
172    pub offset: usize,
173}
174
175#[derive(Clone, derive_new::new)]
176pub struct LessThanFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
177    adapter: A,
178    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
179    pub offset: usize,
180}
181
182impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
183    for LessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
184where
185    F: PrimeField32,
186    A: 'static
187        + AdapterTraceExecutor<
188            F,
189            ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
190            WriteData: From<[[u8; NUM_LIMBS]; 1]>,
191        >,
192    for<'buf> RA: RecordArena<
193        'buf,
194        EmptyAdapterCoreLayout<F, A>,
195        (
196            A::RecordMut<'buf>,
197            &'buf mut LessThanCoreRecord<NUM_LIMBS, LIMB_BITS>,
198        ),
199    >,
200{
201    fn get_opcode_name(&self, opcode: usize) -> String {
202        format!("{:?}", LessThanOpcode::from_usize(opcode - self.offset))
203    }
204
205    fn execute(
206        &self,
207        state: VmStateMut<F, TracingMemory, RA>,
208        instruction: &Instruction<F>,
209    ) -> Result<(), ExecutionError> {
210        debug_assert!(LIMB_BITS <= 8);
211        let Instruction { opcode, .. } = instruction;
212
213        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
214        A::start(*state.pc, state.memory, &mut adapter_record);
215
216        let [rs1, rs2] = self
217            .adapter
218            .read(state.memory, instruction, &mut adapter_record)
219            .into();
220
221        core_record.b = rs1;
222        core_record.c = rs2;
223        core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8;
224
225        let (cmp_result, _, _, _) = run_less_than::<NUM_LIMBS, LIMB_BITS>(
226            core_record.local_opcode == LessThanOpcode::SLT as u8,
227            &rs1,
228            &rs2,
229        );
230
231        let mut output = [0u8; NUM_LIMBS];
232        output[0] = cmp_result as u8;
233
234        self.adapter.write(
235            state.memory,
236            instruction,
237            [output].into(),
238            &mut adapter_record,
239        );
240
241        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
242
243        Ok(())
244    }
245}
246
247impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
248    for LessThanFiller<A, NUM_LIMBS, LIMB_BITS>
249where
250    F: PrimeField32,
251    A: 'static + AdapterTraceFiller<F>,
252{
253    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
254        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
255        // LessThanCoreCols::width() elements
256        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
257        self.adapter.fill_trace_row(mem_helper, adapter_row);
258        // SAFETY: core_row contains a valid LessThanCoreRecord written by the executor
259        // during trace generation
260        let record: &LessThanCoreRecord<NUM_LIMBS, LIMB_BITS> =
261            unsafe { get_record_from_slice(&mut core_row, ()) };
262
263        let core_row: &mut LessThanCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
264
265        let is_slt = record.local_opcode == LessThanOpcode::SLT as u8;
266        let (cmp_result, diff_idx, b_sign, c_sign) =
267            run_less_than::<NUM_LIMBS, LIMB_BITS>(is_slt, &record.b, &record.c);
268
269        // We range check (b_msb_f + 128) and (c_msb_f + 128) if signed,
270        // b_msb_f and c_msb_f if not
271        let (b_msb_f, b_msb_range) = if b_sign {
272            (
273                -F::from_u16((1u16 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u16),
274                record.b[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)),
275            )
276        } else {
277            (
278                F::from_u8(record.b[NUM_LIMBS - 1]),
279                record.b[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)),
280            )
281        };
282        let (c_msb_f, c_msb_range) = if c_sign {
283            (
284                -F::from_u16((1u16 << LIMB_BITS) - record.c[NUM_LIMBS - 1] as u16),
285                record.c[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)),
286            )
287        } else {
288            (
289                F::from_u8(record.c[NUM_LIMBS - 1]),
290                record.c[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)),
291            )
292        };
293
294        core_row.diff_val = if diff_idx == NUM_LIMBS {
295            F::ZERO
296        } else if diff_idx == (NUM_LIMBS - 1) {
297            if cmp_result {
298                c_msb_f - b_msb_f
299            } else {
300                b_msb_f - c_msb_f
301            }
302        } else if cmp_result {
303            F::from_u8(record.c[diff_idx] - record.b[diff_idx])
304        } else {
305            F::from_u8(record.b[diff_idx] - record.c[diff_idx])
306        };
307
308        self.bitwise_lookup_chip
309            .request_range(b_msb_range as u32, c_msb_range as u32);
310
311        core_row.diff_marker = [F::ZERO; NUM_LIMBS];
312        if diff_idx != NUM_LIMBS {
313            self.bitwise_lookup_chip
314                .request_range(core_row.diff_val.as_canonical_u32() - 1, 0);
315            core_row.diff_marker[diff_idx] = F::ONE;
316        }
317
318        core_row.c_msb_f = c_msb_f;
319        core_row.b_msb_f = b_msb_f;
320        core_row.opcode_sltu_flag = F::from_bool(!is_slt);
321        core_row.opcode_slt_flag = F::from_bool(is_slt);
322        core_row.cmp_result = F::from_bool(cmp_result);
323        core_row.c = record.c.map(F::from_u8);
324        core_row.b = record.b.map(F::from_u8);
325    }
326}
327
328// Returns (cmp_result, diff_idx, x_sign, y_sign)
329#[inline(always)]
330pub(super) fn run_less_than<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
331    is_slt: bool,
332    x: &[u8; NUM_LIMBS],
333    y: &[u8; NUM_LIMBS],
334) -> (bool, usize, bool, bool) {
335    let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt;
336    let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt;
337    for i in (0..NUM_LIMBS).rev() {
338        if x[i] != y[i] {
339            return ((x[i] < y[i]) ^ x_sign ^ y_sign, i, x_sign, y_sign);
340        }
341    }
342    (false, NUM_LIMBS, x_sign, y_sign)
343}