openvm_rv32im_circuit/mulh/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::MulHOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17#[allow(unused_imports)]
18use crate::common::*;
19use crate::MulHExecutor;
20
21#[derive(AlignedBytesBorrow, Clone)]
22#[repr(C)]
23struct MulHPreCompute {
24    a: u8,
25    b: u8,
26    c: u8,
27}
28
29impl<A, const LIMB_BITS: usize> MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
30    #[inline(always)]
31    fn pre_compute_impl<F: PrimeField32>(
32        &self,
33        inst: &Instruction<F>,
34        data: &mut MulHPreCompute,
35    ) -> Result<MulHOpcode, StaticProgramError> {
36        *data = MulHPreCompute {
37            a: inst.a.as_canonical_u32() as u8,
38            b: inst.b.as_canonical_u32() as u8,
39            c: inst.c.as_canonical_u32() as u8,
40        };
41        Ok(MulHOpcode::from_usize(
42            inst.opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET),
43        ))
44    }
45}
46
47macro_rules! dispatch {
48    ($execute_impl:ident, $local_opcode:ident) => {
49        match $local_opcode {
50            MulHOpcode::MULH => Ok($execute_impl::<_, _, MulHOp>),
51            MulHOpcode::MULHSU => Ok($execute_impl::<_, _, MulHSuOp>),
52            MulHOpcode::MULHU => Ok($execute_impl::<_, _, MulHUOp>),
53        }
54    };
55}
56
57impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
58    for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
59where
60    F: PrimeField32,
61{
62    #[inline(always)]
63    fn pre_compute_size(&self) -> usize {
64        size_of::<MulHPreCompute>()
65    }
66
67    #[cfg(not(feature = "tco"))]
68    #[inline(always)]
69    fn pre_compute<Ctx: ExecutionCtxTrait>(
70        &self,
71        _pc: u32,
72        inst: &Instruction<F>,
73        data: &mut [u8],
74    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
75        let pre_compute: &mut MulHPreCompute = data.borrow_mut();
76        let local_opcode = self.pre_compute_impl(inst, pre_compute)?;
77        dispatch!(execute_e1_handler, local_opcode)
78    }
79
80    #[cfg(feature = "tco")]
81    fn handler<Ctx>(
82        &self,
83        _pc: u32,
84        inst: &Instruction<F>,
85        data: &mut [u8],
86    ) -> Result<Handler<F, Ctx>, StaticProgramError>
87    where
88        Ctx: ExecutionCtxTrait,
89    {
90        let pre_compute: &mut MulHPreCompute = data.borrow_mut();
91        let local_opcode = self.pre_compute_impl(inst, pre_compute)?;
92        dispatch!(execute_e1_handler, local_opcode)
93    }
94}
95
96#[cfg(feature = "aot")]
97impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
98    for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
99where
100    F: PrimeField32,
101{
102    fn is_aot_supported(&self, inst: &Instruction<F>) -> bool {
103        inst.opcode == MulHOpcode::MULH.global_opcode()
104            || inst.opcode == MulHOpcode::MULHSU.global_opcode()
105            || inst.opcode == MulHOpcode::MULHU.global_opcode()
106    }
107
108    fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
109        let to_i16 = |c: F| -> i16 {
110            let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
111            let c_i24 = ((c_u24 << 8) as i32) >> 8;
112            c_i24 as i16
113        };
114
115        let a = to_i16(inst.a);
116        let b = to_i16(inst.b);
117        let c = to_i16(inst.c);
118
119        if a % 4 != 0 || b % 4 != 0 || c % 4 != 0 {
120            return Err(AotError::InvalidInstruction);
121        }
122
123        let opcode = MulHOpcode::from_usize(inst.opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET));
124
125        let mut asm = String::new();
126
127        /*
128            for implicit multiplication, we need to load the multiplicand into `eax`
129            result of hi bits are always stored in `edx`
130            can't use REG_C_W, because it is edx, and it gets overridden
131        */
132        let (_, delta_str_b) = &xmm_to_gpr((b / 4) as u8, "eax", true);
133        let (gpr_reg_c, delta_str_c) = &xmm_to_gpr((c / 4) as u8, REG_A_W, false);
134        asm += delta_str_b;
135        asm += delta_str_c;
136        match opcode {
137            MulHOpcode::MULH => {
138                asm += &format!("   imul {gpr_reg_c}\n");
139                asm += &gpr_to_xmm("edx", (a / 4) as u8);
140            }
141            MulHOpcode::MULHSU => {
142                // free to modify edx:eax, since mul and imul operations modify anyways
143                asm += &format!("   mov {REG_B_W}, eax\n");
144                asm += &format!("   imul {gpr_reg_c}\n");
145                asm += "   mov eax, edx\n";
146                asm += &format!("   mov edx, {gpr_reg_c}\n");
147                asm += "   sar edx, 31\n";
148                asm += &format!("   and edx, {REG_B_W}\n");
149                asm += "   add eax, edx\n";
150                asm += &gpr_to_xmm("eax", (a / 4) as u8);
151            }
152            MulHOpcode::MULHU => {
153                asm += &format!("   mul {gpr_reg_c}\n");
154                asm += &gpr_to_xmm("edx", (a / 4) as u8);
155            }
156        }
157        Ok(asm)
158    }
159}
160
161impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
162    for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
163where
164    F: PrimeField32,
165{
166    fn metered_pre_compute_size(&self) -> usize {
167        size_of::<E2PreCompute<MulHPreCompute>>()
168    }
169
170    #[cfg(not(feature = "tco"))]
171    fn metered_pre_compute<Ctx>(
172        &self,
173        chip_idx: usize,
174        _pc: u32,
175        inst: &Instruction<F>,
176        data: &mut [u8],
177    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
178    where
179        Ctx: MeteredExecutionCtxTrait,
180    {
181        let pre_compute: &mut E2PreCompute<MulHPreCompute> = data.borrow_mut();
182        pre_compute.chip_idx = chip_idx as u32;
183        let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?;
184        dispatch!(execute_e2_handler, local_opcode)
185    }
186
187    #[cfg(feature = "tco")]
188    fn metered_handler<Ctx>(
189        &self,
190        chip_idx: usize,
191        _pc: u32,
192        inst: &Instruction<F>,
193        data: &mut [u8],
194    ) -> Result<Handler<F, Ctx>, StaticProgramError>
195    where
196        Ctx: MeteredExecutionCtxTrait,
197    {
198        let pre_compute: &mut E2PreCompute<MulHPreCompute> = data.borrow_mut();
199        pre_compute.chip_idx = chip_idx as u32;
200        let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?;
201        dispatch!(execute_e2_handler, local_opcode)
202    }
203}
204
205#[cfg(feature = "aot")]
206impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
207    for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
208where
209    F: PrimeField32,
210{
211    fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
212        true
213    }
214    fn generate_x86_metered_asm(
215        &self,
216        inst: &Instruction<F>,
217        pc: u32,
218        chip_idx: usize,
219        config: &SystemConfig,
220    ) -> Result<String, AotError> {
221        let mut asm_str = self.generate_x86_asm(inst, pc)?;
222
223        asm_str += &update_height_change_asm(chip_idx, 1)?;
224        // read [b:4]_1
225        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
226        // read [c:4]_1
227        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
228        // write [a:4]_1
229        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
230
231        Ok(asm_str)
232    }
233}
234
235#[inline(always)]
236unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: MulHOperation>(
237    pre_compute: &MulHPreCompute,
238    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
239) {
240    let rs1: [u8; RV32_REGISTER_NUM_LIMBS] =
241        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
242    let rs2: [u8; RV32_REGISTER_NUM_LIMBS] =
243        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
244    let rd = <OP as MulHOperation>::compute(rs1, rs2);
245    exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
246
247    let pc = exec_state.pc();
248    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
249}
250
251#[create_handler]
252#[inline(always)]
253unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: MulHOperation>(
254    pre_compute: *const u8,
255    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
256) {
257    let pre_compute: &MulHPreCompute =
258        std::slice::from_raw_parts(pre_compute, size_of::<MulHPreCompute>()).borrow();
259    execute_e12_impl::<F, CTX, OP>(pre_compute, exec_state);
260}
261
262#[create_handler]
263#[inline(always)]
264unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: MulHOperation>(
265    pre_compute: *const u8,
266    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
267) {
268    let pre_compute: &E2PreCompute<MulHPreCompute> =
269        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<MulHPreCompute>>()).borrow();
270    exec_state
271        .ctx
272        .on_height_change(pre_compute.chip_idx as usize, 1);
273    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, exec_state);
274}
275
276trait MulHOperation {
277    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4];
278}
279struct MulHOp;
280struct MulHSuOp;
281struct MulHUOp;
282impl MulHOperation for MulHOp {
283    #[inline(always)]
284    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
285        let rs1 = i32::from_le_bytes(rs1) as i64;
286        let rs2 = i32::from_le_bytes(rs2) as i64;
287        ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
288    }
289}
290impl MulHOperation for MulHSuOp {
291    #[inline(always)]
292    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
293        let rs1 = i32::from_le_bytes(rs1) as i64;
294        let rs2 = u32::from_le_bytes(rs2) as i64;
295        ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
296    }
297}
298impl MulHOperation for MulHUOp {
299    #[inline(always)]
300    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
301        let rs1 = u32::from_le_bytes(rs1) as i64;
302        let rs2 = u32::from_le_bytes(rs2) as i64;
303        ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
304    }
305}