openvm_algebra_circuit/modular_chip/
is_eq.rs

1use std::{
2    array::{self, from_fn},
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::{
9    arch::*,
10    system::memory::{
11        online::{GuestMemory, TracingMemory},
12        MemoryAuxColsFactory, POINTER_MAX_BITS,
13    },
14};
15use openvm_circuit_primitives::{
16    bigint::utils::big_uint_to_limbs,
17    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
18    is_equal_array::{IsEqArrayIo, IsEqArraySubAir},
19    AlignedBytesBorrow, SubAir, TraceSubRowGenerator,
20};
21use openvm_circuit_primitives_derive::AlignedBorrow;
22use openvm_instructions::{
23    instruction::Instruction,
24    program::DEFAULT_PC_STEP,
25    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
26    LocalOpcode,
27};
28use openvm_rv32_adapters::Rv32IsEqualModAdapterExecutor;
29use openvm_stark_backend::{
30    interaction::InteractionBuilder,
31    p3_air::{AirBuilder, BaseAir},
32    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
33    rap::BaseAirWithPublicValues,
34};
35
36use crate::modular_chip::VmModularIsEqualExecutor;
37// Given two numbers b and c, we want to prove that a) b == c or b != c, depending on
38// result of cmp_result and b) b, c < N for some modulus N that is passed into the AIR
39// at runtime (i.e. when chip is instantiated).
40
41#[repr(C)]
42#[derive(AlignedBorrow, Debug)]
43pub struct ModularIsEqualCoreCols<T, const READ_LIMBS: usize> {
44    pub is_valid: T,
45    pub is_setup: T,
46    pub b: [T; READ_LIMBS],
47    pub c: [T; READ_LIMBS],
48    pub cmp_result: T,
49
50    // Auxiliary columns for subair EQ comparison between b and c.
51    pub eq_marker: [T; READ_LIMBS],
52
53    // Auxiliary columns to ensure both b and c are smaller than modulus N. Let b_diff_idx be
54    // an index such that b[b_diff_idx] < N[b_diff_idx] and b[i] = N[i] for all i > b_diff_idx,
55    // where larger indices correspond to more significant limbs. Such an index exists iff b < N.
56    // Define c_diff_idx analogously. Then let b_lt_diff = N[b_diff_idx] - b[b_diff_idx] and
57    // c_lt_diff = N[c_diff_idx] - c[c_diff_idx], where both must be in [0, 2^LIMB_BITS).
58    //
59    // To constrain the above, we will use lt_marker, which will indicate where b_diff_idx and
60    // c_diff_idx are. Set lt_marker[b_diff_idx] = 1, lt_marker[c_diff_idx] = c_lt_mark, and 0
61    // everywhere else. If b_diff_idx == c_diff_idx then c_lt_mark = 1, else c_lt_mark = 2. The
62    // purpose of c_lt_mark is to handle the edge case where b_diff_idx == c_diff_idx (because
63    // we cannot set lt_marker[b_diff_idx] to 1 and 2 at the same time).
64    pub lt_marker: [T; READ_LIMBS],
65    pub b_lt_diff: T,
66    pub c_lt_diff: T,
67    pub c_lt_mark: T,
68}
69
70#[derive(Clone, Debug)]
71pub struct ModularIsEqualCoreAir<
72    const READ_LIMBS: usize,
73    const WRITE_LIMBS: usize,
74    const LIMB_BITS: usize,
75> {
76    pub bus: BitwiseOperationLookupBus,
77    pub subair: IsEqArraySubAir<READ_LIMBS>,
78    pub modulus_limbs: [u32; READ_LIMBS],
79    pub offset: usize,
80}
81
82impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
83    ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
84{
85    pub fn new(modulus: BigUint, bus: BitwiseOperationLookupBus, offset: usize) -> Self {
86        let mod_vec = big_uint_to_limbs(&modulus, LIMB_BITS);
87        assert!(mod_vec.len() <= READ_LIMBS);
88        let modulus_limbs = array::from_fn(|i| {
89            if i < mod_vec.len() {
90                mod_vec[i] as u32
91            } else {
92                0
93            }
94        });
95        Self {
96            bus,
97            subair: IsEqArraySubAir::<READ_LIMBS>,
98            modulus_limbs,
99            offset,
100        }
101    }
102}
103
104impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
105    for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
106{
107    fn width(&self) -> usize {
108        ModularIsEqualCoreCols::<F, READ_LIMBS>::width()
109    }
110}
111impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
112    BaseAirWithPublicValues<F> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
113{
114}
115
116impl<AB, I, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
117    VmCoreAir<AB, I> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
118where
119    AB: InteractionBuilder,
120    I: VmAdapterInterface<AB::Expr>,
121    I::Reads: From<[[AB::Expr; READ_LIMBS]; 2]>,
122    I::Writes: From<[[AB::Expr; WRITE_LIMBS]; 1]>,
123    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
124{
125    fn eval(
126        &self,
127        builder: &mut AB,
128        local_core: &[AB::Var],
129        _from_pc: AB::Var,
130    ) -> AdapterAirContext<AB::Expr, I> {
131        let cols: &ModularIsEqualCoreCols<_, READ_LIMBS> = local_core.borrow();
132
133        builder.assert_bool(cols.is_valid);
134        builder.assert_bool(cols.is_setup);
135        builder.when(cols.is_setup).assert_one(cols.is_valid);
136        builder.assert_bool(cols.cmp_result);
137
138        // Constrain that either b == c or b != c, depending on the value of cmp_result.
139        let eq_subair_io = IsEqArrayIo {
140            x: cols.b.map(Into::into),
141            y: cols.c.map(Into::into),
142            out: cols.cmp_result.into(),
143            condition: cols.is_valid - cols.is_setup,
144        };
145        self.subair.eval(builder, (eq_subair_io, cols.eq_marker));
146
147        // Constrain that auxiliary columns lt_columns and c_lt_mark are as defined above.
148        // When c_lt_mark is 1, lt_marker should have exactly one index i where lt_marker[i]
149        // is 1, and be 0 elsewhere. When c_lt_mark is 2, lt_marker[i] should have an
150        // additional index j such that lt_marker[j] is 2. To constrain this:
151        //
152        // * When c_lt_mark = 1 the sum of all lt_marker[i] must be 1
153        // * When c_lt_mark = 2 the sum of lt_marker[i] * (lt_marker[i] - 1) must be 2.
154        //   Additionally, the sum of all lt_marker[i] must be 3.
155        //
156        // All this doesn't apply when is_setup.
157        let lt_marker_sum = cols
158            .lt_marker
159            .iter()
160            .fold(AB::Expr::ZERO, |acc, x| acc + *x);
161        let lt_marker_one_check_sum = cols
162            .lt_marker
163            .iter()
164            .fold(AB::Expr::ZERO, |acc, x| acc + (*x) * (*x - AB::F::ONE));
165
166        // Constrain that c_lt_mark is either 1 or 2.
167        builder
168            .when(cols.is_valid - cols.is_setup)
169            .assert_bool(cols.c_lt_mark - AB::F::ONE);
170
171        // If c_lt_mark is 1, then lt_marker_sum is 1
172        builder
173            .when(cols.is_valid - cols.is_setup)
174            .when_ne(cols.c_lt_mark, AB::F::from_u8(2))
175            .assert_one(lt_marker_sum.clone());
176
177        // If c_lt_mark is 2, then lt_marker_sum is 3
178        builder
179            .when(cols.is_valid - cols.is_setup)
180            .when_ne(cols.c_lt_mark, AB::F::ONE)
181            .assert_eq(lt_marker_sum.clone(), AB::F::from_u8(3));
182
183        // This constraint, along with the constraint (below) that lt_marker[i] is 0, 1, or 2,
184        // ensures that lt_marker has exactly one 2.
185        builder
186            .when_ne(cols.c_lt_mark, AB::F::ONE)
187            .assert_eq(lt_marker_one_check_sum, cols.is_valid * AB::F::from_u8(2));
188
189        // Handle the setup row constraints.
190        // When is_setup = 1, constrain c_lt_mark = 2 and lt_marker_sum = 2
191        // This ensures that lt_marker has exactly one 2 and the remaining entries are 0.
192        // Since lt_marker has no 1, we will end up constraining that b[i] = N[i] for all i
193        // instead of just for i > b_diff_idx.
194        builder
195            .when(cols.is_setup)
196            .assert_eq(cols.c_lt_mark, AB::F::from_u8(2));
197        builder
198            .when(cols.is_setup)
199            .assert_eq(lt_marker_sum.clone(), AB::F::from_u8(2));
200
201        // Constrain that b, c < N (i.e. modulus).
202        let modulus = self.modulus_limbs.map(AB::F::from_u32);
203        let mut prefix_sum = AB::Expr::ZERO;
204
205        for i in (0..READ_LIMBS).rev() {
206            prefix_sum += cols.lt_marker[i].into();
207            builder.assert_zero(
208                cols.lt_marker[i]
209                    * (cols.lt_marker[i] - AB::F::ONE)
210                    * (cols.lt_marker[i] - cols.c_lt_mark),
211            );
212
213            // Constrain b < N.
214            // First, we constrain b[i] = N[i] for i > b_diff_idx.
215            // We do this by constraining that b[i] = N[i] when prefix_sum is not 1 or
216            // lt_marker_sum.
217            //  - If is_setup = 0, then lt_marker_sum is either 1 or 3. In this case, prefix_sum is
218            //    0, 1, 2, or 3. It can be verified by casework that i > b_diff_idx iff prefix_sum
219            //    is not 1 or lt_marker_sum.
220            //  - If is_setup = 1, then we want to constrain b[i] = N[i] for all i. In this case,
221            //    lt_marker_sum is 2 and prefix_sum is 0 or 2. So we constrain b[i] = N[i] when
222            //    prefix_sum is not 1, which works.
223            builder
224                .when_ne(prefix_sum.clone(), AB::F::ONE)
225                .when_ne(prefix_sum.clone(), lt_marker_sum.clone() - cols.is_setup)
226                .assert_eq(cols.b[i], modulus[i]);
227            // Note that lt_marker[i] is either 0, 1, or 2 and lt_marker[i] being 1 indicates b[i] <
228            // N[i] (i.e. i == b_diff_idx).
229            builder
230                .when_ne(cols.lt_marker[i], AB::F::ZERO)
231                .when_ne(cols.lt_marker[i], AB::F::from_u8(2))
232                .assert_eq(AB::Expr::from(modulus[i]) - cols.b[i], cols.b_lt_diff);
233
234            // Constrain c < N.
235            // First, we constrain c[i] = N[i] for i > c_diff_idx.
236            // We do this by constraining that c[i] = N[i] when prefix_sum is not c_lt_mark or
237            // lt_marker_sum. It can be verified by casework that i > c_diff_idx iff
238            // prefix_sum is not c_lt_mark or lt_marker_sum.
239            builder
240                .when_ne(prefix_sum.clone(), cols.c_lt_mark)
241                .when_ne(prefix_sum.clone(), lt_marker_sum.clone())
242                .assert_eq(cols.c[i], modulus[i]);
243            // Note that lt_marker[i] is either 0, 1, or 2 and lt_marker[i] being c_lt_mark
244            // indicates c[i] < N[i] (i.e. i == c_diff_idx). Since c_lt_mark is 1 or 2,
245            // we have {0, 1, 2} \ {0, 3 - c_lt_mark} = {c_lt_mark}.
246            builder
247                .when_ne(cols.lt_marker[i], AB::F::ZERO)
248                .when_ne(cols.lt_marker[i], AB::Expr::from_u8(3) - cols.c_lt_mark)
249                .assert_eq(AB::Expr::from(modulus[i]) - cols.c[i], cols.c_lt_diff);
250        }
251
252        // Check that b_lt_diff and c_lt_diff are positive
253        self.bus
254            .send_range(
255                cols.b_lt_diff - AB::Expr::ONE,
256                cols.c_lt_diff - AB::Expr::ONE,
257            )
258            .eval(builder, cols.is_valid - cols.is_setup);
259
260        let expected_opcode = AB::Expr::from_usize(self.offset)
261            + cols.is_setup
262                * AB::Expr::from_usize(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize)
263            + (AB::Expr::ONE - cols.is_setup)
264                * AB::Expr::from_usize(Rv32ModularArithmeticOpcode::IS_EQ as usize);
265        let mut a: [AB::Expr; WRITE_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
266        a[0] = cols.cmp_result.into();
267
268        AdapterAirContext {
269            to_pc: None,
270            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
271            writes: [a].into(),
272            instruction: MinimalInstruction {
273                is_valid: cols.is_valid.into(),
274                opcode: expected_opcode,
275            }
276            .into(),
277        }
278    }
279
280    fn start_offset(&self) -> usize {
281        self.offset
282    }
283}
284
285#[repr(C)]
286#[derive(AlignedBytesBorrow, Debug)]
287pub struct ModularIsEqualRecord<const READ_LIMBS: usize> {
288    pub is_setup: bool,
289    pub b: [u8; READ_LIMBS],
290    pub c: [u8; READ_LIMBS],
291}
292
293#[derive(derive_new::new, Clone)]
294pub struct ModularIsEqualExecutor<
295    A,
296    const READ_LIMBS: usize,
297    const WRITE_LIMBS: usize,
298    const LIMB_BITS: usize,
299> {
300    adapter: A,
301    pub offset: usize,
302    pub modulus_limbs: [u8; READ_LIMBS],
303}
304
305#[derive(derive_new::new, Clone)]
306pub struct ModularIsEqualFiller<
307    A,
308    const READ_LIMBS: usize,
309    const WRITE_LIMBS: usize,
310    const LIMB_BITS: usize,
311> {
312    adapter: A,
313    pub offset: usize,
314    pub modulus_limbs: [u8; READ_LIMBS],
315    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
316}
317
318impl<F, A, RA, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
319    PreflightExecutor<F, RA> for ModularIsEqualExecutor<A, READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
320where
321    F: PrimeField32,
322    A: 'static
323        + AdapterTraceExecutor<
324            F,
325            ReadData: Into<[[u8; READ_LIMBS]; 2]>,
326            WriteData: From<[u8; WRITE_LIMBS]>,
327        >,
328    for<'buf> RA: RecordArena<
329        'buf,
330        EmptyAdapterCoreLayout<F, A>,
331        (
332            A::RecordMut<'buf>,
333            &'buf mut ModularIsEqualRecord<READ_LIMBS>,
334        ),
335    >,
336{
337    fn execute(
338        &self,
339        state: VmStateMut<F, TracingMemory, RA>,
340        instruction: &Instruction<F>,
341    ) -> Result<(), ExecutionError> {
342        let Instruction { opcode, .. } = instruction;
343
344        let local_opcode =
345            Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.offset));
346        matches!(
347            local_opcode,
348            Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ
349        );
350
351        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
352
353        A::start(*state.pc, state.memory, &mut adapter_record);
354        [core_record.b, core_record.c] = self
355            .adapter
356            .read(state.memory, instruction, &mut adapter_record)
357            .into();
358
359        core_record.is_setup = instruction.opcode.local_opcode_idx(self.offset)
360            == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize;
361
362        let mut write_data = [0u8; WRITE_LIMBS];
363        write_data[0] = (core_record.b == core_record.c) as u8;
364
365        self.adapter.write(
366            state.memory,
367            instruction,
368            write_data.into(),
369            &mut adapter_record,
370        );
371
372        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
373
374        Ok(())
375    }
376
377    fn get_opcode_name(&self, opcode: usize) -> String {
378        format!(
379            "{:?}",
380            Rv32ModularArithmeticOpcode::from_usize(opcode - self.offset)
381        )
382    }
383}
384
385impl<F, A, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
386    for ModularIsEqualFiller<A, READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
387where
388    F: PrimeField32,
389    A: 'static + AdapterTraceFiller<F>,
390{
391    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
392        let (adapter_row, mut core_row) = row_slice.split_at_mut(A::WIDTH);
393        self.adapter.fill_trace_row(mem_helper, adapter_row);
394        // SAFETY:
395        // - row_slice is guaranteed by the caller to have at least A::WIDTH +
396        //   ModularIsEqualCoreCols::width() elements
397        // - caller ensures core_row contains a valid record written by the executor during trace
398        //   generation
399        let record: &ModularIsEqualRecord<READ_LIMBS> =
400            unsafe { get_record_from_slice(&mut core_row, ()) };
401        let cols: &mut ModularIsEqualCoreCols<F, READ_LIMBS> = core_row.borrow_mut();
402        let (b_cmp, b_diff_idx) =
403            run_unsigned_less_than::<READ_LIMBS>(&record.b, &self.modulus_limbs);
404        let (c_cmp, c_diff_idx) =
405            run_unsigned_less_than::<READ_LIMBS>(&record.c, &self.modulus_limbs);
406
407        if !record.is_setup {
408            assert!(b_cmp, "{:?} >= {:?}", record.b, self.modulus_limbs);
409        }
410        assert!(c_cmp, "{:?} >= {:?}", record.c, self.modulus_limbs);
411
412        // Writing in reverse order
413        cols.c_lt_mark = if b_diff_idx == c_diff_idx {
414            F::ONE
415        } else {
416            F::TWO
417        };
418
419        cols.c_lt_diff = F::from_u8(self.modulus_limbs[c_diff_idx] - record.c[c_diff_idx]);
420        if !record.is_setup {
421            cols.b_lt_diff = F::from_u8(self.modulus_limbs[b_diff_idx] - record.b[b_diff_idx]);
422            self.bitwise_lookup_chip.request_range(
423                (self.modulus_limbs[b_diff_idx] - record.b[b_diff_idx] - 1) as u32,
424                (self.modulus_limbs[c_diff_idx] - record.c[c_diff_idx] - 1) as u32,
425            );
426        } else {
427            cols.b_lt_diff = F::ZERO;
428        }
429
430        cols.lt_marker = from_fn(|i| {
431            if i == b_diff_idx {
432                F::ONE
433            } else if i == c_diff_idx {
434                cols.c_lt_mark
435            } else {
436                F::ZERO
437            }
438        });
439
440        cols.c = record.c.map(F::from_u8);
441        cols.b = record.b.map(F::from_u8);
442        let sub_air = IsEqArraySubAir::<READ_LIMBS>;
443        sub_air.generate_subrow(
444            (&cols.b, &cols.c),
445            (&mut cols.eq_marker, &mut cols.cmp_result),
446        );
447
448        cols.is_setup = F::from_bool(record.is_setup);
449        cols.is_valid = F::ONE;
450    }
451}
452
453impl<const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_LIMBS: usize>
454    VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_LIMBS>
455{
456    pub fn new(
457        adapter: Rv32IsEqualModAdapterExecutor<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
458        offset: usize,
459        modulus_limbs: [u8; TOTAL_LIMBS],
460    ) -> Self {
461        Self(ModularIsEqualExecutor::new(adapter, offset, modulus_limbs))
462    }
463}
464
465#[derive(AlignedBytesBorrow, Clone)]
466#[repr(C)]
467struct ModularIsEqualPreCompute<const READ_LIMBS: usize> {
468    a: u8,
469    rs_addrs: [u8; 2],
470    modulus_limbs: [u8; READ_LIMBS],
471}
472
473impl<const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_READ_SIZE: usize>
474    VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE>
475{
476    fn pre_compute_impl<F: PrimeField32>(
477        &self,
478        pc: u32,
479        inst: &Instruction<F>,
480        data: &mut ModularIsEqualPreCompute<TOTAL_READ_SIZE>,
481    ) -> Result<bool, StaticProgramError> {
482        let Instruction {
483            opcode,
484            a,
485            b,
486            c,
487            d,
488            e,
489            ..
490        } = inst;
491
492        let local_opcode =
493            Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.0.offset));
494
495        // Validate instruction format
496        let a = a.as_canonical_u32();
497        let b = b.as_canonical_u32();
498        let c = c.as_canonical_u32();
499        let d = d.as_canonical_u32();
500        let e = e.as_canonical_u32();
501        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
502            return Err(StaticProgramError::InvalidInstruction(pc));
503        }
504
505        if !matches!(
506            local_opcode,
507            Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ
508        ) {
509            return Err(StaticProgramError::InvalidInstruction(pc));
510        }
511
512        let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
513        *data = ModularIsEqualPreCompute {
514            a: a as u8,
515            rs_addrs,
516            modulus_limbs: self.0.modulus_limbs,
517        };
518
519        let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ISEQ;
520
521        Ok(is_setup)
522    }
523}
524
525macro_rules! dispatch {
526    ($execute_impl:ident, $is_setup:ident) => {
527        Ok(if $is_setup {
528            $execute_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true>
529        } else {
530            $execute_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false>
531        })
532    };
533}
534
535impl<F, const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_READ_SIZE: usize>
536    InterpreterExecutor<F> for VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE>
537where
538    F: PrimeField32,
539{
540    #[inline(always)]
541    fn pre_compute_size(&self) -> usize {
542        std::mem::size_of::<ModularIsEqualPreCompute<TOTAL_READ_SIZE>>()
543    }
544
545    #[cfg(not(feature = "tco"))]
546    fn pre_compute<Ctx: ExecutionCtxTrait>(
547        &self,
548        pc: u32,
549        inst: &Instruction<F>,
550        data: &mut [u8],
551    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
552        let pre_compute: &mut ModularIsEqualPreCompute<TOTAL_READ_SIZE> = data.borrow_mut();
553        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
554
555        dispatch!(execute_e1_handler, is_setup)
556    }
557
558    #[cfg(feature = "tco")]
559    fn handler<Ctx>(
560        &self,
561        pc: u32,
562        inst: &Instruction<F>,
563        data: &mut [u8],
564    ) -> Result<Handler<F, Ctx>, StaticProgramError>
565    where
566        Ctx: ExecutionCtxTrait,
567    {
568        let pre_compute: &mut ModularIsEqualPreCompute<TOTAL_READ_SIZE> = data.borrow_mut();
569        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
570
571        dispatch!(execute_e1_handler, is_setup)
572    }
573}
574
575#[cfg(feature = "aot")]
576impl<F, const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_READ_SIZE: usize> AotExecutor<F>
577    for VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE>
578where
579    F: PrimeField32,
580{
581}
582
583impl<F, const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_READ_SIZE: usize>
584    InterpreterMeteredExecutor<F>
585    for VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE>
586where
587    F: PrimeField32,
588{
589    #[inline(always)]
590    fn metered_pre_compute_size(&self) -> usize {
591        std::mem::size_of::<E2PreCompute<ModularIsEqualPreCompute<TOTAL_READ_SIZE>>>()
592    }
593
594    #[cfg(not(feature = "tco"))]
595    fn metered_pre_compute<Ctx: MeteredExecutionCtxTrait>(
596        &self,
597        chip_idx: usize,
598        pc: u32,
599        inst: &Instruction<F>,
600        data: &mut [u8],
601    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
602        let pre_compute: &mut E2PreCompute<ModularIsEqualPreCompute<TOTAL_READ_SIZE>> =
603            data.borrow_mut();
604        pre_compute.chip_idx = chip_idx as u32;
605
606        let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
607
608        dispatch!(execute_e2_handler, is_setup)
609    }
610
611    #[cfg(feature = "tco")]
612    fn metered_handler<Ctx: MeteredExecutionCtxTrait>(
613        &self,
614        chip_idx: usize,
615        pc: u32,
616        inst: &Instruction<F>,
617        data: &mut [u8],
618    ) -> Result<Handler<F, Ctx>, StaticProgramError> {
619        let pre_compute: &mut E2PreCompute<ModularIsEqualPreCompute<TOTAL_READ_SIZE>> =
620            data.borrow_mut();
621        pre_compute.chip_idx = chip_idx as u32;
622
623        let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
624
625        dispatch!(execute_e2_handler, is_setup)
626    }
627}
628
629#[cfg(feature = "aot")]
630impl<F, const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_READ_SIZE: usize>
631    AotMeteredExecutor<F> for VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE>
632where
633    F: PrimeField32,
634{
635}
636#[create_handler]
637#[inline(always)]
638unsafe fn execute_e1_impl<
639    F: PrimeField32,
640    CTX: ExecutionCtxTrait,
641    const NUM_LANES: usize,
642    const LANE_SIZE: usize,
643    const TOTAL_READ_SIZE: usize,
644    const IS_SETUP: bool,
645>(
646    pre_compute: *const u8,
647    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
648) {
649    let pre_compute: &ModularIsEqualPreCompute<TOTAL_READ_SIZE> = std::slice::from_raw_parts(
650        pre_compute,
651        size_of::<ModularIsEqualPreCompute<TOTAL_READ_SIZE>>(),
652    )
653    .borrow();
654
655    execute_e12_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, IS_SETUP>(
656        pre_compute,
657        exec_state,
658    );
659}
660
661#[create_handler]
662#[inline(always)]
663unsafe fn execute_e2_impl<
664    F: PrimeField32,
665    CTX: MeteredExecutionCtxTrait,
666    const NUM_LANES: usize,
667    const LANE_SIZE: usize,
668    const TOTAL_READ_SIZE: usize,
669    const IS_SETUP: bool,
670>(
671    pre_compute: *const u8,
672    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
673) {
674    let pre_compute: &E2PreCompute<ModularIsEqualPreCompute<TOTAL_READ_SIZE>> =
675        std::slice::from_raw_parts(
676            pre_compute,
677            size_of::<E2PreCompute<ModularIsEqualPreCompute<TOTAL_READ_SIZE>>>(),
678        )
679        .borrow();
680    exec_state
681        .ctx
682        .on_height_change(pre_compute.chip_idx as usize, 1);
683    execute_e12_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, IS_SETUP>(
684        &pre_compute.data,
685        exec_state,
686    );
687}
688
689#[inline(always)]
690unsafe fn execute_e12_impl<
691    F: PrimeField32,
692    CTX: ExecutionCtxTrait,
693    const NUM_LANES: usize,
694    const LANE_SIZE: usize,
695    const TOTAL_READ_SIZE: usize,
696    const IS_SETUP: bool,
697>(
698    pre_compute: &ModularIsEqualPreCompute<TOTAL_READ_SIZE>,
699    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
700) {
701    // Read register values
702    let rs_vals = pre_compute
703        .rs_addrs
704        .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32)));
705
706    // Read memory values
707    let [b, c]: [[u8; TOTAL_READ_SIZE]; 2] = rs_vals.map(|address| {
708        debug_assert!(address as usize + TOTAL_READ_SIZE - 1 < (1 << POINTER_MAX_BITS));
709        from_fn::<_, NUM_LANES, _>(|i| {
710            exec_state.vm_read::<_, LANE_SIZE>(RV32_MEMORY_AS, address + (i * LANE_SIZE) as u32)
711        })
712        .concat()
713        .try_into()
714        .unwrap()
715    });
716
717    if !IS_SETUP {
718        let (b_cmp, _) = run_unsigned_less_than::<TOTAL_READ_SIZE>(&b, &pre_compute.modulus_limbs);
719        debug_assert!(b_cmp, "{:?} >= {:?}", b, pre_compute.modulus_limbs);
720    }
721
722    let (c_cmp, _) = run_unsigned_less_than::<TOTAL_READ_SIZE>(&c, &pre_compute.modulus_limbs);
723    debug_assert!(c_cmp, "{:?} >= {:?}", c, pre_compute.modulus_limbs);
724
725    // Compute result
726    let mut write_data = [0u8; RV32_REGISTER_NUM_LIMBS];
727    write_data[0] = (b == c) as u8;
728
729    // Write result to register
730    exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &write_data);
731
732    let pc = exec_state.pc();
733    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
734}
735
736// Returns (cmp_result, diff_idx)
737#[inline(always)]
738pub(super) fn run_unsigned_less_than<const NUM_LIMBS: usize>(
739    x: &[u8; NUM_LIMBS],
740    y: &[u8; NUM_LIMBS],
741) -> (bool, usize) {
742    for i in (0..NUM_LIMBS).rev() {
743        if x[i] != y[i] {
744            return (x[i] < y[i], i);
745        }
746    }
747    (false, NUM_LIMBS)
748}