1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32BranchEqual256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7 instruction::Instruction,
8 program::DEFAULT_PC_STEP,
9 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10 LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
13use openvm_rv32im_circuit::BranchEqualExecutor;
14use openvm_rv32im_transpiler::BranchEqualOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{common::bytes_to_u64_array, Rv32BranchEqual256Executor, INT256_NUM_LIMBS};
18
19type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
20
21impl Rv32BranchEqual256Executor {
22 pub fn new(adapter_step: AdapterExecutor, offset: usize, pc_step: u32) -> Self {
23 Self(BranchEqualExecutor::new(adapter_step, offset, pc_step))
24 }
25}
26
27#[derive(AlignedBytesBorrow, Clone)]
28#[repr(C)]
29struct BranchEqPreCompute {
30 imm: isize,
31 a: u8,
32 b: u8,
33}
34
35macro_rules! dispatch {
36 ($execute_impl:ident, $local_opcode:ident) => {
37 match $local_opcode {
38 BranchEqualOpcode::BEQ => Ok($execute_impl::<_, _, false>),
39 BranchEqualOpcode::BNE => Ok($execute_impl::<_, _, true>),
40 }
41 };
42}
43
44impl<F: PrimeField32> Executor<F> for Rv32BranchEqual256Executor {
45 fn pre_compute_size(&self) -> usize {
46 size_of::<BranchEqPreCompute>()
47 }
48
49 #[cfg(not(feature = "tco"))]
50 fn pre_compute<Ctx>(
51 &self,
52 pc: u32,
53 inst: &Instruction<F>,
54 data: &mut [u8],
55 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
56 where
57 Ctx: ExecutionCtxTrait,
58 {
59 let data: &mut BranchEqPreCompute = data.borrow_mut();
60 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
61 dispatch!(execute_e1_handler, local_opcode)
62 }
63
64 #[cfg(feature = "tco")]
65 fn handler<Ctx>(
66 &self,
67 pc: u32,
68 inst: &Instruction<F>,
69 data: &mut [u8],
70 ) -> Result<Handler<F, Ctx>, StaticProgramError>
71 where
72 Ctx: ExecutionCtxTrait,
73 {
74 let data: &mut BranchEqPreCompute = data.borrow_mut();
75 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
76 dispatch!(execute_e1_handler, local_opcode)
77 }
78}
79
80impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchEqual256Executor {
81 fn metered_pre_compute_size(&self) -> usize {
82 size_of::<E2PreCompute<BranchEqPreCompute>>()
83 }
84
85 #[cfg(not(feature = "tco"))]
86 fn metered_pre_compute<Ctx>(
87 &self,
88 chip_idx: usize,
89 pc: u32,
90 inst: &Instruction<F>,
91 data: &mut [u8],
92 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
93 where
94 Ctx: MeteredExecutionCtxTrait,
95 {
96 let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
97 data.chip_idx = chip_idx as u32;
98 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
99 dispatch!(execute_e2_handler, local_opcode)
100 }
101
102 #[cfg(feature = "tco")]
103 fn metered_handler<Ctx>(
104 &self,
105 chip_idx: usize,
106 pc: u32,
107 inst: &Instruction<F>,
108 data: &mut [u8],
109 ) -> Result<Handler<F, Ctx>, StaticProgramError>
110 where
111 Ctx: MeteredExecutionCtxTrait,
112 {
113 let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
114 data.chip_idx = chip_idx as u32;
115 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
116 dispatch!(execute_e2_handler, local_opcode)
117 }
118}
119
120#[inline(always)]
121unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
122 pre_compute: &BranchEqPreCompute,
123 instret: &mut u64,
124 pc: &mut u32,
125 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
126) {
127 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
128 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
129 let rs1 =
130 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
131 let rs2 =
132 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
133 let cmp_result = u256_eq(rs1, rs2);
134 if cmp_result ^ IS_NE {
135 *pc = (*pc as isize + pre_compute.imm) as u32;
136 } else {
137 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
138 }
139
140 *instret += 1;
141}
142
143#[create_handler]
144#[inline(always)]
145unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
146 pre_compute: &[u8],
147 instret: &mut u64,
148 pc: &mut u32,
149 _instret_end: u64,
150 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
151) {
152 let pre_compute: &BranchEqPreCompute = pre_compute.borrow();
153 execute_e12_impl::<F, CTX, IS_NE>(pre_compute, instret, pc, exec_state);
154}
155
156#[create_handler]
157#[inline(always)]
158unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
159 pre_compute: &[u8],
160 instret: &mut u64,
161 pc: &mut u32,
162 _arg: u64,
163 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
164) {
165 let pre_compute: &E2PreCompute<BranchEqPreCompute> = pre_compute.borrow();
166 exec_state
167 .ctx
168 .on_height_change(pre_compute.chip_idx as usize, 1);
169 execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, instret, pc, exec_state);
170}
171
172impl Rv32BranchEqual256Executor {
173 fn pre_compute_impl<F: PrimeField32>(
174 &self,
175 pc: u32,
176 inst: &Instruction<F>,
177 data: &mut BranchEqPreCompute,
178 ) -> Result<BranchEqualOpcode, StaticProgramError> {
179 let Instruction {
180 opcode,
181 a,
182 b,
183 c,
184 d,
185 e,
186 ..
187 } = inst;
188 let c = c.as_canonical_u32();
189 let imm = if F::ORDER_U32 - c < c {
190 -((F::ORDER_U32 - c) as isize)
191 } else {
192 c as isize
193 };
194 let e_u32 = e.as_canonical_u32();
195 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
196 return Err(StaticProgramError::InvalidInstruction(pc));
197 }
198 *data = BranchEqPreCompute {
199 imm,
200 a: a.as_canonical_u32() as u8,
201 b: b.as_canonical_u32() as u8,
202 };
203 let local_opcode = BranchEqualOpcode::from_usize(
204 opcode.local_opcode_idx(Rv32BranchEqual256Opcode::CLASS_OFFSET),
205 );
206 Ok(local_opcode)
207 }
208}
209
210fn u256_eq(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
211 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
212 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
213 for i in 0..4 {
214 if rs1_u64[i] != rs2_u64[i] {
215 return false;
216 }
217 }
218 true
219}