openvm_rv32im_circuit/mulh/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::MulHOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::MulHExecutor;
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21struct MulHPreCompute {
22    a: u8,
23    b: u8,
24    c: u8,
25}
26
27impl<A, const LIMB_BITS: usize> MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
28    #[inline(always)]
29    fn pre_compute_impl<F: PrimeField32>(
30        &self,
31        inst: &Instruction<F>,
32        data: &mut MulHPreCompute,
33    ) -> Result<MulHOpcode, StaticProgramError> {
34        *data = MulHPreCompute {
35            a: inst.a.as_canonical_u32() as u8,
36            b: inst.b.as_canonical_u32() as u8,
37            c: inst.c.as_canonical_u32() as u8,
38        };
39        Ok(MulHOpcode::from_usize(
40            inst.opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET),
41        ))
42    }
43}
44
45macro_rules! dispatch {
46    ($execute_impl:ident, $local_opcode:ident) => {
47        match $local_opcode {
48            MulHOpcode::MULH => Ok($execute_impl::<_, _, MulHOp>),
49            MulHOpcode::MULHSU => Ok($execute_impl::<_, _, MulHSuOp>),
50            MulHOpcode::MULHU => Ok($execute_impl::<_, _, MulHUOp>),
51        }
52    };
53}
54
55impl<F, A, const LIMB_BITS: usize> Executor<F>
56    for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
57where
58    F: PrimeField32,
59{
60    #[inline(always)]
61    fn pre_compute_size(&self) -> usize {
62        size_of::<MulHPreCompute>()
63    }
64
65    #[cfg(not(feature = "tco"))]
66    #[inline(always)]
67    fn pre_compute<Ctx: ExecutionCtxTrait>(
68        &self,
69        _pc: u32,
70        inst: &Instruction<F>,
71        data: &mut [u8],
72    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
73        let pre_compute: &mut MulHPreCompute = data.borrow_mut();
74        let local_opcode = self.pre_compute_impl(inst, pre_compute)?;
75        dispatch!(execute_e1_handler, local_opcode)
76    }
77
78    #[cfg(feature = "tco")]
79    fn handler<Ctx>(
80        &self,
81        _pc: u32,
82        inst: &Instruction<F>,
83        data: &mut [u8],
84    ) -> Result<Handler<F, Ctx>, StaticProgramError>
85    where
86        Ctx: ExecutionCtxTrait,
87    {
88        let pre_compute: &mut MulHPreCompute = data.borrow_mut();
89        let local_opcode = self.pre_compute_impl(inst, pre_compute)?;
90        dispatch!(execute_e1_handler, local_opcode)
91    }
92}
93
94impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
95    for MulHExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
96where
97    F: PrimeField32,
98{
99    fn metered_pre_compute_size(&self) -> usize {
100        size_of::<E2PreCompute<MulHPreCompute>>()
101    }
102
103    #[cfg(not(feature = "tco"))]
104    fn metered_pre_compute<Ctx>(
105        &self,
106        chip_idx: usize,
107        _pc: u32,
108        inst: &Instruction<F>,
109        data: &mut [u8],
110    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
111    where
112        Ctx: MeteredExecutionCtxTrait,
113    {
114        let pre_compute: &mut E2PreCompute<MulHPreCompute> = data.borrow_mut();
115        pre_compute.chip_idx = chip_idx as u32;
116        let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?;
117        dispatch!(execute_e2_handler, local_opcode)
118    }
119
120    #[cfg(feature = "tco")]
121    fn metered_handler<Ctx>(
122        &self,
123        chip_idx: usize,
124        _pc: u32,
125        inst: &Instruction<F>,
126        data: &mut [u8],
127    ) -> Result<Handler<F, Ctx>, StaticProgramError>
128    where
129        Ctx: MeteredExecutionCtxTrait,
130    {
131        let pre_compute: &mut E2PreCompute<MulHPreCompute> = data.borrow_mut();
132        pre_compute.chip_idx = chip_idx as u32;
133        let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?;
134        dispatch!(execute_e2_handler, local_opcode)
135    }
136}
137
138#[inline(always)]
139unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: MulHOperation>(
140    pre_compute: &MulHPreCompute,
141    instret: &mut u64,
142    pc: &mut u32,
143    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
144) {
145    let rs1: [u8; RV32_REGISTER_NUM_LIMBS] =
146        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
147    let rs2: [u8; RV32_REGISTER_NUM_LIMBS] =
148        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
149    let rd = <OP as MulHOperation>::compute(rs1, rs2);
150    exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
151
152    *pc += DEFAULT_PC_STEP;
153    *instret += 1;
154}
155
156#[create_handler]
157#[inline(always)]
158unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: MulHOperation>(
159    pre_compute: &[u8],
160    instret: &mut u64,
161    pc: &mut u32,
162    _instret_end: u64,
163    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
164) {
165    let pre_compute: &MulHPreCompute = pre_compute.borrow();
166    execute_e12_impl::<F, CTX, OP>(pre_compute, instret, pc, exec_state);
167}
168
169#[create_handler]
170#[inline(always)]
171unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: MulHOperation>(
172    pre_compute: &[u8],
173    instret: &mut u64,
174    pc: &mut u32,
175    _arg: u64,
176    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
177) {
178    let pre_compute: &E2PreCompute<MulHPreCompute> = pre_compute.borrow();
179    exec_state
180        .ctx
181        .on_height_change(pre_compute.chip_idx as usize, 1);
182    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, instret, pc, exec_state);
183}
184
185trait MulHOperation {
186    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4];
187}
188struct MulHOp;
189struct MulHSuOp;
190struct MulHUOp;
191impl MulHOperation for MulHOp {
192    #[inline(always)]
193    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
194        let rs1 = i32::from_le_bytes(rs1) as i64;
195        let rs2 = i32::from_le_bytes(rs2) as i64;
196        ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
197    }
198}
199impl MulHOperation for MulHSuOp {
200    #[inline(always)]
201    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
202        let rs1 = i32::from_le_bytes(rs1) as i64;
203        let rs2 = u32::from_le_bytes(rs2) as i64;
204        ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
205    }
206}
207impl MulHOperation for MulHUOp {
208    #[inline(always)]
209    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
210        let rs1 = u32::from_le_bytes(rs1) as i64;
211        let rs2 = u32::from_le_bytes(rs2) as i64;
212        ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes()
213    }
214}