1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32BranchEqual256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7 instruction::Instruction,
8 program::DEFAULT_PC_STEP,
9 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10 LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
13use openvm_rv32im_circuit::BranchEqualExecutor;
14use openvm_rv32im_transpiler::BranchEqualOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{common::bytes_to_u64_array, Rv32BranchEqual256Executor, INT256_NUM_LIMBS};
18
19type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
20
21impl Rv32BranchEqual256Executor {
22 pub fn new(adapter_step: AdapterExecutor, offset: usize, pc_step: u32) -> Self {
23 Self(BranchEqualExecutor::new(adapter_step, offset, pc_step))
24 }
25}
26
27#[derive(AlignedBytesBorrow, Clone)]
28#[repr(C)]
29struct BranchEqPreCompute {
30 imm: isize,
31 a: u8,
32 b: u8,
33}
34
35macro_rules! dispatch {
36 ($execute_impl:ident, $local_opcode:ident) => {
37 match $local_opcode {
38 BranchEqualOpcode::BEQ => Ok($execute_impl::<_, _, false>),
39 BranchEqualOpcode::BNE => Ok($execute_impl::<_, _, true>),
40 }
41 };
42}
43
44impl<F: PrimeField32> Executor<F> for Rv32BranchEqual256Executor {
45 fn pre_compute_size(&self) -> usize {
46 size_of::<BranchEqPreCompute>()
47 }
48
49 fn pre_compute<Ctx>(
50 &self,
51 pc: u32,
52 inst: &Instruction<F>,
53 data: &mut [u8],
54 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
55 where
56 Ctx: ExecutionCtxTrait,
57 {
58 let data: &mut BranchEqPreCompute = data.borrow_mut();
59 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
60 dispatch!(execute_e1_impl, local_opcode)
61 }
62
63 #[cfg(feature = "tco")]
64 fn handler<Ctx>(
65 &self,
66 pc: u32,
67 inst: &Instruction<F>,
68 data: &mut [u8],
69 ) -> Result<Handler<F, Ctx>, StaticProgramError>
70 where
71 Ctx: ExecutionCtxTrait,
72 {
73 let data: &mut BranchEqPreCompute = data.borrow_mut();
74 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
75 dispatch!(execute_e1_tco_handler, local_opcode)
76 }
77}
78
79impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchEqual256Executor {
80 fn metered_pre_compute_size(&self) -> usize {
81 size_of::<E2PreCompute<BranchEqPreCompute>>()
82 }
83
84 fn metered_pre_compute<Ctx>(
85 &self,
86 chip_idx: usize,
87 pc: u32,
88 inst: &Instruction<F>,
89 data: &mut [u8],
90 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
91 where
92 Ctx: MeteredExecutionCtxTrait,
93 {
94 let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
95 data.chip_idx = chip_idx as u32;
96 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
97 dispatch!(execute_e2_impl, local_opcode)
98 }
99
100 #[cfg(feature = "tco")]
101 fn metered_handler<Ctx>(
102 &self,
103 chip_idx: usize,
104 pc: u32,
105 inst: &Instruction<F>,
106 data: &mut [u8],
107 ) -> Result<Handler<F, Ctx>, StaticProgramError>
108 where
109 Ctx: MeteredExecutionCtxTrait,
110 {
111 let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
112 data.chip_idx = chip_idx as u32;
113 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
114 dispatch!(execute_e2_tco_handler, local_opcode)
115 }
116}
117
118#[inline(always)]
119unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
120 pre_compute: &BranchEqPreCompute,
121 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
122) {
123 let rs1_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
124 let rs2_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
125 let rs1 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
126 let rs2 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
127 let cmp_result = u256_eq(rs1, rs2);
128 if cmp_result ^ IS_NE {
129 vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32;
130 } else {
131 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
132 }
133
134 vm_state.instret += 1;
135}
136
137#[create_tco_handler]
138unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
139 pre_compute: &[u8],
140 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
141) {
142 let pre_compute: &BranchEqPreCompute = pre_compute.borrow();
143 execute_e12_impl::<F, CTX, IS_NE>(pre_compute, vm_state);
144}
145
146#[create_tco_handler]
147unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
148 pre_compute: &[u8],
149 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
150) {
151 let pre_compute: &E2PreCompute<BranchEqPreCompute> = pre_compute.borrow();
152 vm_state
153 .ctx
154 .on_height_change(pre_compute.chip_idx as usize, 1);
155 execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, vm_state);
156}
157
158impl Rv32BranchEqual256Executor {
159 fn pre_compute_impl<F: PrimeField32>(
160 &self,
161 pc: u32,
162 inst: &Instruction<F>,
163 data: &mut BranchEqPreCompute,
164 ) -> Result<BranchEqualOpcode, StaticProgramError> {
165 let Instruction {
166 opcode,
167 a,
168 b,
169 c,
170 d,
171 e,
172 ..
173 } = inst;
174 let c = c.as_canonical_u32();
175 let imm = if F::ORDER_U32 - c < c {
176 -((F::ORDER_U32 - c) as isize)
177 } else {
178 c as isize
179 };
180 let e_u32 = e.as_canonical_u32();
181 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
182 return Err(StaticProgramError::InvalidInstruction(pc));
183 }
184 *data = BranchEqPreCompute {
185 imm,
186 a: a.as_canonical_u32() as u8,
187 b: b.as_canonical_u32() as u8,
188 };
189 let local_opcode = BranchEqualOpcode::from_usize(
190 opcode.local_opcode_idx(Rv32BranchEqual256Opcode::CLASS_OFFSET),
191 );
192 Ok(local_opcode)
193 }
194}
195
196fn u256_eq(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
197 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
198 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
199 for i in 0..4 {
200 if rs1_u64[i] != rs2_u64[i] {
201 return false;
202 }
203 }
204 true
205}