openvm_rv32im_circuit/less_than/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
8    VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    utils::not,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{instruction::Instruction, LocalOpcode};
16use openvm_rv32im_transpiler::LessThanOpcode;
17use openvm_stark_backend::{
18    interaction::InteractionBuilder,
19    p3_air::{AirBuilder, BaseAir},
20    p3_field::{Field, FieldAlgebra, PrimeField32},
21    rap::BaseAirWithPublicValues,
22};
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use serde_big_array::BigArray;
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow)]
29pub struct LessThanCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30    pub b: [T; NUM_LIMBS],
31    pub c: [T; NUM_LIMBS],
32    pub cmp_result: T,
33
34    pub opcode_slt_flag: T,
35    pub opcode_sltu_flag: T,
36
37    // Most significant limb of b and c respectively as a field element, will be range
38    // checked to be within [-128, 127) if signed, [0, 256) if unsigned.
39    pub b_msb_f: T,
40    pub c_msb_f: T,
41
42    // 1 at the most significant index i such that b[i] != c[i], otherwise 0. If such
43    // an i exists, diff_val = c[i] - b[i] if c[i] > b[i] or b[i] - c[i] else.
44    pub diff_marker: [T; NUM_LIMBS],
45    pub diff_val: T,
46}
47
48#[derive(Copy, Clone, Debug)]
49pub struct LessThanCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
50    pub bus: BitwiseOperationLookupBus,
51    offset: usize,
52}
53
54impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
55    for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
56{
57    fn width(&self) -> usize {
58        LessThanCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
59    }
60}
61impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
62    for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
63{
64}
65
66impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
67    for LessThanCoreAir<NUM_LIMBS, LIMB_BITS>
68where
69    AB: InteractionBuilder,
70    I: VmAdapterInterface<AB::Expr>,
71    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
72    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
73    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
74{
75    fn eval(
76        &self,
77        builder: &mut AB,
78        local_core: &[AB::Var],
79        _from_pc: AB::Var,
80    ) -> AdapterAirContext<AB::Expr, I> {
81        let cols: &LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
82        let flags = [cols.opcode_slt_flag, cols.opcode_sltu_flag];
83
84        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
85            builder.assert_bool(flag);
86            acc + flag.into()
87        });
88        builder.assert_bool(is_valid.clone());
89        builder.assert_bool(cols.cmp_result);
90
91        let b = &cols.b;
92        let c = &cols.c;
93        let marker = &cols.diff_marker;
94        let mut prefix_sum = AB::Expr::ZERO;
95
96        let b_diff = b[NUM_LIMBS - 1] - cols.b_msb_f;
97        let c_diff = c[NUM_LIMBS - 1] - cols.c_msb_f;
98        builder
99            .assert_zero(b_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - b_diff));
100        builder
101            .assert_zero(c_diff.clone() * (AB::Expr::from_canonical_u32(1 << LIMB_BITS) - c_diff));
102
103        for i in (0..NUM_LIMBS).rev() {
104            let diff = (if i == NUM_LIMBS - 1 {
105                cols.c_msb_f - cols.b_msb_f
106            } else {
107                c[i] - b[i]
108            }) * (AB::Expr::from_canonical_u8(2) * cols.cmp_result - AB::Expr::ONE);
109            prefix_sum += marker[i].into();
110            builder.assert_bool(marker[i]);
111            builder.assert_zero(not::<AB::Expr>(prefix_sum.clone()) * diff.clone());
112            builder.when(marker[i]).assert_eq(cols.diff_val, diff);
113        }
114        // - If x != y, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index where diff != 0.
115        //   Constrains that diff == diff_val where diff_val is non-zero.
116        // - If x == y, then prefix_sum = 0 and cmp_result = 0.
117        //   Here, prefix_sum cannot be 1 because all diff are zero, making diff == diff_val fails.
118
119        builder.assert_bool(prefix_sum.clone());
120        builder
121            .when(not::<AB::Expr>(prefix_sum.clone()))
122            .assert_zero(cols.cmp_result);
123
124        // Check if b_msb_f and c_msb_f are in [-128, 127) if signed, [0, 256) if unsigned.
125        self.bus
126            .send_range(
127                cols.b_msb_f
128                    + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
129                cols.c_msb_f
130                    + AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)) * cols.opcode_slt_flag,
131            )
132            .eval(builder, is_valid.clone());
133
134        // Range check to ensure diff_val is non-zero.
135        self.bus
136            .send_range(cols.diff_val - AB::Expr::ONE, AB::F::ZERO)
137            .eval(builder, prefix_sum);
138
139        let expected_opcode = flags
140            .iter()
141            .zip(LessThanOpcode::iter())
142            .fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
143                acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
144            })
145            + AB::Expr::from_canonical_usize(self.offset);
146        let mut a: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
147        a[0] = cols.cmp_result.into();
148
149        AdapterAirContext {
150            to_pc: None,
151            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
152            writes: [a].into(),
153            instruction: MinimalInstruction {
154                is_valid,
155                opcode: expected_opcode,
156            }
157            .into(),
158        }
159    }
160
161    fn start_offset(&self) -> usize {
162        self.offset
163    }
164}
165
166#[repr(C)]
167#[derive(Clone, Debug, Serialize, Deserialize)]
168#[serde(bound = "T: Serialize + DeserializeOwned")]
169pub struct LessThanCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
170    #[serde(with = "BigArray")]
171    pub b: [T; NUM_LIMBS],
172    #[serde(with = "BigArray")]
173    pub c: [T; NUM_LIMBS],
174    pub cmp_result: T,
175    pub b_msb_f: T,
176    pub c_msb_f: T,
177    pub diff_val: T,
178    pub diff_idx: usize,
179    pub opcode: LessThanOpcode,
180}
181
182pub struct LessThanCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
183    pub air: LessThanCoreAir<NUM_LIMBS, LIMB_BITS>,
184    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
185}
186
187impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> LessThanCoreChip<NUM_LIMBS, LIMB_BITS> {
188    pub fn new(
189        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
190        offset: usize,
191    ) -> Self {
192        Self {
193            air: LessThanCoreAir {
194                bus: bitwise_lookup_chip.bus(),
195                offset,
196            },
197            bitwise_lookup_chip,
198        }
199    }
200}
201
202impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
203    VmCoreChip<F, I> for LessThanCoreChip<NUM_LIMBS, LIMB_BITS>
204where
205    I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
206    I::Writes: From<[[F; NUM_LIMBS]; 1]>,
207{
208    type Record = LessThanCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
209    type Air = LessThanCoreAir<NUM_LIMBS, LIMB_BITS>;
210
211    #[allow(clippy::type_complexity)]
212    fn execute_instruction(
213        &self,
214        instruction: &Instruction<F>,
215        _from_pc: u32,
216        reads: I::Reads,
217    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
218        let Instruction { opcode, .. } = instruction;
219        let less_than_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
220
221        let data: [[F; NUM_LIMBS]; 2] = reads.into();
222        let b = data[0].map(|x| x.as_canonical_u32());
223        let c = data[1].map(|y| y.as_canonical_u32());
224        let (cmp_result, diff_idx, b_sign, c_sign) =
225            run_less_than::<NUM_LIMBS, LIMB_BITS>(less_than_opcode, &b, &c);
226
227        // We range check (b_msb_f + 128) and (c_msb_f + 128) if signed,
228        // b_msb_f and c_msb_f if not
229        let (b_msb_f, b_msb_range) = if b_sign {
230            (
231                -F::from_canonical_u32((1 << LIMB_BITS) - b[NUM_LIMBS - 1]),
232                b[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)),
233            )
234        } else {
235            (
236                F::from_canonical_u32(b[NUM_LIMBS - 1]),
237                b[NUM_LIMBS - 1]
238                    + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)),
239            )
240        };
241        let (c_msb_f, c_msb_range) = if c_sign {
242            (
243                -F::from_canonical_u32((1 << LIMB_BITS) - c[NUM_LIMBS - 1]),
244                c[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)),
245            )
246        } else {
247            (
248                F::from_canonical_u32(c[NUM_LIMBS - 1]),
249                c[NUM_LIMBS - 1]
250                    + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)),
251            )
252        };
253        self.bitwise_lookup_chip
254            .request_range(b_msb_range, c_msb_range);
255
256        let diff_val = if diff_idx == NUM_LIMBS {
257            0
258        } else if diff_idx == (NUM_LIMBS - 1) {
259            if cmp_result {
260                c_msb_f - b_msb_f
261            } else {
262                b_msb_f - c_msb_f
263            }
264            .as_canonical_u32()
265        } else if cmp_result {
266            c[diff_idx] - b[diff_idx]
267        } else {
268            b[diff_idx] - c[diff_idx]
269        };
270
271        if diff_idx != NUM_LIMBS {
272            self.bitwise_lookup_chip.request_range(diff_val - 1, 0);
273        }
274
275        let mut writes = [0u32; NUM_LIMBS];
276        writes[0] = cmp_result as u32;
277
278        let output = AdapterRuntimeContext::without_pc([writes.map(F::from_canonical_u32)]);
279        let record = LessThanCoreRecord {
280            opcode: less_than_opcode,
281            b: data[0],
282            c: data[1],
283            cmp_result: F::from_bool(cmp_result),
284            b_msb_f,
285            c_msb_f,
286            diff_val: F::from_canonical_u32(diff_val),
287            diff_idx,
288        };
289
290        Ok((output, record))
291    }
292
293    fn get_opcode_name(&self, opcode: usize) -> String {
294        format!("{:?}", LessThanOpcode::from_usize(opcode - self.air.offset))
295    }
296
297    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
298        let row_slice: &mut LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
299        row_slice.b = record.b;
300        row_slice.c = record.c;
301        row_slice.cmp_result = record.cmp_result;
302        row_slice.b_msb_f = record.b_msb_f;
303        row_slice.c_msb_f = record.c_msb_f;
304        row_slice.diff_val = record.diff_val;
305        row_slice.opcode_slt_flag = F::from_bool(record.opcode == LessThanOpcode::SLT);
306        row_slice.opcode_sltu_flag = F::from_bool(record.opcode == LessThanOpcode::SLTU);
307        row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx));
308    }
309
310    fn air(&self) -> &Self::Air {
311        &self.air
312    }
313}
314
315// Returns (cmp_result, diff_idx, x_sign, y_sign)
316pub(super) fn run_less_than<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
317    opcode: LessThanOpcode,
318    x: &[u32; NUM_LIMBS],
319    y: &[u32; NUM_LIMBS],
320) -> (bool, usize, bool, bool) {
321    let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT;
322    let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT;
323    for i in (0..NUM_LIMBS).rev() {
324        if x[i] != y[i] {
325            return ((x[i] < y[i]) ^ x_sign ^ y_sign, i, x_sign, y_sign);
326        }
327    }
328    (false, NUM_LIMBS, x_sign, y_sign)
329}