openvm_bigint_circuit/
branch_eq.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32BranchEqual256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7    instruction::Instruction,
8    program::DEFAULT_PC_STEP,
9    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10    LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
13use openvm_rv32im_circuit::BranchEqualExecutor;
14use openvm_rv32im_transpiler::BranchEqualOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{common::bytes_to_u64_array, Rv32BranchEqual256Executor, INT256_NUM_LIMBS};
18
19type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
20
21impl Rv32BranchEqual256Executor {
22    pub fn new(adapter_step: AdapterExecutor, offset: usize, pc_step: u32) -> Self {
23        Self(BranchEqualExecutor::new(adapter_step, offset, pc_step))
24    }
25}
26
27#[derive(AlignedBytesBorrow, Clone)]
28#[repr(C)]
29struct BranchEqPreCompute {
30    imm: isize,
31    a: u8,
32    b: u8,
33}
34
35macro_rules! dispatch {
36    ($execute_impl:ident, $local_opcode:ident) => {
37        match $local_opcode {
38            BranchEqualOpcode::BEQ => Ok($execute_impl::<_, _, false>),
39            BranchEqualOpcode::BNE => Ok($execute_impl::<_, _, true>),
40        }
41    };
42}
43
44impl<F: PrimeField32> InterpreterExecutor<F> for Rv32BranchEqual256Executor {
45    fn pre_compute_size(&self) -> usize {
46        size_of::<BranchEqPreCompute>()
47    }
48
49    #[cfg(not(feature = "tco"))]
50    fn pre_compute<Ctx>(
51        &self,
52        pc: u32,
53        inst: &Instruction<F>,
54        data: &mut [u8],
55    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
56    where
57        Ctx: ExecutionCtxTrait,
58    {
59        let data: &mut BranchEqPreCompute = data.borrow_mut();
60        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
61        dispatch!(execute_e1_handler, local_opcode)
62    }
63
64    #[cfg(feature = "tco")]
65    fn handler<Ctx>(
66        &self,
67        pc: u32,
68        inst: &Instruction<F>,
69        data: &mut [u8],
70    ) -> Result<Handler<F, Ctx>, StaticProgramError>
71    where
72        Ctx: ExecutionCtxTrait,
73    {
74        let data: &mut BranchEqPreCompute = data.borrow_mut();
75        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
76        dispatch!(execute_e1_handler, local_opcode)
77    }
78}
79
80#[cfg(feature = "aot")]
81impl<F: PrimeField32> AotExecutor<F> for Rv32BranchEqual256Executor {}
82
83impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32BranchEqual256Executor {
84    fn metered_pre_compute_size(&self) -> usize {
85        size_of::<E2PreCompute<BranchEqPreCompute>>()
86    }
87
88    #[cfg(not(feature = "tco"))]
89    fn metered_pre_compute<Ctx>(
90        &self,
91        chip_idx: usize,
92        pc: u32,
93        inst: &Instruction<F>,
94        data: &mut [u8],
95    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
96    where
97        Ctx: MeteredExecutionCtxTrait,
98    {
99        let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
100        data.chip_idx = chip_idx as u32;
101        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
102        dispatch!(execute_e2_handler, local_opcode)
103    }
104
105    #[cfg(feature = "tco")]
106    fn metered_handler<Ctx>(
107        &self,
108        chip_idx: usize,
109        pc: u32,
110        inst: &Instruction<F>,
111        data: &mut [u8],
112    ) -> Result<Handler<F, Ctx>, StaticProgramError>
113    where
114        Ctx: MeteredExecutionCtxTrait,
115    {
116        let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
117        data.chip_idx = chip_idx as u32;
118        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
119        dispatch!(execute_e2_handler, local_opcode)
120    }
121}
122
123#[cfg(feature = "aot")]
124impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32BranchEqual256Executor {}
125
126#[inline(always)]
127unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
128    pre_compute: &BranchEqPreCompute,
129    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
130) {
131    let mut pc = exec_state.pc();
132    let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
133    let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
134    let rs1 =
135        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
136    let rs2 =
137        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
138    let cmp_result = u256_eq(rs1, rs2);
139    if cmp_result ^ IS_NE {
140        pc = (pc as isize + pre_compute.imm) as u32;
141    } else {
142        pc = pc.wrapping_add(DEFAULT_PC_STEP);
143    }
144    exec_state.set_pc(pc);
145}
146
147#[create_handler]
148#[inline(always)]
149unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
150    pre_compute: *const u8,
151    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
152) {
153    let pre_compute: &BranchEqPreCompute =
154        std::slice::from_raw_parts(pre_compute, size_of::<BranchEqPreCompute>()).borrow();
155    execute_e12_impl::<F, CTX, IS_NE>(pre_compute, exec_state);
156}
157
158#[create_handler]
159#[inline(always)]
160unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
161    pre_compute: *const u8,
162    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
163) {
164    let pre_compute: &E2PreCompute<BranchEqPreCompute> =
165        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<BranchEqPreCompute>>())
166            .borrow();
167    exec_state
168        .ctx
169        .on_height_change(pre_compute.chip_idx as usize, 1);
170    execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, exec_state);
171}
172
173impl Rv32BranchEqual256Executor {
174    fn pre_compute_impl<F: PrimeField32>(
175        &self,
176        pc: u32,
177        inst: &Instruction<F>,
178        data: &mut BranchEqPreCompute,
179    ) -> Result<BranchEqualOpcode, StaticProgramError> {
180        let Instruction {
181            opcode,
182            a,
183            b,
184            c,
185            d,
186            e,
187            ..
188        } = inst;
189        let c = c.as_canonical_u32();
190        let imm = if F::ORDER_U32 - c < c {
191            -((F::ORDER_U32 - c) as isize)
192        } else {
193            c as isize
194        };
195        let e_u32 = e.as_canonical_u32();
196        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
197            return Err(StaticProgramError::InvalidInstruction(pc));
198        }
199        *data = BranchEqPreCompute {
200            imm,
201            a: a.as_canonical_u32() as u8,
202            b: b.as_canonical_u32() as u8,
203        };
204        let local_opcode = BranchEqualOpcode::from_usize(
205            opcode.local_opcode_idx(Rv32BranchEqual256Opcode::CLASS_OFFSET),
206        );
207        Ok(local_opcode)
208    }
209}
210
211fn u256_eq(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
212    let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
213    let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
214    for i in 0..4 {
215        if rs1_u64[i] != rs2_u64[i] {
216            return false;
217        }
218    }
219    true
220}