openvm_rv32im_circuit/branch_eq/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode,
10};
11use openvm_rv32im_transpiler::BranchEqualOpcode;
12use openvm_stark_backend::p3_field::PrimeField32;
13
14use super::BranchEqualExecutor;
15#[cfg(feature = "aot")]
16use crate::common::{
17    update_adapter_heights_asm, update_height_change_asm, xmm_to_gpr, REG_A_W, REG_B_W,
18};
19
20#[derive(AlignedBytesBorrow, Clone)]
21#[repr(C)]
22struct BranchEqualPreCompute {
23    imm: isize,
24    a: u8,
25    b: u8,
26}
27
28impl<A, const NUM_LIMBS: usize> BranchEqualExecutor<A, NUM_LIMBS> {
29    /// Return `is_bne`, true if the local opcode is BNE.
30    #[inline(always)]
31    fn pre_compute_impl<F: PrimeField32>(
32        &self,
33        pc: u32,
34        inst: &Instruction<F>,
35        data: &mut BranchEqualPreCompute,
36    ) -> Result<bool, StaticProgramError> {
37        let data: &mut BranchEqualPreCompute = data.borrow_mut();
38        let &Instruction {
39            opcode, a, b, c, d, ..
40        } = inst;
41        let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset));
42        let c = c.as_canonical_u32();
43        let imm = if F::ORDER_U32 - c < c {
44            -((F::ORDER_U32 - c) as isize)
45        } else {
46            c as isize
47        };
48        if d.as_canonical_u32() != RV32_REGISTER_AS {
49            return Err(StaticProgramError::InvalidInstruction(pc));
50        }
51        *data = BranchEqualPreCompute {
52            imm,
53            a: a.as_canonical_u32() as u8,
54            b: b.as_canonical_u32() as u8,
55        };
56        Ok(local_opcode == BranchEqualOpcode::BNE)
57    }
58}
59
60macro_rules! dispatch {
61    ($execute_impl:ident, $is_bne:ident) => {
62        if $is_bne {
63            Ok($execute_impl::<_, _, true>)
64        } else {
65            Ok($execute_impl::<_, _, false>)
66        }
67    };
68}
69
70impl<F, A, const NUM_LIMBS: usize> InterpreterExecutor<F> for BranchEqualExecutor<A, NUM_LIMBS>
71where
72    F: PrimeField32,
73{
74    #[inline(always)]
75    fn pre_compute_size(&self) -> usize {
76        size_of::<BranchEqualPreCompute>()
77    }
78
79    #[cfg(not(feature = "tco"))]
80    #[inline(always)]
81    fn pre_compute<Ctx: ExecutionCtxTrait>(
82        &self,
83        pc: u32,
84        inst: &Instruction<F>,
85        data: &mut [u8],
86    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
87        let data: &mut BranchEqualPreCompute = data.borrow_mut();
88        let is_bne = self.pre_compute_impl(pc, inst, data)?;
89        dispatch!(execute_e1_handler, is_bne)
90    }
91
92    #[cfg(feature = "tco")]
93    fn handler<Ctx>(
94        &self,
95        pc: u32,
96        inst: &Instruction<F>,
97        data: &mut [u8],
98    ) -> Result<Handler<F, Ctx>, StaticProgramError>
99    where
100        Ctx: ExecutionCtxTrait,
101    {
102        let data: &mut BranchEqualPreCompute = data.borrow_mut();
103        let is_bne = self.pre_compute_impl(pc, inst, data)?;
104        dispatch!(execute_e1_handler, is_bne)
105    }
106}
107
108#[cfg(feature = "aot")]
109impl<F, A, const NUM_LIMBS: usize> AotExecutor<F> for BranchEqualExecutor<A, NUM_LIMBS>
110where
111    F: PrimeField32,
112{
113    fn generate_x86_asm(&self, inst: &Instruction<F>, pc: u32) -> Result<String, AotError> {
114        let &Instruction {
115            opcode, a, b, c, d, ..
116        } = inst;
117        let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset));
118        let c = c.as_canonical_u32();
119        let imm = if F::ORDER_U32 - c < c {
120            -((F::ORDER_U32 - c) as isize)
121        } else {
122            c as isize
123        };
124        let next_pc = (pc as isize + imm) as u32;
125        // TODO: this should return an error instead.
126        if d.as_canonical_u32() != RV32_REGISTER_AS {
127            return Err(AotError::InvalidInstruction);
128        }
129        let a = a.as_canonical_u32() as u8;
130        let b = b.as_canonical_u32() as u8;
131
132        let mut asm_str = String::new();
133        let a_reg = a / 4;
134        let b_reg = b / 4;
135
136        // Calculate the result. Inputs: eax, ecx. Outputs: edx.
137        let (reg_a, delta_str_a) = &xmm_to_gpr(a_reg, REG_A_W, false);
138        asm_str += delta_str_a;
139        let (reg_b, delta_str_b) = &xmm_to_gpr(b_reg, REG_B_W, false);
140        asm_str += delta_str_b;
141        asm_str += &format!("   cmp {reg_a}, {reg_b}\n");
142        let not_jump_label = format!(".asm_execute_pc_{pc}_not_jump");
143        match local_opcode {
144            BranchEqualOpcode::BEQ => {
145                asm_str += &format!("   jne {not_jump_label}\n");
146                asm_str += &format!("   jmp asm_execute_pc_{next_pc}\n");
147            }
148            BranchEqualOpcode::BNE => {
149                asm_str += &format!("   je {not_jump_label}\n");
150                asm_str += &format!("   jmp asm_execute_pc_{next_pc}\n");
151            }
152        }
153        asm_str += &format!("{not_jump_label}:\n");
154
155        Ok(asm_str)
156    }
157
158    fn is_aot_supported(&self, _inst: &Instruction<F>) -> bool {
159        true
160    }
161}
162
163impl<F, A, const NUM_LIMBS: usize> InterpreterMeteredExecutor<F>
164    for BranchEqualExecutor<A, NUM_LIMBS>
165where
166    F: PrimeField32,
167{
168    fn metered_pre_compute_size(&self) -> usize {
169        size_of::<E2PreCompute<BranchEqualPreCompute>>()
170    }
171
172    #[cfg(not(feature = "tco"))]
173    fn metered_pre_compute<Ctx>(
174        &self,
175        chip_idx: usize,
176        pc: u32,
177        inst: &Instruction<F>,
178        data: &mut [u8],
179    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
180    where
181        Ctx: MeteredExecutionCtxTrait,
182    {
183        let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
184        data.chip_idx = chip_idx as u32;
185        let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
186        dispatch!(execute_e2_handler, is_bne)
187    }
188
189    #[cfg(feature = "tco")]
190    fn metered_handler<Ctx>(
191        &self,
192        chip_idx: usize,
193        pc: u32,
194        inst: &Instruction<F>,
195        data: &mut [u8],
196    ) -> Result<Handler<F, Ctx>, StaticProgramError>
197    where
198        Ctx: MeteredExecutionCtxTrait,
199    {
200        let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
201        data.chip_idx = chip_idx as u32;
202        let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
203        dispatch!(execute_e2_handler, is_bne)
204    }
205}
206#[cfg(feature = "aot")]
207impl<F, A, const NUM_LIMBS: usize> AotMeteredExecutor<F> for BranchEqualExecutor<A, NUM_LIMBS>
208where
209    F: PrimeField32,
210{
211    fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
212        true
213    }
214    fn generate_x86_metered_asm(
215        &self,
216        inst: &Instruction<F>,
217        pc: u32,
218        chip_idx: usize,
219        config: &SystemConfig,
220    ) -> Result<String, AotError> {
221        let mut asm_str = String::from("");
222
223        asm_str += &update_height_change_asm(chip_idx, 1)?;
224        // read [b:4]_1
225        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
226        // read [c:4]_1
227        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
228
229        asm_str += &self.generate_x86_asm(inst, pc)?;
230        Ok(asm_str)
231    }
232}
233
234#[inline(always)]
235unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
236    pre_compute: &BranchEqualPreCompute,
237    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
238) {
239    let mut pc = exec_state.pc();
240    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
241    let rs2 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
242    if (rs1 == rs2) ^ IS_NE {
243        pc = (pc as isize + pre_compute.imm) as u32;
244    } else {
245        pc = pc.wrapping_add(DEFAULT_PC_STEP);
246    }
247    exec_state.set_pc(pc);
248}
249
250#[create_handler]
251#[inline(always)]
252unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
253    pre_compute: *const u8,
254    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
255) {
256    let pre_compute: &BranchEqualPreCompute =
257        std::slice::from_raw_parts(pre_compute, size_of::<BranchEqualPreCompute>()).borrow();
258    execute_e12_impl::<F, CTX, IS_NE>(pre_compute, exec_state);
259}
260
261#[create_handler]
262#[inline(always)]
263unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
264    pre_compute: *const u8,
265    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
266) {
267    let pre_compute: &E2PreCompute<BranchEqualPreCompute> = std::slice::from_raw_parts(
268        pre_compute,
269        size_of::<E2PreCompute<BranchEqualPreCompute>>(),
270    )
271    .borrow();
272    exec_state
273        .ctx
274        .on_height_change(pre_compute.chip_idx as usize, 1);
275    execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, exec_state);
276}