1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32BranchEqual256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7 instruction::Instruction,
8 program::DEFAULT_PC_STEP,
9 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10 LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
13use openvm_rv32im_circuit::BranchEqualExecutor;
14use openvm_rv32im_transpiler::BranchEqualOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{common::bytes_to_u64_array, Rv32BranchEqual256Executor, INT256_NUM_LIMBS};
18
19type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
20
21impl Rv32BranchEqual256Executor {
22 pub fn new(adapter_step: AdapterExecutor, offset: usize, pc_step: u32) -> Self {
23 Self(BranchEqualExecutor::new(adapter_step, offset, pc_step))
24 }
25}
26
27#[derive(AlignedBytesBorrow, Clone)]
28#[repr(C)]
29struct BranchEqPreCompute {
30 imm: isize,
31 a: u8,
32 b: u8,
33}
34
35macro_rules! dispatch {
36 ($execute_impl:ident, $local_opcode:ident) => {
37 match $local_opcode {
38 BranchEqualOpcode::BEQ => Ok($execute_impl::<_, _, false>),
39 BranchEqualOpcode::BNE => Ok($execute_impl::<_, _, true>),
40 }
41 };
42}
43
44impl<F: PrimeField32> InterpreterExecutor<F> for Rv32BranchEqual256Executor {
45 fn pre_compute_size(&self) -> usize {
46 size_of::<BranchEqPreCompute>()
47 }
48
49 #[cfg(not(feature = "tco"))]
50 fn pre_compute<Ctx>(
51 &self,
52 pc: u32,
53 inst: &Instruction<F>,
54 data: &mut [u8],
55 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
56 where
57 Ctx: ExecutionCtxTrait,
58 {
59 let data: &mut BranchEqPreCompute = data.borrow_mut();
60 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
61 dispatch!(execute_e1_handler, local_opcode)
62 }
63
64 #[cfg(feature = "tco")]
65 fn handler<Ctx>(
66 &self,
67 pc: u32,
68 inst: &Instruction<F>,
69 data: &mut [u8],
70 ) -> Result<Handler<F, Ctx>, StaticProgramError>
71 where
72 Ctx: ExecutionCtxTrait,
73 {
74 let data: &mut BranchEqPreCompute = data.borrow_mut();
75 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
76 dispatch!(execute_e1_handler, local_opcode)
77 }
78}
79
80#[cfg(feature = "aot")]
81impl<F: PrimeField32> AotExecutor<F> for Rv32BranchEqual256Executor {}
82
83impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32BranchEqual256Executor {
84 fn metered_pre_compute_size(&self) -> usize {
85 size_of::<E2PreCompute<BranchEqPreCompute>>()
86 }
87
88 #[cfg(not(feature = "tco"))]
89 fn metered_pre_compute<Ctx>(
90 &self,
91 chip_idx: usize,
92 pc: u32,
93 inst: &Instruction<F>,
94 data: &mut [u8],
95 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
96 where
97 Ctx: MeteredExecutionCtxTrait,
98 {
99 let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
100 data.chip_idx = chip_idx as u32;
101 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
102 dispatch!(execute_e2_handler, local_opcode)
103 }
104
105 #[cfg(feature = "tco")]
106 fn metered_handler<Ctx>(
107 &self,
108 chip_idx: usize,
109 pc: u32,
110 inst: &Instruction<F>,
111 data: &mut [u8],
112 ) -> Result<Handler<F, Ctx>, StaticProgramError>
113 where
114 Ctx: MeteredExecutionCtxTrait,
115 {
116 let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
117 data.chip_idx = chip_idx as u32;
118 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
119 dispatch!(execute_e2_handler, local_opcode)
120 }
121}
122
123#[cfg(feature = "aot")]
124impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32BranchEqual256Executor {}
125
126#[inline(always)]
127unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
128 pre_compute: &BranchEqPreCompute,
129 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
130) {
131 let mut pc = exec_state.pc();
132 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
133 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
134 let rs1 =
135 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
136 let rs2 =
137 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
138 let cmp_result = u256_eq(rs1, rs2);
139 if cmp_result ^ IS_NE {
140 pc = (pc as isize + pre_compute.imm) as u32;
141 } else {
142 pc = pc.wrapping_add(DEFAULT_PC_STEP);
143 }
144 exec_state.set_pc(pc);
145}
146
147#[create_handler]
148#[inline(always)]
149unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
150 pre_compute: *const u8,
151 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
152) {
153 let pre_compute: &BranchEqPreCompute =
154 std::slice::from_raw_parts(pre_compute, size_of::<BranchEqPreCompute>()).borrow();
155 execute_e12_impl::<F, CTX, IS_NE>(pre_compute, exec_state);
156}
157
158#[create_handler]
159#[inline(always)]
160unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
161 pre_compute: *const u8,
162 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
163) {
164 let pre_compute: &E2PreCompute<BranchEqPreCompute> =
165 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<BranchEqPreCompute>>())
166 .borrow();
167 exec_state
168 .ctx
169 .on_height_change(pre_compute.chip_idx as usize, 1);
170 execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, exec_state);
171}
172
173impl Rv32BranchEqual256Executor {
174 fn pre_compute_impl<F: PrimeField32>(
175 &self,
176 pc: u32,
177 inst: &Instruction<F>,
178 data: &mut BranchEqPreCompute,
179 ) -> Result<BranchEqualOpcode, StaticProgramError> {
180 let Instruction {
181 opcode,
182 a,
183 b,
184 c,
185 d,
186 e,
187 ..
188 } = inst;
189 let c = c.as_canonical_u32();
190 let imm = if F::ORDER_U32 - c < c {
191 -((F::ORDER_U32 - c) as isize)
192 } else {
193 c as isize
194 };
195 let e_u32 = e.as_canonical_u32();
196 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
197 return Err(StaticProgramError::InvalidInstruction(pc));
198 }
199 *data = BranchEqPreCompute {
200 imm,
201 a: a.as_canonical_u32() as u8,
202 b: b.as_canonical_u32() as u8,
203 };
204 let local_opcode = BranchEqualOpcode::from_usize(
205 opcode.local_opcode_idx(Rv32BranchEqual256Opcode::CLASS_OFFSET),
206 );
207 Ok(local_opcode)
208 }
209}
210
211fn u256_eq(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
212 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
213 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
214 for i in 0..4 {
215 if rs1_u64[i] != rs2_u64[i] {
216 return false;
217 }
218 }
219 true
220}