openvm_rv32im_circuit/mulh/
execution.rs1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::MulHOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::MulHExecutor;
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21struct MulHPreCompute {
22 a: u8,
23 b: u8,
24 c: u8,
25}
26
27impl<A, const LIMB_BITS: usize> MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
28 #[inline(always)]
29 fn pre_compute_impl<F: PrimeField32>(
30 &self,
31 inst: &Instruction<F>,
32 data: &mut MulHPreCompute,
33 ) -> Result<MulHOpcode, StaticProgramError> {
34 *data = MulHPreCompute {
35 a: inst.a.as_canonical_u32() as u8,
36 b: inst.b.as_canonical_u32() as u8,
37 c: inst.c.as_canonical_u32() as u8,
38 };
39 Ok(MulHOpcode::from_usize(
40 inst.opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET),
41 ))
42 }
43}
44
45macro_rules! dispatch {
46 ($execute_impl:ident, $local_opcode:ident) => {
47 match $local_opcode {
48 MulHOpcode::MULH => Ok($execute_impl::<_, _, MulHOp>),
49 MulHOpcode::MULHSU => Ok($execute_impl::<_, _, MulHSuOp>),
50 MulHOpcode::MULHU => Ok($execute_impl::<_, _, MulHUOp>),
51 }
52 };
53}
54
55impl<F, A, const LIMB_BITS: usize> Executor<F>
56 for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
57where
58 F: PrimeField32,
59{
60 #[inline(always)]
61 fn pre_compute_size(&self) -> usize {
62 size_of::<MulHPreCompute>()
63 }
64
65 #[inline(always)]
66 fn pre_compute<Ctx: ExecutionCtxTrait>(
67 &self,
68 _pc: u32,
69 inst: &Instruction<F>,
70 data: &mut [u8],
71 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
72 let pre_compute: &mut MulHPreCompute = data.borrow_mut();
73 let local_opcode = self.pre_compute_impl(inst, pre_compute)?;
74 dispatch!(execute_e1_impl, local_opcode)
75 }
76
77 #[cfg(feature = "tco")]
78 fn handler<Ctx>(
79 &self,
80 _pc: u32,
81 inst: &Instruction<F>,
82 data: &mut [u8],
83 ) -> Result<Handler<F, Ctx>, StaticProgramError>
84 where
85 Ctx: ExecutionCtxTrait,
86 {
87 let pre_compute: &mut MulHPreCompute = data.borrow_mut();
88 let local_opcode = self.pre_compute_impl(inst, pre_compute)?;
89 dispatch!(execute_e1_tco_handler, local_opcode)
90 }
91}
92
93impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
94 for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
95where
96 F: PrimeField32,
97{
98 fn metered_pre_compute_size(&self) -> usize {
99 size_of::<E2PreCompute<MulHPreCompute>>()
100 }
101
102 fn metered_pre_compute<Ctx>(
103 &self,
104 chip_idx: usize,
105 _pc: u32,
106 inst: &Instruction<F>,
107 data: &mut [u8],
108 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
109 where
110 Ctx: MeteredExecutionCtxTrait,
111 {
112 let pre_compute: &mut E2PreCompute<MulHPreCompute> = data.borrow_mut();
113 pre_compute.chip_idx = chip_idx as u32;
114 let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?;
115 dispatch!(execute_e2_impl, local_opcode)
116 }
117
118 #[cfg(feature = "tco")]
119 fn metered_handler<Ctx>(
120 &self,
121 chip_idx: usize,
122 _pc: u32,
123 inst: &Instruction<F>,
124 data: &mut [u8],
125 ) -> Result<Handler<F, Ctx>, StaticProgramError>
126 where
127 Ctx: MeteredExecutionCtxTrait,
128 {
129 let pre_compute: &mut E2PreCompute<MulHPreCompute> = data.borrow_mut();
130 pre_compute.chip_idx = chip_idx as u32;
131 let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?;
132 dispatch!(execute_e2_tco_handler, local_opcode)
133 }
134}
135
136#[inline(always)]
137unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: MulHOperation>(
138 pre_compute: &MulHPreCompute,
139 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
140) {
141 let rs1: [u8; RV32_REGISTER_NUM_LIMBS] =
142 vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
143 let rs2: [u8; RV32_REGISTER_NUM_LIMBS] =
144 vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
145 let rd = <OP as MulHOperation>::compute(rs1, rs2);
146 vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
147
148 vm_state.pc += DEFAULT_PC_STEP;
149 vm_state.instret += 1;
150}
151
152#[create_tco_handler]
153unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: MulHOperation>(
154 pre_compute: &[u8],
155 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
156) {
157 let pre_compute: &MulHPreCompute = pre_compute.borrow();
158 execute_e12_impl::<F, CTX, OP>(pre_compute, vm_state);
159}
160
161#[create_tco_handler]
162unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: MulHOperation>(
163 pre_compute: &[u8],
164 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
165) {
166 let pre_compute: &E2PreCompute<MulHPreCompute> = pre_compute.borrow();
167 vm_state
168 .ctx
169 .on_height_change(pre_compute.chip_idx as usize, 1);
170 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, vm_state);
171}
172
173trait MulHOperation {
174 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4];
175}
176struct MulHOp;
177struct MulHSuOp;
178struct MulHUOp;
179impl MulHOperation for MulHOp {
180 #[inline(always)]
181 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
182 let rs1 = i32::from_le_bytes(rs1) as i64;
183 let rs2 = i32::from_le_bytes(rs2) as i64;
184 ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
185 }
186}
187impl MulHOperation for MulHSuOp {
188 #[inline(always)]
189 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
190 let rs1 = i32::from_le_bytes(rs1) as i64;
191 let rs2 = u32::from_le_bytes(rs2) as i64;
192 ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
193 }
194}
195impl MulHOperation for MulHUOp {
196 #[inline(always)]
197 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
198 let rs1 = u32::from_le_bytes(rs1) as i64;
199 let rs2 = u32::from_le_bytes(rs2) as i64;
200 ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
201 }
202}