openvm_algebra_circuit/modular_chip/
is_eq.rs

1use std::{
2    array::{self, from_fn},
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::arch::{
9    AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
10    VmCoreAir, VmCoreChip,
11};
12use openvm_circuit_primitives::{
13    bigint::utils::big_uint_to_limbs,
14    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
15    is_equal_array::{IsEqArrayIo, IsEqArraySubAir},
16    SubAir, TraceSubRowGenerator,
17};
18use openvm_circuit_primitives_derive::AlignedBorrow;
19use openvm_instructions::{instruction::Instruction, LocalOpcode};
20use openvm_stark_backend::{
21    interaction::InteractionBuilder,
22    p3_air::{AirBuilder, BaseAir},
23    p3_field::{Field, FieldAlgebra, PrimeField32},
24    rap::BaseAirWithPublicValues,
25};
26use serde::{Deserialize, Serialize};
27use serde_big_array::BigArray;
28// Given two numbers b and c, we want to prove that a) b == c or b != c, depending on
29// result of cmp_result and b) b, c < N for some modulus N that is passed into the AIR
30// at runtime (i.e. when chip is instantiated).
31
32#[repr(C)]
33#[derive(AlignedBorrow)]
34pub struct ModularIsEqualCoreCols<T, const READ_LIMBS: usize> {
35    pub is_valid: T,
36    pub is_setup: T,
37    pub b: [T; READ_LIMBS],
38    pub c: [T; READ_LIMBS],
39    pub cmp_result: T,
40
41    // Auxiliary columns for subair EQ comparison between b and c.
42    pub eq_marker: [T; READ_LIMBS],
43
44    // Auxiliary columns to ensure both b and c are smaller than modulus N. Let b_diff_idx be
45    // an index such that b[b_diff_idx] < N[b_diff_idx] and b[i] = N[i] for all i > b_diff_idx,
46    // where larger indices correspond to more significant limbs. Such an index exists iff b < N.
47    // Define c_diff_idx analogously. Then let b_lt_diff = N[b_diff_idx] - b[b_diff_idx] and
48    // c_lt_diff = N[c_diff_idx] - c[c_diff_idx], where both must be in [0, 2^LIMB_BITS).
49    //
50    // To constrain the above, we will use lt_marker, which will indicate where b_diff_idx and
51    // c_diff_idx are. Set lt_marker[b_diff_idx] = 1, lt_marker[c_diff_idx] = c_lt_mark, and 0
52    // everywhere else. If b_diff_idx == c_diff_idx then c_lt_mark = 1, else c_lt_mark = 2. The
53    // purpose of c_lt_mark is to handle the edge case where b_diff_idx == c_diff_idx (because
54    // we cannot set lt_marker[b_diff_idx] to 1 and 2 at the same time).
55    pub lt_marker: [T; READ_LIMBS],
56    pub b_lt_diff: T,
57    pub c_lt_diff: T,
58    pub c_lt_mark: T,
59}
60
61#[derive(Clone, Debug)]
62pub struct ModularIsEqualCoreAir<
63    const READ_LIMBS: usize,
64    const WRITE_LIMBS: usize,
65    const LIMB_BITS: usize,
66> {
67    pub bus: BitwiseOperationLookupBus,
68    pub subair: IsEqArraySubAir<READ_LIMBS>,
69    pub modulus_limbs: [u32; READ_LIMBS],
70    pub offset: usize,
71}
72
73impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
74    ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
75{
76    pub fn new(modulus: BigUint, bus: BitwiseOperationLookupBus, offset: usize) -> Self {
77        let mod_vec = big_uint_to_limbs(&modulus, LIMB_BITS);
78        assert!(mod_vec.len() <= READ_LIMBS);
79        let modulus_limbs = array::from_fn(|i| {
80            if i < mod_vec.len() {
81                mod_vec[i] as u32
82            } else {
83                0
84            }
85        });
86        Self {
87            bus,
88            subair: IsEqArraySubAir::<READ_LIMBS>,
89            modulus_limbs,
90            offset,
91        }
92    }
93}
94
95impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
96    for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
97{
98    fn width(&self) -> usize {
99        ModularIsEqualCoreCols::<F, READ_LIMBS>::width()
100    }
101}
102impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
103    BaseAirWithPublicValues<F> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
104{
105}
106
107impl<AB, I, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
108    VmCoreAir<AB, I> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
109where
110    AB: InteractionBuilder,
111    I: VmAdapterInterface<AB::Expr>,
112    I::Reads: From<[[AB::Expr; READ_LIMBS]; 2]>,
113    I::Writes: From<[[AB::Expr; WRITE_LIMBS]; 1]>,
114    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
115{
116    fn eval(
117        &self,
118        builder: &mut AB,
119        local_core: &[AB::Var],
120        _from_pc: AB::Var,
121    ) -> AdapterAirContext<AB::Expr, I> {
122        let cols: &ModularIsEqualCoreCols<_, READ_LIMBS> = local_core.borrow();
123
124        builder.assert_bool(cols.is_valid);
125        builder.assert_bool(cols.is_setup);
126        builder.when(cols.is_setup).assert_one(cols.is_valid);
127        builder.assert_bool(cols.cmp_result);
128
129        // Constrain that either b == c or b != c, depending on the value of cmp_result.
130        let eq_subair_io = IsEqArrayIo {
131            x: cols.b.map(Into::into),
132            y: cols.c.map(Into::into),
133            out: cols.cmp_result.into(),
134            condition: cols.is_valid - cols.is_setup,
135        };
136        self.subair.eval(builder, (eq_subair_io, cols.eq_marker));
137
138        // Constrain that auxiliary columns lt_columns and c_lt_mark are as defined above.
139        // When c_lt_mark is 1, lt_marker should have exactly one index i where lt_marker[i]
140        // is 1, and be 0 elsewhere. When c_lt_mark is 2, lt_marker[i] should have an
141        // additional index j such that lt_marker[j] is 2. To constrain this:
142        //
143        // * When c_lt_mark = 1 the sum of all lt_marker[i] must be 1
144        // * When c_lt_mark = 2 the sum of lt_marker[i] * (lt_marker[i] - 1) must be 2.
145        //   Additionally, the sum of all lt_marker[i] must be 3.
146        //
147        // All this doesn't apply when is_setup.
148        let lt_marker_sum = cols
149            .lt_marker
150            .iter()
151            .fold(AB::Expr::ZERO, |acc, x| acc + *x);
152        let lt_marker_one_check_sum = cols
153            .lt_marker
154            .iter()
155            .fold(AB::Expr::ZERO, |acc, x| acc + (*x) * (*x - AB::F::ONE));
156
157        // Constrain that c_lt_mark is either 1 or 2.
158        builder
159            .when(cols.is_valid - cols.is_setup)
160            .assert_bool(cols.c_lt_mark - AB::F::ONE);
161
162        // If c_lt_mark is 1, then lt_marker_sum is 1
163        builder
164            .when(cols.is_valid - cols.is_setup)
165            .when_ne(cols.c_lt_mark, AB::F::from_canonical_u8(2))
166            .assert_one(lt_marker_sum.clone());
167
168        // If c_lt_mark is 2, then lt_marker_sum is 3
169        builder
170            .when(cols.is_valid - cols.is_setup)
171            .when_ne(cols.c_lt_mark, AB::F::ONE)
172            .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(3));
173
174        // This constraint, along with the constraint (below) that lt_marker[i] is 0, 1, or 2,
175        // ensures that lt_marker has exactly one 2.
176        builder.when_ne(cols.c_lt_mark, AB::F::ONE).assert_eq(
177            lt_marker_one_check_sum,
178            cols.is_valid * AB::F::from_canonical_u8(2),
179        );
180
181        // Handle the setup row constraints.
182        // When is_setup = 1, constrain c_lt_mark = 2 and lt_marker_sum = 2
183        // This ensures that lt_marker has exactly one 2 and the remaining entries are 0.
184        // Since lt_marker has no 1, we will end up constraining that b[i] = N[i] for all i
185        // instead of just for i > b_diff_idx.
186        builder
187            .when(cols.is_setup)
188            .assert_eq(cols.c_lt_mark, AB::F::from_canonical_u8(2));
189        builder
190            .when(cols.is_setup)
191            .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(2));
192
193        // Constrain that b, c < N (i.e. modulus).
194        let modulus = self.modulus_limbs.map(AB::F::from_canonical_u32);
195        let mut prefix_sum = AB::Expr::ZERO;
196
197        for i in (0..READ_LIMBS).rev() {
198            prefix_sum += cols.lt_marker[i].into();
199            builder.assert_zero(
200                cols.lt_marker[i]
201                    * (cols.lt_marker[i] - AB::F::ONE)
202                    * (cols.lt_marker[i] - cols.c_lt_mark),
203            );
204
205            // Constrain b < N.
206            // First, we constrain b[i] = N[i] for i > b_diff_idx.
207            // We do this by constraining that b[i] = N[i] when prefix_sum is not 1 or
208            // lt_marker_sum.
209            //  - If is_setup = 0, then lt_marker_sum is either 1 or 3. In this case, prefix_sum is
210            //    0, 1, 2, or 3. It can be verified by casework that i > b_diff_idx iff prefix_sum
211            //    is not 1 or lt_marker_sum.
212            //  - If is_setup = 1, then we want to constrain b[i] = N[i] for all i. In this case,
213            //    lt_marker_sum is 2 and prefix_sum is 0 or 2. So we constrain b[i] = N[i] when
214            //    prefix_sum is not 1, which works.
215            builder
216                .when_ne(prefix_sum.clone(), AB::F::ONE)
217                .when_ne(prefix_sum.clone(), lt_marker_sum.clone() - cols.is_setup)
218                .assert_eq(cols.b[i], modulus[i]);
219            // Note that lt_marker[i] is either 0, 1, or 2 and lt_marker[i] being 1 indicates b[i] <
220            // N[i] (i.e. i == b_diff_idx).
221            builder
222                .when_ne(cols.lt_marker[i], AB::F::ZERO)
223                .when_ne(cols.lt_marker[i], AB::F::from_canonical_u8(2))
224                .assert_eq(AB::Expr::from(modulus[i]) - cols.b[i], cols.b_lt_diff);
225
226            // Constrain c < N.
227            // First, we constrain c[i] = N[i] for i > c_diff_idx.
228            // We do this by constraining that c[i] = N[i] when prefix_sum is not c_lt_mark or
229            // lt_marker_sum. It can be verified by casework that i > c_diff_idx iff
230            // prefix_sum is not c_lt_mark or lt_marker_sum.
231            builder
232                .when_ne(prefix_sum.clone(), cols.c_lt_mark)
233                .when_ne(prefix_sum.clone(), lt_marker_sum.clone())
234                .assert_eq(cols.c[i], modulus[i]);
235            // Note that lt_marker[i] is either 0, 1, or 2 and lt_marker[i] being c_lt_mark
236            // indicates c[i] < N[i] (i.e. i == c_diff_idx). Since c_lt_mark is 1 or 2,
237            // we have {0, 1, 2} \ {0, 3 - c_lt_mark} = {c_lt_mark}.
238            builder
239                .when_ne(cols.lt_marker[i], AB::F::ZERO)
240                .when_ne(
241                    cols.lt_marker[i],
242                    AB::Expr::from_canonical_u8(3) - cols.c_lt_mark,
243                )
244                .assert_eq(AB::Expr::from(modulus[i]) - cols.c[i], cols.c_lt_diff);
245        }
246
247        // Check that b_lt_diff and c_lt_diff are positive
248        self.bus
249            .send_range(
250                cols.b_lt_diff - AB::Expr::ONE,
251                cols.c_lt_diff - AB::Expr::ONE,
252            )
253            .eval(builder, cols.is_valid - cols.is_setup);
254
255        let expected_opcode = AB::Expr::from_canonical_usize(self.offset)
256            + cols.is_setup
257                * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize)
258            + (AB::Expr::ONE - cols.is_setup)
259                * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::IS_EQ as usize);
260        let mut a: [AB::Expr; WRITE_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
261        a[0] = cols.cmp_result.into();
262
263        AdapterAirContext {
264            to_pc: None,
265            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
266            writes: [a].into(),
267            instruction: MinimalInstruction {
268                is_valid: cols.is_valid.into(),
269                opcode: expected_opcode,
270            }
271            .into(),
272        }
273    }
274
275    fn start_offset(&self) -> usize {
276        self.offset
277    }
278}
279
280#[repr(C)]
281#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
282pub struct ModularIsEqualCoreRecord<T, const READ_LIMBS: usize> {
283    #[serde(with = "BigArray")]
284    pub b: [T; READ_LIMBS],
285    #[serde(with = "BigArray")]
286    pub c: [T; READ_LIMBS],
287    pub cmp_result: T,
288    #[serde(with = "BigArray")]
289    pub eq_marker: [T; READ_LIMBS],
290    pub b_diff_idx: usize,
291    pub c_diff_idx: usize,
292    pub is_setup: bool,
293}
294
295pub struct ModularIsEqualCoreChip<
296    const READ_LIMBS: usize,
297    const WRITE_LIMBS: usize,
298    const LIMB_BITS: usize,
299> {
300    pub air: ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>,
301    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
302}
303
304impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
305    ModularIsEqualCoreChip<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
306{
307    pub fn new(
308        modulus: BigUint,
309        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
310        offset: usize,
311    ) -> Self {
312        Self {
313            air: ModularIsEqualCoreAir::new(modulus, bitwise_lookup_chip.bus(), offset),
314            bitwise_lookup_chip,
315        }
316    }
317}
318
319impl<
320        F: PrimeField32,
321        I: VmAdapterInterface<F>,
322        const READ_LIMBS: usize,
323        const WRITE_LIMBS: usize,
324        const LIMB_BITS: usize,
325    > VmCoreChip<F, I> for ModularIsEqualCoreChip<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
326where
327    I::Reads: Into<[[F; READ_LIMBS]; 2]>,
328    I::Writes: From<[[F; WRITE_LIMBS]; 1]>,
329{
330    type Record = ModularIsEqualCoreRecord<F, READ_LIMBS>;
331    type Air = ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>;
332
333    #[allow(clippy::type_complexity)]
334    fn execute_instruction(
335        &self,
336        instruction: &Instruction<F>,
337        _from_pc: u32,
338        reads: I::Reads,
339    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
340        let data: [[F; READ_LIMBS]; 2] = reads.into();
341        let b = data[0].map(|x| x.as_canonical_u32());
342        let c = data[1].map(|y| y.as_canonical_u32());
343        let (b_cmp, b_diff_idx) = run_unsigned_less_than::<READ_LIMBS>(&b, &self.air.modulus_limbs);
344        let (c_cmp, c_diff_idx) = run_unsigned_less_than::<READ_LIMBS>(&c, &self.air.modulus_limbs);
345        let is_setup = instruction.opcode.local_opcode_idx(self.air.offset)
346            == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize;
347
348        if !is_setup {
349            assert!(b_cmp, "{:?} >= {:?}", b, self.air.modulus_limbs);
350        }
351        assert!(c_cmp, "{:?} >= {:?}", c, self.air.modulus_limbs);
352        if !is_setup {
353            self.bitwise_lookup_chip.request_range(
354                self.air.modulus_limbs[b_diff_idx] - b[b_diff_idx] - 1,
355                self.air.modulus_limbs[c_diff_idx] - c[c_diff_idx] - 1,
356            );
357        }
358
359        let mut eq_marker = [F::ZERO; READ_LIMBS];
360        let mut cmp_result = F::ZERO;
361        self.air
362            .subair
363            .generate_subrow((&data[0], &data[1]), (&mut eq_marker, &mut cmp_result));
364
365        let mut writes = [F::ZERO; WRITE_LIMBS];
366        writes[0] = cmp_result;
367
368        let output = AdapterRuntimeContext::without_pc([writes]);
369        let record = ModularIsEqualCoreRecord {
370            is_setup,
371            b: data[0],
372            c: data[1],
373            cmp_result,
374            eq_marker,
375            b_diff_idx,
376            c_diff_idx,
377        };
378
379        Ok((output, record))
380    }
381
382    fn get_opcode_name(&self, opcode: usize) -> String {
383        format!(
384            "{:?}",
385            Rv32ModularArithmeticOpcode::from_usize(opcode - self.air.offset)
386        )
387    }
388
389    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
390        let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut();
391        row_slice.is_valid = F::ONE;
392        row_slice.is_setup = F::from_bool(record.is_setup);
393        row_slice.b = record.b;
394        row_slice.c = record.c;
395        row_slice.cmp_result = record.cmp_result;
396
397        row_slice.eq_marker = record.eq_marker;
398
399        if !record.is_setup {
400            row_slice.b_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.b_diff_idx])
401                - record.b[record.b_diff_idx];
402        }
403        row_slice.c_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.c_diff_idx])
404            - record.c[record.c_diff_idx];
405        row_slice.c_lt_mark = if record.b_diff_idx == record.c_diff_idx {
406            F::ONE
407        } else {
408            F::from_canonical_u8(2)
409        };
410        row_slice.lt_marker = from_fn(|i| {
411            if i == record.b_diff_idx {
412                F::ONE
413            } else if i == record.c_diff_idx {
414                row_slice.c_lt_mark
415            } else {
416                F::ZERO
417            }
418        });
419    }
420
421    fn air(&self) -> &Self::Air {
422        &self.air
423    }
424}
425
426// Returns (cmp_result, diff_idx)
427pub(super) fn run_unsigned_less_than<const NUM_LIMBS: usize>(
428    x: &[u32; NUM_LIMBS],
429    y: &[u32; NUM_LIMBS],
430) -> (bool, usize) {
431    for i in (0..NUM_LIMBS).rev() {
432        if x[i] != y[i] {
433            return (x[i] < y[i], i);
434        }
435    }
436    (false, NUM_LIMBS)
437}