openvm_algebra_circuit/modular_chip/
is_eq.rs

1use std::{
2    array::{self, from_fn},
3    borrow::{Borrow, BorrowMut},
4};
5
6use num_bigint::BigUint;
7use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
8use openvm_circuit::{
9    arch::*,
10    system::memory::{
11        online::{GuestMemory, TracingMemory},
12        MemoryAuxColsFactory, POINTER_MAX_BITS,
13    },
14};
15use openvm_circuit_primitives::{
16    bigint::utils::big_uint_to_limbs,
17    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
18    is_equal_array::{IsEqArrayIo, IsEqArraySubAir},
19    AlignedBytesBorrow, SubAir, TraceSubRowGenerator,
20};
21use openvm_circuit_primitives_derive::AlignedBorrow;
22use openvm_instructions::{
23    instruction::Instruction,
24    program::DEFAULT_PC_STEP,
25    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
26    LocalOpcode,
27};
28use openvm_rv32_adapters::Rv32IsEqualModAdapterExecutor;
29use openvm_stark_backend::{
30    interaction::InteractionBuilder,
31    p3_air::{AirBuilder, BaseAir},
32    p3_field::{Field, FieldAlgebra, PrimeField32},
33    rap::BaseAirWithPublicValues,
34};
35
36use crate::modular_chip::VmModularIsEqualExecutor;
37// Given two numbers b and c, we want to prove that a) b == c or b != c, depending on
38// result of cmp_result and b) b, c < N for some modulus N that is passed into the AIR
39// at runtime (i.e. when chip is instantiated).
40
41#[repr(C)]
42#[derive(AlignedBorrow, Debug)]
43pub struct ModularIsEqualCoreCols<T, const READ_LIMBS: usize> {
44    pub is_valid: T,
45    pub is_setup: T,
46    pub b: [T; READ_LIMBS],
47    pub c: [T; READ_LIMBS],
48    pub cmp_result: T,
49
50    // Auxiliary columns for subair EQ comparison between b and c.
51    pub eq_marker: [T; READ_LIMBS],
52
53    // Auxiliary columns to ensure both b and c are smaller than modulus N. Let b_diff_idx be
54    // an index such that b[b_diff_idx] < N[b_diff_idx] and b[i] = N[i] for all i > b_diff_idx,
55    // where larger indices correspond to more significant limbs. Such an index exists iff b < N.
56    // Define c_diff_idx analogously. Then let b_lt_diff = N[b_diff_idx] - b[b_diff_idx] and
57    // c_lt_diff = N[c_diff_idx] - c[c_diff_idx], where both must be in [0, 2^LIMB_BITS).
58    //
59    // To constrain the above, we will use lt_marker, which will indicate where b_diff_idx and
60    // c_diff_idx are. Set lt_marker[b_diff_idx] = 1, lt_marker[c_diff_idx] = c_lt_mark, and 0
61    // everywhere else. If b_diff_idx == c_diff_idx then c_lt_mark = 1, else c_lt_mark = 2. The
62    // purpose of c_lt_mark is to handle the edge case where b_diff_idx == c_diff_idx (because
63    // we cannot set lt_marker[b_diff_idx] to 1 and 2 at the same time).
64    pub lt_marker: [T; READ_LIMBS],
65    pub b_lt_diff: T,
66    pub c_lt_diff: T,
67    pub c_lt_mark: T,
68}
69
70#[derive(Clone, Debug)]
71pub struct ModularIsEqualCoreAir<
72    const READ_LIMBS: usize,
73    const WRITE_LIMBS: usize,
74    const LIMB_BITS: usize,
75> {
76    pub bus: BitwiseOperationLookupBus,
77    pub subair: IsEqArraySubAir<READ_LIMBS>,
78    pub modulus_limbs: [u32; READ_LIMBS],
79    pub offset: usize,
80}
81
82impl<const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
83    ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
84{
85    pub fn new(modulus: BigUint, bus: BitwiseOperationLookupBus, offset: usize) -> Self {
86        let mod_vec = big_uint_to_limbs(&modulus, LIMB_BITS);
87        assert!(mod_vec.len() <= READ_LIMBS);
88        let modulus_limbs = array::from_fn(|i| {
89            if i < mod_vec.len() {
90                mod_vec[i] as u32
91            } else {
92                0
93            }
94        });
95        Self {
96            bus,
97            subair: IsEqArraySubAir::<READ_LIMBS>,
98            modulus_limbs,
99            offset,
100        }
101    }
102}
103
104impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
105    for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
106{
107    fn width(&self) -> usize {
108        ModularIsEqualCoreCols::<F, READ_LIMBS>::width()
109    }
110}
111impl<F: Field, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
112    BaseAirWithPublicValues<F> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
113{
114}
115
116impl<AB, I, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
117    VmCoreAir<AB, I> for ModularIsEqualCoreAir<READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
118where
119    AB: InteractionBuilder,
120    I: VmAdapterInterface<AB::Expr>,
121    I::Reads: From<[[AB::Expr; READ_LIMBS]; 2]>,
122    I::Writes: From<[[AB::Expr; WRITE_LIMBS]; 1]>,
123    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
124{
125    fn eval(
126        &self,
127        builder: &mut AB,
128        local_core: &[AB::Var],
129        _from_pc: AB::Var,
130    ) -> AdapterAirContext<AB::Expr, I> {
131        let cols: &ModularIsEqualCoreCols<_, READ_LIMBS> = local_core.borrow();
132
133        builder.assert_bool(cols.is_valid);
134        builder.assert_bool(cols.is_setup);
135        builder.when(cols.is_setup).assert_one(cols.is_valid);
136        builder.assert_bool(cols.cmp_result);
137
138        // Constrain that either b == c or b != c, depending on the value of cmp_result.
139        let eq_subair_io = IsEqArrayIo {
140            x: cols.b.map(Into::into),
141            y: cols.c.map(Into::into),
142            out: cols.cmp_result.into(),
143            condition: cols.is_valid - cols.is_setup,
144        };
145        self.subair.eval(builder, (eq_subair_io, cols.eq_marker));
146
147        // Constrain that auxiliary columns lt_columns and c_lt_mark are as defined above.
148        // When c_lt_mark is 1, lt_marker should have exactly one index i where lt_marker[i]
149        // is 1, and be 0 elsewhere. When c_lt_mark is 2, lt_marker[i] should have an
150        // additional index j such that lt_marker[j] is 2. To constrain this:
151        //
152        // * When c_lt_mark = 1 the sum of all lt_marker[i] must be 1
153        // * When c_lt_mark = 2 the sum of lt_marker[i] * (lt_marker[i] - 1) must be 2.
154        //   Additionally, the sum of all lt_marker[i] must be 3.
155        //
156        // All this doesn't apply when is_setup.
157        let lt_marker_sum = cols
158            .lt_marker
159            .iter()
160            .fold(AB::Expr::ZERO, |acc, x| acc + *x);
161        let lt_marker_one_check_sum = cols
162            .lt_marker
163            .iter()
164            .fold(AB::Expr::ZERO, |acc, x| acc + (*x) * (*x - AB::F::ONE));
165
166        // Constrain that c_lt_mark is either 1 or 2.
167        builder
168            .when(cols.is_valid - cols.is_setup)
169            .assert_bool(cols.c_lt_mark - AB::F::ONE);
170
171        // If c_lt_mark is 1, then lt_marker_sum is 1
172        builder
173            .when(cols.is_valid - cols.is_setup)
174            .when_ne(cols.c_lt_mark, AB::F::from_canonical_u8(2))
175            .assert_one(lt_marker_sum.clone());
176
177        // If c_lt_mark is 2, then lt_marker_sum is 3
178        builder
179            .when(cols.is_valid - cols.is_setup)
180            .when_ne(cols.c_lt_mark, AB::F::ONE)
181            .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(3));
182
183        // This constraint, along with the constraint (below) that lt_marker[i] is 0, 1, or 2,
184        // ensures that lt_marker has exactly one 2.
185        builder.when_ne(cols.c_lt_mark, AB::F::ONE).assert_eq(
186            lt_marker_one_check_sum,
187            cols.is_valid * AB::F::from_canonical_u8(2),
188        );
189
190        // Handle the setup row constraints.
191        // When is_setup = 1, constrain c_lt_mark = 2 and lt_marker_sum = 2
192        // This ensures that lt_marker has exactly one 2 and the remaining entries are 0.
193        // Since lt_marker has no 1, we will end up constraining that b[i] = N[i] for all i
194        // instead of just for i > b_diff_idx.
195        builder
196            .when(cols.is_setup)
197            .assert_eq(cols.c_lt_mark, AB::F::from_canonical_u8(2));
198        builder
199            .when(cols.is_setup)
200            .assert_eq(lt_marker_sum.clone(), AB::F::from_canonical_u8(2));
201
202        // Constrain that b, c < N (i.e. modulus).
203        let modulus = self.modulus_limbs.map(AB::F::from_canonical_u32);
204        let mut prefix_sum = AB::Expr::ZERO;
205
206        for i in (0..READ_LIMBS).rev() {
207            prefix_sum += cols.lt_marker[i].into();
208            builder.assert_zero(
209                cols.lt_marker[i]
210                    * (cols.lt_marker[i] - AB::F::ONE)
211                    * (cols.lt_marker[i] - cols.c_lt_mark),
212            );
213
214            // Constrain b < N.
215            // First, we constrain b[i] = N[i] for i > b_diff_idx.
216            // We do this by constraining that b[i] = N[i] when prefix_sum is not 1 or
217            // lt_marker_sum.
218            //  - If is_setup = 0, then lt_marker_sum is either 1 or 3. In this case, prefix_sum is
219            //    0, 1, 2, or 3. It can be verified by casework that i > b_diff_idx iff prefix_sum
220            //    is not 1 or lt_marker_sum.
221            //  - If is_setup = 1, then we want to constrain b[i] = N[i] for all i. In this case,
222            //    lt_marker_sum is 2 and prefix_sum is 0 or 2. So we constrain b[i] = N[i] when
223            //    prefix_sum is not 1, which works.
224            builder
225                .when_ne(prefix_sum.clone(), AB::F::ONE)
226                .when_ne(prefix_sum.clone(), lt_marker_sum.clone() - cols.is_setup)
227                .assert_eq(cols.b[i], modulus[i]);
228            // Note that lt_marker[i] is either 0, 1, or 2 and lt_marker[i] being 1 indicates b[i] <
229            // N[i] (i.e. i == b_diff_idx).
230            builder
231                .when_ne(cols.lt_marker[i], AB::F::ZERO)
232                .when_ne(cols.lt_marker[i], AB::F::from_canonical_u8(2))
233                .assert_eq(AB::Expr::from(modulus[i]) - cols.b[i], cols.b_lt_diff);
234
235            // Constrain c < N.
236            // First, we constrain c[i] = N[i] for i > c_diff_idx.
237            // We do this by constraining that c[i] = N[i] when prefix_sum is not c_lt_mark or
238            // lt_marker_sum. It can be verified by casework that i > c_diff_idx iff
239            // prefix_sum is not c_lt_mark or lt_marker_sum.
240            builder
241                .when_ne(prefix_sum.clone(), cols.c_lt_mark)
242                .when_ne(prefix_sum.clone(), lt_marker_sum.clone())
243                .assert_eq(cols.c[i], modulus[i]);
244            // Note that lt_marker[i] is either 0, 1, or 2 and lt_marker[i] being c_lt_mark
245            // indicates c[i] < N[i] (i.e. i == c_diff_idx). Since c_lt_mark is 1 or 2,
246            // we have {0, 1, 2} \ {0, 3 - c_lt_mark} = {c_lt_mark}.
247            builder
248                .when_ne(cols.lt_marker[i], AB::F::ZERO)
249                .when_ne(
250                    cols.lt_marker[i],
251                    AB::Expr::from_canonical_u8(3) - cols.c_lt_mark,
252                )
253                .assert_eq(AB::Expr::from(modulus[i]) - cols.c[i], cols.c_lt_diff);
254        }
255
256        // Check that b_lt_diff and c_lt_diff are positive
257        self.bus
258            .send_range(
259                cols.b_lt_diff - AB::Expr::ONE,
260                cols.c_lt_diff - AB::Expr::ONE,
261            )
262            .eval(builder, cols.is_valid - cols.is_setup);
263
264        let expected_opcode = AB::Expr::from_canonical_usize(self.offset)
265            + cols.is_setup
266                * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize)
267            + (AB::Expr::ONE - cols.is_setup)
268                * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::IS_EQ as usize);
269        let mut a: [AB::Expr; WRITE_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
270        a[0] = cols.cmp_result.into();
271
272        AdapterAirContext {
273            to_pc: None,
274            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
275            writes: [a].into(),
276            instruction: MinimalInstruction {
277                is_valid: cols.is_valid.into(),
278                opcode: expected_opcode,
279            }
280            .into(),
281        }
282    }
283
284    fn start_offset(&self) -> usize {
285        self.offset
286    }
287}
288
289#[repr(C)]
290#[derive(AlignedBytesBorrow, Debug)]
291pub struct ModularIsEqualRecord<const READ_LIMBS: usize> {
292    pub is_setup: bool,
293    pub b: [u8; READ_LIMBS],
294    pub c: [u8; READ_LIMBS],
295}
296
297#[derive(derive_new::new, Clone)]
298pub struct ModularIsEqualExecutor<
299    A,
300    const READ_LIMBS: usize,
301    const WRITE_LIMBS: usize,
302    const LIMB_BITS: usize,
303> {
304    adapter: A,
305    pub offset: usize,
306    pub modulus_limbs: [u8; READ_LIMBS],
307}
308
309#[derive(derive_new::new, Clone)]
310pub struct ModularIsEqualFiller<
311    A,
312    const READ_LIMBS: usize,
313    const WRITE_LIMBS: usize,
314    const LIMB_BITS: usize,
315> {
316    adapter: A,
317    pub offset: usize,
318    pub modulus_limbs: [u8; READ_LIMBS],
319    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
320}
321
322impl<F, A, RA, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize>
323    PreflightExecutor<F, RA> for ModularIsEqualExecutor<A, READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
324where
325    F: PrimeField32,
326    A: 'static
327        + AdapterTraceExecutor<
328            F,
329            ReadData: Into<[[u8; READ_LIMBS]; 2]>,
330            WriteData: From<[u8; WRITE_LIMBS]>,
331        >,
332    for<'buf> RA: RecordArena<
333        'buf,
334        EmptyAdapterCoreLayout<F, A>,
335        (
336            A::RecordMut<'buf>,
337            &'buf mut ModularIsEqualRecord<READ_LIMBS>,
338        ),
339    >,
340{
341    fn execute(
342        &self,
343        state: VmStateMut<F, TracingMemory, RA>,
344        instruction: &Instruction<F>,
345    ) -> Result<(), ExecutionError> {
346        let Instruction { opcode, .. } = instruction;
347
348        let local_opcode =
349            Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.offset));
350        matches!(
351            local_opcode,
352            Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ
353        );
354
355        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
356
357        A::start(*state.pc, state.memory, &mut adapter_record);
358        [core_record.b, core_record.c] = self
359            .adapter
360            .read(state.memory, instruction, &mut adapter_record)
361            .into();
362
363        core_record.is_setup = instruction.opcode.local_opcode_idx(self.offset)
364            == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize;
365
366        let mut write_data = [0u8; WRITE_LIMBS];
367        write_data[0] = (core_record.b == core_record.c) as u8;
368
369        self.adapter.write(
370            state.memory,
371            instruction,
372            write_data.into(),
373            &mut adapter_record,
374        );
375
376        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
377
378        Ok(())
379    }
380
381    fn get_opcode_name(&self, opcode: usize) -> String {
382        format!(
383            "{:?}",
384            Rv32ModularArithmeticOpcode::from_usize(opcode - self.offset)
385        )
386    }
387}
388
389impl<F, A, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
390    for ModularIsEqualFiller<A, READ_LIMBS, WRITE_LIMBS, LIMB_BITS>
391where
392    F: PrimeField32,
393    A: 'static + AdapterTraceFiller<F>,
394{
395    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
396        let (adapter_row, mut core_row) = row_slice.split_at_mut(A::WIDTH);
397        self.adapter.fill_trace_row(mem_helper, adapter_row);
398        // SAFETY:
399        // - row_slice is guaranteed by the caller to have at least A::WIDTH +
400        //   ModularIsEqualCoreCols::width() elements
401        // - caller ensures core_row contains a valid record written by the executor during trace
402        //   generation
403        let record: &ModularIsEqualRecord<READ_LIMBS> =
404            unsafe { get_record_from_slice(&mut core_row, ()) };
405        let cols: &mut ModularIsEqualCoreCols<F, READ_LIMBS> = core_row.borrow_mut();
406        let (b_cmp, b_diff_idx) =
407            run_unsigned_less_than::<READ_LIMBS>(&record.b, &self.modulus_limbs);
408        let (c_cmp, c_diff_idx) =
409            run_unsigned_less_than::<READ_LIMBS>(&record.c, &self.modulus_limbs);
410
411        if !record.is_setup {
412            assert!(b_cmp, "{:?} >= {:?}", record.b, self.modulus_limbs);
413        }
414        assert!(c_cmp, "{:?} >= {:?}", record.c, self.modulus_limbs);
415
416        // Writing in reverse order
417        cols.c_lt_mark = if b_diff_idx == c_diff_idx {
418            F::ONE
419        } else {
420            F::TWO
421        };
422
423        cols.c_lt_diff =
424            F::from_canonical_u8(self.modulus_limbs[c_diff_idx] - record.c[c_diff_idx]);
425        if !record.is_setup {
426            cols.b_lt_diff =
427                F::from_canonical_u8(self.modulus_limbs[b_diff_idx] - record.b[b_diff_idx]);
428            self.bitwise_lookup_chip.request_range(
429                (self.modulus_limbs[b_diff_idx] - record.b[b_diff_idx] - 1) as u32,
430                (self.modulus_limbs[c_diff_idx] - record.c[c_diff_idx] - 1) as u32,
431            );
432        } else {
433            cols.b_lt_diff = F::ZERO;
434        }
435
436        cols.lt_marker = from_fn(|i| {
437            if i == b_diff_idx {
438                F::ONE
439            } else if i == c_diff_idx {
440                cols.c_lt_mark
441            } else {
442                F::ZERO
443            }
444        });
445
446        cols.c = record.c.map(F::from_canonical_u8);
447        cols.b = record.b.map(F::from_canonical_u8);
448        let sub_air = IsEqArraySubAir::<READ_LIMBS>;
449        sub_air.generate_subrow(
450            (&cols.b, &cols.c),
451            (&mut cols.eq_marker, &mut cols.cmp_result),
452        );
453
454        cols.is_setup = F::from_bool(record.is_setup);
455        cols.is_valid = F::ONE;
456    }
457}
458
459impl<const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_LIMBS: usize>
460    VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_LIMBS>
461{
462    pub fn new(
463        adapter: Rv32IsEqualModAdapterExecutor<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>,
464        offset: usize,
465        modulus_limbs: [u8; TOTAL_LIMBS],
466    ) -> Self {
467        Self(ModularIsEqualExecutor::new(adapter, offset, modulus_limbs))
468    }
469}
470
471#[derive(AlignedBytesBorrow, Clone)]
472#[repr(C)]
473struct ModularIsEqualPreCompute<const READ_LIMBS: usize> {
474    a: u8,
475    rs_addrs: [u8; 2],
476    modulus_limbs: [u8; READ_LIMBS],
477}
478
479impl<const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_READ_SIZE: usize>
480    VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE>
481{
482    fn pre_compute_impl<F: PrimeField32>(
483        &self,
484        pc: u32,
485        inst: &Instruction<F>,
486        data: &mut ModularIsEqualPreCompute<TOTAL_READ_SIZE>,
487    ) -> Result<bool, StaticProgramError> {
488        let Instruction {
489            opcode,
490            a,
491            b,
492            c,
493            d,
494            e,
495            ..
496        } = inst;
497
498        let local_opcode =
499            Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.0.offset));
500
501        // Validate instruction format
502        let a = a.as_canonical_u32();
503        let b = b.as_canonical_u32();
504        let c = c.as_canonical_u32();
505        let d = d.as_canonical_u32();
506        let e = e.as_canonical_u32();
507        if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS {
508            return Err(StaticProgramError::InvalidInstruction(pc));
509        }
510
511        if !matches!(
512            local_opcode,
513            Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ
514        ) {
515            return Err(StaticProgramError::InvalidInstruction(pc));
516        }
517
518        let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8);
519        *data = ModularIsEqualPreCompute {
520            a: a as u8,
521            rs_addrs,
522            modulus_limbs: self.0.modulus_limbs,
523        };
524
525        let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ISEQ;
526
527        Ok(is_setup)
528    }
529}
530
531macro_rules! dispatch {
532    ($execute_impl:ident, $is_setup:ident) => {
533        Ok(if $is_setup {
534            $execute_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true>
535        } else {
536            $execute_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false>
537        })
538    };
539}
540
541impl<F, const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_READ_SIZE: usize> Executor<F>
542    for VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE>
543where
544    F: PrimeField32,
545{
546    #[inline(always)]
547    fn pre_compute_size(&self) -> usize {
548        std::mem::size_of::<ModularIsEqualPreCompute<TOTAL_READ_SIZE>>()
549    }
550
551    fn pre_compute<Ctx: ExecutionCtxTrait>(
552        &self,
553        pc: u32,
554        inst: &Instruction<F>,
555        data: &mut [u8],
556    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
557        let pre_compute: &mut ModularIsEqualPreCompute<TOTAL_READ_SIZE> = data.borrow_mut();
558        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
559
560        dispatch!(execute_e1_impl, is_setup)
561    }
562
563    #[cfg(feature = "tco")]
564    fn handler<Ctx>(
565        &self,
566        pc: u32,
567        inst: &Instruction<F>,
568        data: &mut [u8],
569    ) -> Result<Handler<F, Ctx>, StaticProgramError>
570    where
571        Ctx: ExecutionCtxTrait,
572    {
573        let pre_compute: &mut ModularIsEqualPreCompute<TOTAL_READ_SIZE> = data.borrow_mut();
574        let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
575
576        dispatch!(execute_e1_tco_handler, is_setup)
577    }
578}
579
580impl<F, const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_READ_SIZE: usize>
581    MeteredExecutor<F> for VmModularIsEqualExecutor<NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE>
582where
583    F: PrimeField32,
584{
585    #[inline(always)]
586    fn metered_pre_compute_size(&self) -> usize {
587        std::mem::size_of::<E2PreCompute<ModularIsEqualPreCompute<TOTAL_READ_SIZE>>>()
588    }
589
590    fn metered_pre_compute<Ctx: MeteredExecutionCtxTrait>(
591        &self,
592        chip_idx: usize,
593        pc: u32,
594        inst: &Instruction<F>,
595        data: &mut [u8],
596    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
597        let pre_compute: &mut E2PreCompute<ModularIsEqualPreCompute<TOTAL_READ_SIZE>> =
598            data.borrow_mut();
599        pre_compute.chip_idx = chip_idx as u32;
600
601        let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
602
603        dispatch!(execute_e2_impl, is_setup)
604    }
605
606    #[cfg(feature = "tco")]
607    fn metered_handler<Ctx: MeteredExecutionCtxTrait>(
608        &self,
609        chip_idx: usize,
610        pc: u32,
611        inst: &Instruction<F>,
612        data: &mut [u8],
613    ) -> Result<Handler<F, Ctx>, StaticProgramError> {
614        let pre_compute: &mut E2PreCompute<ModularIsEqualPreCompute<TOTAL_READ_SIZE>> =
615            data.borrow_mut();
616        pre_compute.chip_idx = chip_idx as u32;
617
618        let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
619
620        dispatch!(execute_e2_tco_handler, is_setup)
621    }
622}
623
624#[create_tco_handler]
625unsafe fn execute_e1_impl<
626    F: PrimeField32,
627    CTX: ExecutionCtxTrait,
628    const NUM_LANES: usize,
629    const LANE_SIZE: usize,
630    const TOTAL_READ_SIZE: usize,
631    const IS_SETUP: bool,
632>(
633    pre_compute: &[u8],
634    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
635) {
636    let pre_compute: &ModularIsEqualPreCompute<TOTAL_READ_SIZE> = pre_compute.borrow();
637
638    execute_e12_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, IS_SETUP>(
639        pre_compute,
640        vm_state,
641    );
642}
643
644#[create_tco_handler]
645unsafe fn execute_e2_impl<
646    F: PrimeField32,
647    CTX: MeteredExecutionCtxTrait,
648    const NUM_LANES: usize,
649    const LANE_SIZE: usize,
650    const TOTAL_READ_SIZE: usize,
651    const IS_SETUP: bool,
652>(
653    pre_compute: &[u8],
654    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
655) {
656    let pre_compute: &E2PreCompute<ModularIsEqualPreCompute<TOTAL_READ_SIZE>> =
657        pre_compute.borrow();
658    vm_state
659        .ctx
660        .on_height_change(pre_compute.chip_idx as usize, 1);
661    execute_e12_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, IS_SETUP>(
662        &pre_compute.data,
663        vm_state,
664    );
665}
666
667unsafe fn execute_e12_impl<
668    F: PrimeField32,
669    CTX: ExecutionCtxTrait,
670    const NUM_LANES: usize,
671    const LANE_SIZE: usize,
672    const TOTAL_READ_SIZE: usize,
673    const IS_SETUP: bool,
674>(
675    pre_compute: &ModularIsEqualPreCompute<TOTAL_READ_SIZE>,
676    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
677) {
678    // Read register values
679    let rs_vals = pre_compute
680        .rs_addrs
681        .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32)));
682
683    // Read memory values
684    let [b, c]: [[u8; TOTAL_READ_SIZE]; 2] = rs_vals.map(|address| {
685        debug_assert!(address as usize + TOTAL_READ_SIZE - 1 < (1 << POINTER_MAX_BITS));
686        from_fn::<_, NUM_LANES, _>(|i| {
687            vm_state.vm_read::<_, LANE_SIZE>(RV32_MEMORY_AS, address + (i * LANE_SIZE) as u32)
688        })
689        .concat()
690        .try_into()
691        .unwrap()
692    });
693
694    if !IS_SETUP {
695        let (b_cmp, _) = run_unsigned_less_than::<TOTAL_READ_SIZE>(&b, &pre_compute.modulus_limbs);
696        debug_assert!(b_cmp, "{:?} >= {:?}", b, pre_compute.modulus_limbs);
697    }
698
699    let (c_cmp, _) = run_unsigned_less_than::<TOTAL_READ_SIZE>(&c, &pre_compute.modulus_limbs);
700    debug_assert!(c_cmp, "{:?} >= {:?}", c, pre_compute.modulus_limbs);
701
702    // Compute result
703    let mut write_data = [0u8; RV32_REGISTER_NUM_LIMBS];
704    write_data[0] = (b == c) as u8;
705
706    // Write result to register
707    vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &write_data);
708
709    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
710    vm_state.instret += 1;
711}
712
713// Returns (cmp_result, diff_idx)
714#[inline(always)]
715pub(super) fn run_unsigned_less_than<const NUM_LIMBS: usize>(
716    x: &[u8; NUM_LIMBS],
717    y: &[u8; NUM_LIMBS],
718) -> (bool, usize) {
719    for i in (0..NUM_LIMBS).rev() {
720        if x[i] != y[i] {
721            return (x[i] < y[i], i);
722        }
723    }
724    (false, NUM_LIMBS)
725}