openvm_algebra_circuit/modular_chip/
is_eq.rs

1use std::{
2    array::{self, from_fn},
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::arch::{
9    AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
10    VmCoreAir, VmCoreChip,
11};
12use openvm_circuit_primitives::{
13    bigint::utils::big_uint_to_limbs,
14    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
15    is_equal_array::{IsEqArrayIo, IsEqArraySubAir},
16    SubAir, TraceSubRowGenerator,
17};
18use openvm_circuit_primitives_derive::AlignedBorrow;
19use openvm_instructions::{instruction::Instruction, LocalOpcode};
20use openvm_stark_backend::{
21    interaction::InteractionBuilder,
22    p3_air::{AirBuilder, BaseAir},
23    p3_field::{Field, FieldAlgebra, PrimeField32},
24    rap::BaseAirWithPublicValues,
25};
26use serde::{Deserialize, Serialize};
27use serde_big_array::BigArray;
28// Given two numbers b and c, we want to prove that a) b == c or b != c, depending on
29// result of cmp_result and b) b, c < N for some modulus N that is passed into the AIR
30// at runtime (i.e. when chip is instantiated).
31
32#[repr(C)]
33#[derive(AlignedBorrow)]
34pub struct ModularIsEqualCoreCols<T, const READ_LIMBS: usize> {
35    pub is_valid: T,
36    pub is_setup: T,
37    pub b: [T; READ_LIMBS],
38    pub c: [T; READ_LIMBS],
39    pub cmp_result: T,
40
41    // Auxiliary columns for subair EQ comparison between b and c.
42    pub eq_marker: [T; READ_LIMBS],
43
44    // Auxiliary columns to ensure both b and c are smaller than modulus N. Let b_diff_idx be
45    // an index such that b[b_diff_idx] < N[b_diff_idx] and b[i] = N[i] for all i > b_diff_idx,
46    // where larger indices correspond to more significant limbs. Such an index exists iff b < N.
47    // Define c_diff_idx analogously. Then let b_lt_diff = N[b_diff_idx] - b[b_diff_idx] and
48    // c_lt_diff = N[c_diff_idx] - c[c_diff_idx], where both must be in [0, 2^LIMB_BITS).
49    //
50    // To constrain the above, we will use lt_marker, which will indicate where b_diff_idx and c_diff_idx are.
51    // Set lt_marker[b_diff_idx] = 1, lt_marker[c_diff_idx] = c_lt_mark, and 0 everywhere
52    // else. If b_diff_idx == c_diff_idx then c_lt_mark = 1, else c_lt_mark = 2. The purpose of
53    // c_lt_mark is to handle the edge case where b_diff_idx == c_diff_idx (because we cannot set
54    // lt_marker[b_diff_idx] to 1 and 2 at the same time).
55    pub lt_marker: [T; READ_LIMBS],
56    pub b_lt_diff: T,
57    pub c_lt_diff: T,
58    pub c_lt_mark: T,
59}
60
61#[derive(Clone, Debug)]
62pub struct ModularIsEqualCoreAir<
63    const READ_LIMBS: usize,
64    const WRITE_LIMBS: usize,
65    const LIMB_BITS: usize,
66> {
67    pub bus: BitwiseOperationLookupBus,
68    pub subair: IsEqArraySubAir<READ_LIMBS>,
69    pub modulus_limbs: [u32; READ_LIMBS],
70    pub offset: usize,
71}
72
73impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
74    ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
75{
76    pub fn new(modulus: BigUint, bus: BitwiseOperationLookupBus, offset: usize) -> Self {
77        let mod_vec = big_uint_to_limbs(&modulus, LIMB_BITS);
78        assert!(mod_vec.len() <= READ_LIMBS);
79        let modulus_limbs = array::from_fn(|i| {
80            if i < mod_vec.len() {
81                mod_vec[i] as u32
82            } else {
83                0
84            }
85        });
86        Self {
87            bus,
88            subair: IsEqArraySubAir::<READ_LIMBS>,
89            modulus_limbs,
90            offset,
91        }
92    }
93}
94
95impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
96    for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
97{
98    fn width(&self) -> usize {
99        ModularIsEqualCoreCols::<F, READ_LIMBS>::width()
100    }
101}
102impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
103    BaseAirWithPublicValues<F> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
104{
105}
106
107impl<AB, I, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
108    VmCoreAir<AB, I> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
109where
110    AB: InteractionBuilder,
111    I: VmAdapterInterface<AB::Expr>,
112    I::Reads: From<[[AB::Expr; READ_LIMBS]; 2]>,
113    I::Writes: From<[[AB::Expr; WRITE_LIMBS]; 1]>,
114    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
115{
116    fn eval(
117        &self,
118        builder: &mut AB,
119        local_core: &[AB::Var],
120        _from_pc: AB::Var,
121    ) -> AdapterAirContext<AB::Expr, I> {
122        let cols: &ModularIsEqualCoreCols<_, READ_LIMBS> = local_core.borrow();
123
124        builder.assert_bool(cols.is_valid);
125        builder.assert_bool(cols.is_setup);
126        builder.when(cols.is_setup).assert_one(cols.is_valid);
127        builder.assert_bool(cols.cmp_result);
128
129        // Constrain that either b == c or b != c, depending on the value of cmp_result.
130        let eq_subair_io = IsEqArrayIo {
131            x: cols.b.map(Into::into),
132            y: cols.c.map(Into::into),
133            out: cols.cmp_result.into(),
134            condition: cols.is_valid - cols.is_setup,
135        };
136        self.subair.eval(builder, (eq_subair_io, cols.eq_marker));
137
138        // Constrain that auxiliary columns lt_columns and c_lt_mark are as defined above.
139        // When c_lt_mark is 1, lt_marker should have exactly one index i where lt_marker[i]
140        // is 1, and be 0 elsewhere. When c_lt_mark is 2, lt_marker[i] should have an
141        // additional index j such that lt_marker[j] is 2. To constrain this:
142        //
143        // * When c_lt_mark = 1 the sum of all lt_marker[i] must be 1
144        // * When c_lt_mark = 2 the sum of lt_marker[i] * (lt_marker[i] - 1) must be 2.
145        //   Additionally, the sum of all lt_marker[i] must be 3.
146        //
147        // All this doesn't apply when is_setup.
148        let lt_marker_sum = cols
149            .lt_marker
150            .iter()
151            .fold(AB::Expr::ZERO, |acc, x| acc + *x);
152        let lt_marker_one_check_sum = cols
153            .lt_marker
154            .iter()
155            .fold(AB::Expr::ZERO, |acc, x| acc + (*x) * (*x - AB::F::ONE));
156
157        // Constrain that c_lt_mark is either 1 or 2.
158        builder
159            .when(cols.is_valid - cols.is_setup)
160            .assert_bool(cols.c_lt_mark - AB::F::ONE);
161
162        // If c_lt_mark is 1, then lt_marker_sum is 1
163        builder
164            .when(cols.is_valid - cols.is_setup)
165            .when_ne(cols.c_lt_mark, AB::F::from_canonical_u8(2))
166            .assert_one(lt_marker_sum.clone());
167
168        // If c_lt_mark is 2, then lt_marker_sum is 3
169        builder
170            .when(cols.is_valid - cols.is_setup)
171            .when_ne(cols.c_lt_mark, AB::F::ONE)
172            .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(3));
173
174        // This constraint, along with the constraint (below) that lt_marker[i] is 0, 1, or 2,
175        // ensures that lt_marker has exactly one 2.
176        builder.when_ne(cols.c_lt_mark, AB::F::ONE).assert_eq(
177            lt_marker_one_check_sum,
178            cols.is_valid * AB::F::from_canonical_u8(2),
179        );
180
181        // Handle the setup row constraints.
182        // When is_setup = 1, constrain c_lt_mark = 2 and lt_marker_sum = 2
183        // This ensures that lt_marker has exactly one 2 and the remaining entries are 0.
184        // Since lt_marker has no 1, we will end up constraining that b[i] = N[i] for all i
185        // instead of just for i > b_diff_idx.
186        builder
187            .when(cols.is_setup)
188            .assert_eq(cols.c_lt_mark, AB::F::from_canonical_u8(2));
189        builder
190            .when(cols.is_setup)
191            .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(2));
192
193        // Constrain that b, c < N (i.e. modulus).
194        let modulus = self.modulus_limbs.map(AB::F::from_canonical_u32);
195        let mut prefix_sum = AB::Expr::ZERO;
196
197        for i in (0..READ_LIMBS).rev() {
198            prefix_sum += cols.lt_marker[i].into();
199            builder.assert_zero(
200                cols.lt_marker[i]
201                    * (cols.lt_marker[i] - AB::F::ONE)
202                    * (cols.lt_marker[i] - cols.c_lt_mark),
203            );
204
205            // Constrain b < N.
206            // First, we constrain b[i] = N[i] for i > b_diff_idx.
207            // We do this by constraining that b[i] = N[i] when prefix_sum is not 1 or lt_marker_sum.
208            //  - If is_setup = 0, then lt_marker_sum is either 1 or 3. In this case, prefix_sum is 0, 1, 2, or 3.
209            //    It can be verified by casework that i > b_diff_idx iff prefix_sum is not 1 or lt_marker_sum.
210            //  - If is_setup = 1, then we want to constrain b[i] = N[i] for all i. In this case, lt_marker_sum is 2
211            //    and prefix_sum is 0 or 2. So we constrain b[i] = N[i] when prefix_sum is not 1, which works.
212            builder
213                .when_ne(prefix_sum.clone(), AB::F::ONE)
214                .when_ne(prefix_sum.clone(), lt_marker_sum.clone() - cols.is_setup)
215                .assert_eq(cols.b[i], modulus[i]);
216            // Note that lt_marker[i] is either 0, 1, or 2 and lt_marker[i] being 1 indicates b[i] < N[i] (i.e. i == b_diff_idx).
217            builder
218                .when_ne(cols.lt_marker[i], AB::F::ZERO)
219                .when_ne(cols.lt_marker[i], AB::F::from_canonical_u8(2))
220                .assert_eq(AB::Expr::from(modulus[i]) - cols.b[i], cols.b_lt_diff);
221
222            // Constrain c < N.
223            // First, we constrain c[i] = N[i] for i > c_diff_idx.
224            // We do this by constraining that c[i] = N[i] when prefix_sum is not c_lt_mark or lt_marker_sum.
225            // It can be verified by casework that i > c_diff_idx iff prefix_sum is not c_lt_mark or lt_marker_sum.
226            builder
227                .when_ne(prefix_sum.clone(), cols.c_lt_mark)
228                .when_ne(prefix_sum.clone(), lt_marker_sum.clone())
229                .assert_eq(cols.c[i], modulus[i]);
230            // Note that lt_marker[i] is either 0, 1, or 2 and lt_marker[i] being c_lt_mark indicates c[i] < N[i] (i.e. i == c_diff_idx).
231            // Since c_lt_mark is 1 or 2, we have {0, 1, 2} \ {0, 3 - c_lt_mark} = {c_lt_mark}.
232            builder
233                .when_ne(cols.lt_marker[i], AB::F::ZERO)
234                .when_ne(
235                    cols.lt_marker[i],
236                    AB::Expr::from_canonical_u8(3) - cols.c_lt_mark,
237                )
238                .assert_eq(AB::Expr::from(modulus[i]) - cols.c[i], cols.c_lt_diff);
239        }
240
241        // Check that b_lt_diff and c_lt_diff are positive
242        self.bus
243            .send_range(
244                cols.b_lt_diff - AB::Expr::ONE,
245                cols.c_lt_diff - AB::Expr::ONE,
246            )
247            .eval(builder, cols.is_valid - cols.is_setup);
248
249        let expected_opcode = AB::Expr::from_canonical_usize(self.offset)
250            + cols.is_setup
251                * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize)
252            + (AB::Expr::ONE - cols.is_setup)
253                * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::IS_EQ as usize);
254        let mut a: [AB::Expr; WRITE_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
255        a[0] = cols.cmp_result.into();
256
257        AdapterAirContext {
258            to_pc: None,
259            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
260            writes: [a].into(),
261            instruction: MinimalInstruction {
262                is_valid: cols.is_valid.into(),
263                opcode: expected_opcode,
264            }
265            .into(),
266        }
267    }
268
269    fn start_offset(&self) -> usize {
270        self.offset
271    }
272}
273
274#[repr(C)]
275#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
276pub struct ModularIsEqualCoreRecord<T, const READ_LIMBS: usize> {
277    #[serde(with = "BigArray")]
278    pub b: [T; READ_LIMBS],
279    #[serde(with = "BigArray")]
280    pub c: [T; READ_LIMBS],
281    pub cmp_result: T,
282    #[serde(with = "BigArray")]
283    pub eq_marker: [T; READ_LIMBS],
284    pub b_diff_idx: usize,
285    pub c_diff_idx: usize,
286    pub is_setup: bool,
287}
288
289pub struct ModularIsEqualCoreChip<
290    const READ_LIMBS: usize,
291    const WRITE_LIMBS: usize,
292    const LIMB_BITS: usize,
293> {
294    pub air: ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>,
295    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
296}
297
298impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
299    ModularIsEqualCoreChip<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
300{
301    pub fn new(
302        modulus: BigUint,
303        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
304        offset: usize,
305    ) -> Self {
306        Self {
307            air: ModularIsEqualCoreAir::new(modulus, bitwise_lookup_chip.bus(), offset),
308            bitwise_lookup_chip,
309        }
310    }
311}
312
313impl<
314        F: PrimeField32,
315        I: VmAdapterInterface<F>,
316        const READ_LIMBS: usize,
317        const WRITE_LIMBS: usize,
318        const LIMB_BITS: usize,
319    > VmCoreChip<F, I> for ModularIsEqualCoreChip<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
320where
321    I::Reads: Into<[[F; READ_LIMBS]; 2]>,
322    I::Writes: From<[[F; WRITE_LIMBS]; 1]>,
323{
324    type Record = ModularIsEqualCoreRecord<F, READ_LIMBS>;
325    type Air = ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>;
326
327    #[allow(clippy::type_complexity)]
328    fn execute_instruction(
329        &self,
330        instruction: &Instruction<F>,
331        _from_pc: u32,
332        reads: I::Reads,
333    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
334        let data: [[F; READ_LIMBS]; 2] = reads.into();
335        let b = data[0].map(|x| x.as_canonical_u32());
336        let c = data[1].map(|y| y.as_canonical_u32());
337        let (b_cmp, b_diff_idx) = run_unsigned_less_than::<READ_LIMBS>(&b, &self.air.modulus_limbs);
338        let (c_cmp, c_diff_idx) = run_unsigned_less_than::<READ_LIMBS>(&c, &self.air.modulus_limbs);
339        let is_setup = instruction.opcode.local_opcode_idx(self.air.offset)
340            == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize;
341
342        if !is_setup {
343            assert!(b_cmp, "{:?} >= {:?}", b, self.air.modulus_limbs);
344        }
345        assert!(c_cmp, "{:?} >= {:?}", c, self.air.modulus_limbs);
346        if !is_setup {
347            self.bitwise_lookup_chip.request_range(
348                self.air.modulus_limbs[b_diff_idx] - b[b_diff_idx] - 1,
349                self.air.modulus_limbs[c_diff_idx] - c[c_diff_idx] - 1,
350            );
351        }
352
353        let mut eq_marker = [F::ZERO; READ_LIMBS];
354        let mut cmp_result = F::ZERO;
355        self.air
356            .subair
357            .generate_subrow((&data[0], &data[1]), (&mut eq_marker, &mut cmp_result));
358
359        let mut writes = [F::ZERO; WRITE_LIMBS];
360        writes[0] = cmp_result;
361
362        let output = AdapterRuntimeContext::without_pc([writes]);
363        let record = ModularIsEqualCoreRecord {
364            is_setup,
365            b: data[0],
366            c: data[1],
367            cmp_result,
368            eq_marker,
369            b_diff_idx,
370            c_diff_idx,
371        };
372
373        Ok((output, record))
374    }
375
376    fn get_opcode_name(&self, opcode: usize) -> String {
377        format!(
378            "{:?}",
379            Rv32ModularArithmeticOpcode::from_usize(opcode - self.air.offset)
380        )
381    }
382
383    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
384        let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut();
385        row_slice.is_valid = F::ONE;
386        row_slice.is_setup = F::from_bool(record.is_setup);
387        row_slice.b = record.b;
388        row_slice.c = record.c;
389        row_slice.cmp_result = record.cmp_result;
390
391        row_slice.eq_marker = record.eq_marker;
392
393        if !record.is_setup {
394            row_slice.b_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.b_diff_idx])
395                - record.b[record.b_diff_idx];
396        }
397        row_slice.c_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.c_diff_idx])
398            - record.c[record.c_diff_idx];
399        row_slice.c_lt_mark = if record.b_diff_idx == record.c_diff_idx {
400            F::ONE
401        } else {
402            F::from_canonical_u8(2)
403        };
404        row_slice.lt_marker = from_fn(|i| {
405            if i == record.b_diff_idx {
406                F::ONE
407            } else if i == record.c_diff_idx {
408                row_slice.c_lt_mark
409            } else {
410                F::ZERO
411            }
412        });
413    }
414
415    fn air(&self) -> &Self::Air {
416        &self.air
417    }
418}
419
420// Returns (cmp_result, diff_idx)
421pub(super) fn run_unsigned_less_than<const NUM_LIMBS: usize>(
422    x: &[u32; NUM_LIMBS],
423    y: &[u32; NUM_LIMBS],
424) -> (bool, usize) {
425    for i in (0..NUM_LIMBS).rev() {
426        if x[i] != y[i] {
427            return (x[i] < y[i], i);
428        }
429    }
430    (false, NUM_LIMBS)
431}