openvm_rv32im_circuit/jal_lui/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode,
10};
11use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, JAL};
12use openvm_stark_backend::p3_field::PrimeField32;
13
14use super::core::{get_signed_imm, Rv32JalLuiExecutor};
15
16#[derive(AlignedBytesBorrow, Clone)]
17#[repr(C)]
18struct JalLuiPreCompute {
19    signed_imm: i32,
20    a: u8,
21}
22
23impl<A> Rv32JalLuiExecutor<A> {
24    /// Return (IS_JAL, ENABLED)
25    #[inline(always)]
26    fn pre_compute_impl<F: PrimeField32>(
27        &self,
28        inst: &Instruction<F>,
29        data: &mut JalLuiPreCompute,
30    ) -> Result<(bool, bool), StaticProgramError> {
31        let local_opcode = Rv32JalLuiOpcode::from_usize(
32            inst.opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET),
33        );
34        let is_jal = local_opcode == JAL;
35        let signed_imm = get_signed_imm(is_jal, inst.c);
36
37        *data = JalLuiPreCompute {
38            signed_imm,
39            a: inst.a.as_canonical_u32() as u8,
40        };
41        let enabled = !inst.f.is_zero();
42        Ok((is_jal, enabled))
43    }
44}
45
46macro_rules! dispatch {
47    ($execute_impl:ident, $is_jal:ident, $enabled:ident) => {
48        match ($is_jal, $enabled) {
49            (true, true) => Ok($execute_impl::<_, _, true, true>),
50            (true, false) => Ok($execute_impl::<_, _, true, false>),
51            (false, true) => Ok($execute_impl::<_, _, false, true>),
52            (false, false) => Ok($execute_impl::<_, _, false, false>),
53        }
54    };
55}
56
57impl<F, A> InterpreterExecutor<F> for Rv32JalLuiExecutor<A>
58where
59    F: PrimeField32,
60{
61    #[inline(always)]
62    fn pre_compute_size(&self) -> usize {
63        size_of::<JalLuiPreCompute>()
64    }
65
66    #[cfg(not(feature = "tco"))]
67    fn pre_compute<Ctx: ExecutionCtxTrait>(
68        &self,
69        _pc: u32,
70        inst: &Instruction<F>,
71        data: &mut [u8],
72    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
73        let data: &mut JalLuiPreCompute = data.borrow_mut();
74        let (is_jal, enabled) = self.pre_compute_impl(inst, data)?;
75        dispatch!(execute_e1_handler, is_jal, enabled)
76    }
77
78    #[cfg(feature = "tco")]
79    fn handler<Ctx>(
80        &self,
81        _pc: u32,
82        inst: &Instruction<F>,
83        data: &mut [u8],
84    ) -> Result<Handler<F, Ctx>, StaticProgramError>
85    where
86        Ctx: ExecutionCtxTrait,
87    {
88        let data: &mut JalLuiPreCompute = data.borrow_mut();
89        let (is_jal, enabled) = self.pre_compute_impl(inst, data)?;
90        dispatch!(execute_e1_handler, is_jal, enabled)
91    }
92}
93
94#[cfg(feature = "aot")]
95impl<F, A> AotExecutor<F> for Rv32JalLuiExecutor<A>
96where
97    F: PrimeField32,
98{
99    fn generate_x86_asm(&self, inst: &Instruction<F>, pc: u32) -> Result<String, AotError> {
100        use crate::common::*;
101
102        let local_opcode = Rv32JalLuiOpcode::from_usize(
103            inst.opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET),
104        );
105        let is_jal = local_opcode == JAL;
106        let signed_imm = get_signed_imm(is_jal, inst.c);
107        let a = inst.a.as_canonical_u32() as u8;
108        let enabled = !inst.f.is_zero();
109
110        let mut asm_str = String::new();
111        let a_reg = a / 4;
112
113        let rd = if is_jal {
114            pc + DEFAULT_PC_STEP
115        } else {
116            let imm = signed_imm as u32;
117            imm << 12
118        };
119
120        if enabled {
121            if let Some(override_reg) = RISCV_TO_X86_OVERRIDE_MAP[a_reg as usize] {
122                asm_str += &format!("   mov {override_reg}, {rd}\n");
123            } else {
124                asm_str += &format!("   mov {REG_A_W}, {rd}\n");
125                asm_str += &gpr_to_xmm(REG_A_W, a_reg);
126            }
127        }
128        if is_jal {
129            let next_pc = pc as i32 + signed_imm;
130            debug_assert!(next_pc >= 0);
131            asm_str += &format!("   jmp asm_execute_pc_{next_pc}\n");
132        };
133
134        Ok(asm_str)
135    }
136
137    fn is_aot_supported(&self, _inst: &Instruction<F>) -> bool {
138        true
139    }
140}
141
142impl<F, A> InterpreterMeteredExecutor<F> for Rv32JalLuiExecutor<A>
143where
144    F: PrimeField32,
145{
146    fn metered_pre_compute_size(&self) -> usize {
147        size_of::<E2PreCompute<JalLuiPreCompute>>()
148    }
149
150    #[cfg(not(feature = "tco"))]
151    fn metered_pre_compute<Ctx>(
152        &self,
153        chip_idx: usize,
154        _pc: u32,
155        inst: &Instruction<F>,
156        data: &mut [u8],
157    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
158    where
159        Ctx: MeteredExecutionCtxTrait,
160    {
161        let data: &mut E2PreCompute<JalLuiPreCompute> = data.borrow_mut();
162        data.chip_idx = chip_idx as u32;
163        let (is_jal, enabled) = self.pre_compute_impl(inst, &mut data.data)?;
164        dispatch!(execute_e2_handler, is_jal, enabled)
165    }
166
167    #[cfg(feature = "tco")]
168    fn metered_handler<Ctx>(
169        &self,
170        chip_idx: usize,
171        _pc: u32,
172        inst: &Instruction<F>,
173        data: &mut [u8],
174    ) -> Result<Handler<F, Ctx>, StaticProgramError>
175    where
176        Ctx: MeteredExecutionCtxTrait,
177    {
178        let data: &mut E2PreCompute<JalLuiPreCompute> = data.borrow_mut();
179        data.chip_idx = chip_idx as u32;
180        let (is_jal, enabled) = self.pre_compute_impl(inst, &mut data.data)?;
181        dispatch!(execute_e2_handler, is_jal, enabled)
182    }
183}
184
185#[cfg(feature = "aot")]
186impl<F, A> AotMeteredExecutor<F> for Rv32JalLuiExecutor<A>
187where
188    F: PrimeField32,
189{
190    fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
191        true
192    }
193    fn generate_x86_metered_asm(
194        &self,
195        inst: &Instruction<F>,
196        pc: u32,
197        chip_idx: usize,
198        config: &SystemConfig,
199    ) -> Result<String, AotError> {
200        use crate::common::{update_adapter_heights_asm, update_height_change_asm};
201        let enabled = !inst.f.is_zero();
202        let mut asm_str = update_height_change_asm(chip_idx, 1)?;
203        if enabled {
204            // write [a:4]_1
205            asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
206        }
207        asm_str += &self.generate_x86_asm(inst, pc)?;
208        Ok(asm_str)
209    }
210}
211
212#[inline(always)]
213unsafe fn execute_e12_impl<
214    F: PrimeField32,
215    CTX: ExecutionCtxTrait,
216    const IS_JAL: bool,
217    const ENABLED: bool,
218>(
219    pre_compute: &JalLuiPreCompute,
220    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
221) {
222    let JalLuiPreCompute { a, signed_imm } = *pre_compute;
223    let mut pc = exec_state.pc();
224    let rd = if IS_JAL {
225        let rd_data = (pc + DEFAULT_PC_STEP).to_le_bytes();
226        let next_pc = pc as i32 + signed_imm;
227        debug_assert!(next_pc >= 0);
228        pc = next_pc as u32;
229        rd_data
230    } else {
231        let imm = signed_imm as u32;
232        let rd = imm << 12;
233        pc += DEFAULT_PC_STEP;
234        rd.to_le_bytes()
235    };
236
237    if ENABLED {
238        exec_state.vm_write(RV32_REGISTER_AS, a as u32, &rd);
239    }
240    exec_state.set_pc(pc);
241}
242
243#[create_handler]
244#[inline(always)]
245unsafe fn execute_e1_impl<
246    F: PrimeField32,
247    CTX: ExecutionCtxTrait,
248    const IS_JAL: bool,
249    const ENABLED: bool,
250>(
251    pre_compute: *const u8,
252    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
253) {
254    let pre_compute: &JalLuiPreCompute =
255        std::slice::from_raw_parts(pre_compute, size_of::<JalLuiPreCompute>()).borrow();
256    execute_e12_impl::<F, CTX, IS_JAL, ENABLED>(pre_compute, exec_state);
257}
258
259#[create_handler]
260#[inline(always)]
261unsafe fn execute_e2_impl<
262    F: PrimeField32,
263    CTX: MeteredExecutionCtxTrait,
264    const IS_JAL: bool,
265    const ENABLED: bool,
266>(
267    pre_compute: *const u8,
268    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
269) {
270    let pre_compute: &E2PreCompute<JalLuiPreCompute> =
271        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<JalLuiPreCompute>>())
272            .borrow();
273    exec_state
274        .ctx
275        .on_height_change(pre_compute.chip_idx as usize, 1);
276    execute_e12_impl::<F, CTX, IS_JAL, ENABLED>(&pre_compute.data, exec_state);
277}