openvm_rv32im_circuit/branch_eq/
execution.rs1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode,
10};
11use openvm_rv32im_transpiler::BranchEqualOpcode;
12use openvm_stark_backend::p3_field::PrimeField32;
13
14use super::BranchEqualExecutor;
15#[cfg(feature = "aot")]
16use crate::common::{
17 update_adapter_heights_asm, update_height_change_asm, xmm_to_gpr, REG_A_W, REG_B_W,
18};
19
20#[derive(AlignedBytesBorrow, Clone)]
21#[repr(C)]
22struct BranchEqualPreCompute {
23 imm: isize,
24 a: u8,
25 b: u8,
26}
27
28impl<A, const NUM_LIMBS: usize> BranchEqualExecutor<A, NUM_LIMBS> {
29 #[inline(always)]
31 fn pre_compute_impl<F: PrimeField32>(
32 &self,
33 pc: u32,
34 inst: &Instruction<F>,
35 data: &mut BranchEqualPreCompute,
36 ) -> Result<bool, StaticProgramError> {
37 let data: &mut BranchEqualPreCompute = data.borrow_mut();
38 let &Instruction {
39 opcode, a, b, c, d, ..
40 } = inst;
41 let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset));
42 let c = c.as_canonical_u32();
43 let imm = if F::ORDER_U32 - c < c {
44 -((F::ORDER_U32 - c) as isize)
45 } else {
46 c as isize
47 };
48 if d.as_canonical_u32() != RV32_REGISTER_AS {
49 return Err(StaticProgramError::InvalidInstruction(pc));
50 }
51 *data = BranchEqualPreCompute {
52 imm,
53 a: a.as_canonical_u32() as u8,
54 b: b.as_canonical_u32() as u8,
55 };
56 Ok(local_opcode == BranchEqualOpcode::BNE)
57 }
58}
59
60macro_rules! dispatch {
61 ($execute_impl:ident, $is_bne:ident) => {
62 if $is_bne {
63 Ok($execute_impl::<_, _, true>)
64 } else {
65 Ok($execute_impl::<_, _, false>)
66 }
67 };
68}
69
70impl<F, A, const NUM_LIMBS: usize> InterpreterExecutor<F> for BranchEqualExecutor<A, NUM_LIMBS>
71where
72 F: PrimeField32,
73{
74 #[inline(always)]
75 fn pre_compute_size(&self) -> usize {
76 size_of::<BranchEqualPreCompute>()
77 }
78
79 #[cfg(not(feature = "tco"))]
80 #[inline(always)]
81 fn pre_compute<Ctx: ExecutionCtxTrait>(
82 &self,
83 pc: u32,
84 inst: &Instruction<F>,
85 data: &mut [u8],
86 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
87 let data: &mut BranchEqualPreCompute = data.borrow_mut();
88 let is_bne = self.pre_compute_impl(pc, inst, data)?;
89 dispatch!(execute_e1_handler, is_bne)
90 }
91
92 #[cfg(feature = "tco")]
93 fn handler<Ctx>(
94 &self,
95 pc: u32,
96 inst: &Instruction<F>,
97 data: &mut [u8],
98 ) -> Result<Handler<F, Ctx>, StaticProgramError>
99 where
100 Ctx: ExecutionCtxTrait,
101 {
102 let data: &mut BranchEqualPreCompute = data.borrow_mut();
103 let is_bne = self.pre_compute_impl(pc, inst, data)?;
104 dispatch!(execute_e1_handler, is_bne)
105 }
106}
107
108#[cfg(feature = "aot")]
109impl<F, A, const NUM_LIMBS: usize> AotExecutor<F> for BranchEqualExecutor<A, NUM_LIMBS>
110where
111 F: PrimeField32,
112{
113 fn generate_x86_asm(&self, inst: &Instruction<F>, pc: u32) -> Result<String, AotError> {
114 let &Instruction {
115 opcode, a, b, c, d, ..
116 } = inst;
117 let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset));
118 let c = c.as_canonical_u32();
119 let imm = if F::ORDER_U32 - c < c {
120 -((F::ORDER_U32 - c) as isize)
121 } else {
122 c as isize
123 };
124 let next_pc = (pc as isize + imm) as u32;
125 if d.as_canonical_u32() != RV32_REGISTER_AS {
127 return Err(AotError::InvalidInstruction);
128 }
129 let a = a.as_canonical_u32() as u8;
130 let b = b.as_canonical_u32() as u8;
131
132 let mut asm_str = String::new();
133 let a_reg = a / 4;
134 let b_reg = b / 4;
135
136 let (reg_a, delta_str_a) = &xmm_to_gpr(a_reg, REG_A_W, false);
138 asm_str += delta_str_a;
139 let (reg_b, delta_str_b) = &xmm_to_gpr(b_reg, REG_B_W, false);
140 asm_str += delta_str_b;
141 asm_str += &format!(" cmp {reg_a}, {reg_b}\n");
142 let not_jump_label = format!(".asm_execute_pc_{pc}_not_jump");
143 match local_opcode {
144 BranchEqualOpcode::BEQ => {
145 asm_str += &format!(" jne {not_jump_label}\n");
146 asm_str += &format!(" jmp asm_execute_pc_{next_pc}\n");
147 }
148 BranchEqualOpcode::BNE => {
149 asm_str += &format!(" je {not_jump_label}\n");
150 asm_str += &format!(" jmp asm_execute_pc_{next_pc}\n");
151 }
152 }
153 asm_str += &format!("{not_jump_label}:\n");
154
155 Ok(asm_str)
156 }
157
158 fn is_aot_supported(&self, _inst: &Instruction<F>) -> bool {
159 true
160 }
161}
162
163impl<F, A, const NUM_LIMBS: usize> InterpreterMeteredExecutor<F>
164 for BranchEqualExecutor<A, NUM_LIMBS>
165where
166 F: PrimeField32,
167{
168 fn metered_pre_compute_size(&self) -> usize {
169 size_of::<E2PreCompute<BranchEqualPreCompute>>()
170 }
171
172 #[cfg(not(feature = "tco"))]
173 fn metered_pre_compute<Ctx>(
174 &self,
175 chip_idx: usize,
176 pc: u32,
177 inst: &Instruction<F>,
178 data: &mut [u8],
179 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
180 where
181 Ctx: MeteredExecutionCtxTrait,
182 {
183 let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
184 data.chip_idx = chip_idx as u32;
185 let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
186 dispatch!(execute_e2_handler, is_bne)
187 }
188
189 #[cfg(feature = "tco")]
190 fn metered_handler<Ctx>(
191 &self,
192 chip_idx: usize,
193 pc: u32,
194 inst: &Instruction<F>,
195 data: &mut [u8],
196 ) -> Result<Handler<F, Ctx>, StaticProgramError>
197 where
198 Ctx: MeteredExecutionCtxTrait,
199 {
200 let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
201 data.chip_idx = chip_idx as u32;
202 let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
203 dispatch!(execute_e2_handler, is_bne)
204 }
205}
206#[cfg(feature = "aot")]
207impl<F, A, const NUM_LIMBS: usize> AotMeteredExecutor<F> for BranchEqualExecutor<A, NUM_LIMBS>
208where
209 F: PrimeField32,
210{
211 fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
212 true
213 }
214 fn generate_x86_metered_asm(
215 &self,
216 inst: &Instruction<F>,
217 pc: u32,
218 chip_idx: usize,
219 config: &SystemConfig,
220 ) -> Result<String, AotError> {
221 let mut asm_str = String::from("");
222
223 asm_str += &update_height_change_asm(chip_idx, 1)?;
224 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
226 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
228
229 asm_str += &self.generate_x86_asm(inst, pc)?;
230 Ok(asm_str)
231 }
232}
233
234#[inline(always)]
235unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
236 pre_compute: &BranchEqualPreCompute,
237 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
238) {
239 let mut pc = exec_state.pc();
240 let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
241 let rs2 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
242 if (rs1 == rs2) ^ IS_NE {
243 pc = (pc as isize + pre_compute.imm) as u32;
244 } else {
245 pc = pc.wrapping_add(DEFAULT_PC_STEP);
246 }
247 exec_state.set_pc(pc);
248}
249
250#[create_handler]
251#[inline(always)]
252unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
253 pre_compute: *const u8,
254 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
255) {
256 let pre_compute: &BranchEqualPreCompute =
257 std::slice::from_raw_parts(pre_compute, size_of::<BranchEqualPreCompute>()).borrow();
258 execute_e12_impl::<F, CTX, IS_NE>(pre_compute, exec_state);
259}
260
261#[create_handler]
262#[inline(always)]
263unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
264 pre_compute: *const u8,
265 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
266) {
267 let pre_compute: &E2PreCompute<BranchEqualPreCompute> = std::slice::from_raw_parts(
268 pre_compute,
269 size_of::<E2PreCompute<BranchEqualPreCompute>>(),
270 )
271 .borrow();
272 exec_state
273 .ctx
274 .on_height_change(pre_compute.chip_idx as usize, 1);
275 execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, exec_state);
276}