openvm_rv32im_circuit/mulh/
execution.rs1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::MulHOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17#[allow(unused_imports)]
18use crate::common::*;
19use crate::MulHExecutor;
20
21#[derive(AlignedBytesBorrow, Clone)]
22#[repr(C)]
23struct MulHPreCompute {
24 a: u8,
25 b: u8,
26 c: u8,
27}
28
29impl<A, const LIMB_BITS: usize> MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
30 #[inline(always)]
31 fn pre_compute_impl<F: PrimeField32>(
32 &self,
33 inst: &Instruction<F>,
34 data: &mut MulHPreCompute,
35 ) -> Result<MulHOpcode, StaticProgramError> {
36 *data = MulHPreCompute {
37 a: inst.a.as_canonical_u32() as u8,
38 b: inst.b.as_canonical_u32() as u8,
39 c: inst.c.as_canonical_u32() as u8,
40 };
41 Ok(MulHOpcode::from_usize(
42 inst.opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET),
43 ))
44 }
45}
46
47macro_rules! dispatch {
48 ($execute_impl:ident, $local_opcode:ident) => {
49 match $local_opcode {
50 MulHOpcode::MULH => Ok($execute_impl::<_, _, MulHOp>),
51 MulHOpcode::MULHSU => Ok($execute_impl::<_, _, MulHSuOp>),
52 MulHOpcode::MULHU => Ok($execute_impl::<_, _, MulHUOp>),
53 }
54 };
55}
56
57impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
58 for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
59where
60 F: PrimeField32,
61{
62 #[inline(always)]
63 fn pre_compute_size(&self) -> usize {
64 size_of::<MulHPreCompute>()
65 }
66
67 #[cfg(not(feature = "tco"))]
68 #[inline(always)]
69 fn pre_compute<Ctx: ExecutionCtxTrait>(
70 &self,
71 _pc: u32,
72 inst: &Instruction<F>,
73 data: &mut [u8],
74 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
75 let pre_compute: &mut MulHPreCompute = data.borrow_mut();
76 let local_opcode = self.pre_compute_impl(inst, pre_compute)?;
77 dispatch!(execute_e1_handler, local_opcode)
78 }
79
80 #[cfg(feature = "tco")]
81 fn handler<Ctx>(
82 &self,
83 _pc: u32,
84 inst: &Instruction<F>,
85 data: &mut [u8],
86 ) -> Result<Handler<F, Ctx>, StaticProgramError>
87 where
88 Ctx: ExecutionCtxTrait,
89 {
90 let pre_compute: &mut MulHPreCompute = data.borrow_mut();
91 let local_opcode = self.pre_compute_impl(inst, pre_compute)?;
92 dispatch!(execute_e1_handler, local_opcode)
93 }
94}
95
96#[cfg(feature = "aot")]
97impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
98 for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
99where
100 F: PrimeField32,
101{
102 fn is_aot_supported(&self, inst: &Instruction<F>) -> bool {
103 inst.opcode == MulHOpcode::MULH.global_opcode()
104 || inst.opcode == MulHOpcode::MULHSU.global_opcode()
105 || inst.opcode == MulHOpcode::MULHU.global_opcode()
106 }
107
108 fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
109 let to_i16 = |c: F| -> i16 {
110 let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
111 let c_i24 = ((c_u24 << 8) as i32) >> 8;
112 c_i24 as i16
113 };
114
115 let a = to_i16(inst.a);
116 let b = to_i16(inst.b);
117 let c = to_i16(inst.c);
118
119 if a % 4 != 0 || b % 4 != 0 || c % 4 != 0 {
120 return Err(AotError::InvalidInstruction);
121 }
122
123 let opcode = MulHOpcode::from_usize(inst.opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET));
124
125 let mut asm = String::new();
126
127 let (_, delta_str_b) = &xmm_to_gpr((b / 4) as u8, "eax", true);
133 let (gpr_reg_c, delta_str_c) = &xmm_to_gpr((c / 4) as u8, REG_A_W, false);
134 asm += delta_str_b;
135 asm += delta_str_c;
136 match opcode {
137 MulHOpcode::MULH => {
138 asm += &format!(" imul {gpr_reg_c}\n");
139 asm += &gpr_to_xmm("edx", (a / 4) as u8);
140 }
141 MulHOpcode::MULHSU => {
142 asm += &format!(" mov {REG_B_W}, eax\n");
144 asm += &format!(" imul {gpr_reg_c}\n");
145 asm += " mov eax, edx\n";
146 asm += &format!(" mov edx, {gpr_reg_c}\n");
147 asm += " sar edx, 31\n";
148 asm += &format!(" and edx, {REG_B_W}\n");
149 asm += " add eax, edx\n";
150 asm += &gpr_to_xmm("eax", (a / 4) as u8);
151 }
152 MulHOpcode::MULHU => {
153 asm += &format!(" mul {gpr_reg_c}\n");
154 asm += &gpr_to_xmm("edx", (a / 4) as u8);
155 }
156 }
157 Ok(asm)
158 }
159}
160
161impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
162 for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
163where
164 F: PrimeField32,
165{
166 fn metered_pre_compute_size(&self) -> usize {
167 size_of::<E2PreCompute<MulHPreCompute>>()
168 }
169
170 #[cfg(not(feature = "tco"))]
171 fn metered_pre_compute<Ctx>(
172 &self,
173 chip_idx: usize,
174 _pc: u32,
175 inst: &Instruction<F>,
176 data: &mut [u8],
177 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
178 where
179 Ctx: MeteredExecutionCtxTrait,
180 {
181 let pre_compute: &mut E2PreCompute<MulHPreCompute> = data.borrow_mut();
182 pre_compute.chip_idx = chip_idx as u32;
183 let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?;
184 dispatch!(execute_e2_handler, local_opcode)
185 }
186
187 #[cfg(feature = "tco")]
188 fn metered_handler<Ctx>(
189 &self,
190 chip_idx: usize,
191 _pc: u32,
192 inst: &Instruction<F>,
193 data: &mut [u8],
194 ) -> Result<Handler<F, Ctx>, StaticProgramError>
195 where
196 Ctx: MeteredExecutionCtxTrait,
197 {
198 let pre_compute: &mut E2PreCompute<MulHPreCompute> = data.borrow_mut();
199 pre_compute.chip_idx = chip_idx as u32;
200 let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?;
201 dispatch!(execute_e2_handler, local_opcode)
202 }
203}
204
205#[cfg(feature = "aot")]
206impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
207 for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
208where
209 F: PrimeField32,
210{
211 fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
212 true
213 }
214 fn generate_x86_metered_asm(
215 &self,
216 inst: &Instruction<F>,
217 pc: u32,
218 chip_idx: usize,
219 config: &SystemConfig,
220 ) -> Result<String, AotError> {
221 let mut asm_str = self.generate_x86_asm(inst, pc)?;
222
223 asm_str += &update_height_change_asm(chip_idx, 1)?;
224 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
226 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
228 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
230
231 Ok(asm_str)
232 }
233}
234
235#[inline(always)]
236unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: MulHOperation>(
237 pre_compute: &MulHPreCompute,
238 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
239) {
240 let rs1: [u8; RV32_REGISTER_NUM_LIMBS] =
241 exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
242 let rs2: [u8; RV32_REGISTER_NUM_LIMBS] =
243 exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
244 let rd = <OP as MulHOperation>::compute(rs1, rs2);
245 exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
246
247 let pc = exec_state.pc();
248 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
249}
250
251#[create_handler]
252#[inline(always)]
253unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: MulHOperation>(
254 pre_compute: *const u8,
255 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
256) {
257 let pre_compute: &MulHPreCompute =
258 std::slice::from_raw_parts(pre_compute, size_of::<MulHPreCompute>()).borrow();
259 execute_e12_impl::<F, CTX, OP>(pre_compute, exec_state);
260}
261
262#[create_handler]
263#[inline(always)]
264unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: MulHOperation>(
265 pre_compute: *const u8,
266 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
267) {
268 let pre_compute: &E2PreCompute<MulHPreCompute> =
269 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<MulHPreCompute>>()).borrow();
270 exec_state
271 .ctx
272 .on_height_change(pre_compute.chip_idx as usize, 1);
273 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, exec_state);
274}
275
276trait MulHOperation {
277 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4];
278}
279struct MulHOp;
280struct MulHSuOp;
281struct MulHUOp;
282impl MulHOperation for MulHOp {
283 #[inline(always)]
284 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
285 let rs1 = i32::from_le_bytes(rs1) as i64;
286 let rs2 = i32::from_le_bytes(rs2) as i64;
287 ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
288 }
289}
290impl MulHOperation for MulHSuOp {
291 #[inline(always)]
292 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
293 let rs1 = i32::from_le_bytes(rs1) as i64;
294 let rs2 = u32::from_le_bytes(rs2) as i64;
295 ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
296 }
297}
298impl MulHOperation for MulHUOp {
299 #[inline(always)]
300 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
301 let rs1 = u32::from_le_bytes(rs1) as i64;
302 let rs2 = u32::from_le_bytes(rs2) as i64;
303 ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
304 }
305}