openvm_rv32im_circuit/jalr/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::{DEFAULT_PC_STEP, PC_BITS},
11    riscv::RV32_REGISTER_AS,
12};
13use openvm_stark_backend::p3_field::PrimeField32;
14
15use super::core::Rv32JalrExecutor;
16#[cfg(feature = "aot")]
17use crate::common::*;
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21struct JalrPreCompute {
22    imm_extended: u32,
23    a: u8,
24    b: u8,
25}
26
27impl<A> Rv32JalrExecutor<A> {
28    /// Return true if enabled.
29    fn pre_compute_impl<F: PrimeField32>(
30        &self,
31        pc: u32,
32        inst: &Instruction<F>,
33        data: &mut JalrPreCompute,
34    ) -> Result<bool, StaticProgramError> {
35        let imm_extended = inst.c.as_canonical_u32() + inst.g.as_canonical_u32() * 0xffff0000;
36        if inst.d.as_canonical_u32() != RV32_REGISTER_AS {
37            return Err(StaticProgramError::InvalidInstruction(pc));
38        }
39        *data = JalrPreCompute {
40            imm_extended,
41            a: inst.a.as_canonical_u32() as u8,
42            b: inst.b.as_canonical_u32() as u8,
43        };
44        let enabled = !inst.f.is_zero();
45        Ok(enabled)
46    }
47}
48
49macro_rules! dispatch {
50    ($execute_impl:ident, $enabled:ident) => {
51        if $enabled {
52            Ok($execute_impl::<_, _, true>)
53        } else {
54            Ok($execute_impl::<_, _, false>)
55        }
56    };
57}
58
59impl<F, A> InterpreterExecutor<F> for Rv32JalrExecutor<A>
60where
61    F: PrimeField32,
62{
63    #[inline(always)]
64    fn pre_compute_size(&self) -> usize {
65        size_of::<JalrPreCompute>()
66    }
67    #[cfg(not(feature = "tco"))]
68    #[inline(always)]
69    fn pre_compute<Ctx: ExecutionCtxTrait>(
70        &self,
71        pc: u32,
72        inst: &Instruction<F>,
73        data: &mut [u8],
74    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
75        let data: &mut JalrPreCompute = data.borrow_mut();
76        let enabled = self.pre_compute_impl(pc, inst, data)?;
77        dispatch!(execute_e1_handler, enabled)
78    }
79
80    #[cfg(feature = "tco")]
81    fn handler<Ctx>(
82        &self,
83        pc: u32,
84        inst: &Instruction<F>,
85        data: &mut [u8],
86    ) -> Result<Handler<F, Ctx>, StaticProgramError>
87    where
88        Ctx: ExecutionCtxTrait,
89    {
90        let data: &mut JalrPreCompute = data.borrow_mut();
91        let enabled = self.pre_compute_impl(pc, inst, data)?;
92        dispatch!(execute_e1_handler, enabled)
93    }
94}
95
96#[cfg(feature = "aot")]
97impl<F, A> AotExecutor<F> for Rv32JalrExecutor<A>
98where
99    F: PrimeField32,
100{
101    fn is_aot_supported(&self, _inst: &Instruction<F>) -> bool {
102        true
103    }
104
105    fn generate_x86_asm(&self, inst: &Instruction<F>, pc: u32) -> Result<String, AotError> {
106        let mut asm_str = String::new();
107        let to_i16 = |c: F| -> i16 {
108            let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
109            let c_i24 = ((c_u24 << 8) as i32) >> 8;
110            c_i24 as i16
111        };
112        let a = to_i16(inst.a);
113        let b = to_i16(inst.b);
114        if a % 4 != 0 || b % 4 != 0 {
115            return Err(AotError::InvalidInstruction);
116        }
117        let imm_extended = inst.c.as_canonical_u32() + inst.g.as_canonical_u32() * 0xffff0000;
118        let write_rd = !inst.f.is_zero();
119
120        let (gpr_reg_b, delta_b) = xmm_to_gpr((b / 4) as u8, REG_B_W, true);
121        asm_str += &delta_b;
122        asm_str += &format!("   add {gpr_reg_b}, {imm_extended}\n");
123        asm_str += &format!("   and {gpr_reg_b}, -2\n"); // clear bit 0 per RISC-V jalr
124
125        let gpr_reg_b_64 = convert_x86_reg(&gpr_reg_b, Width::W64).unwrap();
126
127        if write_rd {
128            let next_pc = pc.wrapping_add(DEFAULT_PC_STEP);
129            asm_str += &format!("   mov {REG_A_W}, {next_pc}\n");
130            asm_str += &gpr_to_xmm(REG_A_W, (a / 4) as u8);
131        }
132
133        asm_str += &format!("   lea {REG_C}, [rip + map_pc_base]\n");
134        asm_str += &format!("   movsxd {REG_A}, [{REG_C} + {gpr_reg_b_64}]\n");
135        asm_str += &format!("   add {REG_A}, {REG_C}\n");
136        asm_str += &format!("   jmp {REG_A}\n");
137        Ok(asm_str)
138    }
139}
140
141impl<F, A> InterpreterMeteredExecutor<F> for Rv32JalrExecutor<A>
142where
143    F: PrimeField32,
144{
145    fn metered_pre_compute_size(&self) -> usize {
146        size_of::<E2PreCompute<JalrPreCompute>>()
147    }
148
149    #[cfg(not(feature = "tco"))]
150    fn metered_pre_compute<Ctx>(
151        &self,
152        chip_idx: usize,
153        pc: u32,
154        inst: &Instruction<F>,
155        data: &mut [u8],
156    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
157    where
158        Ctx: MeteredExecutionCtxTrait,
159    {
160        let data: &mut E2PreCompute<JalrPreCompute> = data.borrow_mut();
161        data.chip_idx = chip_idx as u32;
162        let enabled = self.pre_compute_impl(pc, inst, &mut data.data)?;
163        dispatch!(execute_e2_handler, enabled)
164    }
165
166    #[cfg(feature = "tco")]
167    fn metered_handler<Ctx>(
168        &self,
169        chip_idx: usize,
170        pc: u32,
171        inst: &Instruction<F>,
172        data: &mut [u8],
173    ) -> Result<Handler<F, Ctx>, StaticProgramError>
174    where
175        Ctx: MeteredExecutionCtxTrait,
176    {
177        let data: &mut E2PreCompute<JalrPreCompute> = data.borrow_mut();
178        data.chip_idx = chip_idx as u32;
179        let enabled = self.pre_compute_impl(pc, inst, &mut data.data)?;
180        dispatch!(execute_e2_handler, enabled)
181    }
182}
183
184#[cfg(feature = "aot")]
185impl<F, A> AotMeteredExecutor<F> for Rv32JalrExecutor<A>
186where
187    F: PrimeField32,
188{
189    fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
190        true
191    }
192    fn generate_x86_metered_asm(
193        &self,
194        inst: &Instruction<F>,
195        pc: u32,
196        chip_idx: usize,
197        config: &SystemConfig,
198    ) -> Result<String, AotError> {
199        let enabled = !inst.f.is_zero();
200        let mut asm_str = update_height_change_asm(chip_idx, 1)?;
201        // read [b:4]_1
202        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
203        if enabled {
204            // write [a:4]_1
205            asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
206        }
207        asm_str += &self.generate_x86_asm(inst, pc)?;
208        Ok(asm_str)
209    }
210}
211#[inline(always)]
212unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const ENABLED: bool>(
213    pre_compute: &JalrPreCompute,
214    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
215) {
216    let pc = exec_state.pc();
217    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
218    let rs1 = u32::from_le_bytes(rs1);
219    let to_pc = rs1.wrapping_add(pre_compute.imm_extended);
220    let to_pc = to_pc - (to_pc & 1);
221    debug_assert!(to_pc < (1 << PC_BITS));
222    let rd = (pc + DEFAULT_PC_STEP).to_le_bytes();
223
224    if ENABLED {
225        exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
226    }
227
228    exec_state.set_pc(to_pc);
229}
230
231#[create_handler]
232#[inline(always)]
233unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const ENABLED: bool>(
234    pre_compute: *const u8,
235    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
236) {
237    let pre_compute: &JalrPreCompute =
238        std::slice::from_raw_parts(pre_compute, size_of::<JalrPreCompute>()).borrow();
239    execute_e12_impl::<F, CTX, ENABLED>(pre_compute, exec_state);
240}
241
242#[create_handler]
243#[inline(always)]
244unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const ENABLED: bool>(
245    pre_compute: *const u8,
246    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
247) {
248    let pre_compute: &E2PreCompute<JalrPreCompute> =
249        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<JalrPreCompute>>()).borrow();
250    exec_state
251        .ctx
252        .on_height_change(pre_compute.chip_idx as usize, 1);
253    execute_e12_impl::<F, CTX, ENABLED>(&pre_compute.data, exec_state);
254}