openvm_rv32im_circuit/mul/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::MulOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17#[cfg(feature = "aot")]
18use crate::common::*;
19use crate::MultiplicationExecutor;
20
21#[derive(AlignedBytesBorrow, Clone)]
22#[repr(C)]
23struct MultiPreCompute {
24    a: u8,
25    b: u8,
26    c: u8,
27}
28
29impl<A, const LIMB_BITS: usize> MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
30    fn pre_compute_impl<F: PrimeField32>(
31        &self,
32        pc: u32,
33        inst: &Instruction<F>,
34        data: &mut MultiPreCompute,
35    ) -> Result<(), StaticProgramError> {
36        assert_eq!(
37            MulOpcode::from_usize(inst.opcode.local_opcode_idx(self.offset)),
38            MulOpcode::MUL
39        );
40        if inst.d.as_canonical_u32() != RV32_REGISTER_AS {
41            return Err(StaticProgramError::InvalidInstruction(pc));
42        }
43
44        *data = MultiPreCompute {
45            a: inst.a.as_canonical_u32() as u8,
46            b: inst.b.as_canonical_u32() as u8,
47            c: inst.c.as_canonical_u32() as u8,
48        };
49        Ok(())
50    }
51}
52
53impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
54    for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
55where
56    F: PrimeField32,
57{
58    fn pre_compute_size(&self) -> usize {
59        size_of::<MultiPreCompute>()
60    }
61    #[cfg(not(feature = "tco"))]
62    fn pre_compute<Ctx>(
63        &self,
64        pc: u32,
65        inst: &Instruction<F>,
66        data: &mut [u8],
67    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
68    where
69        Ctx: ExecutionCtxTrait,
70    {
71        let pre_compute: &mut MultiPreCompute = data.borrow_mut();
72        self.pre_compute_impl(pc, inst, pre_compute)?;
73        Ok(execute_e1_impl)
74    }
75
76    #[cfg(feature = "tco")]
77    fn handler<Ctx>(
78        &self,
79        pc: u32,
80        inst: &Instruction<F>,
81        data: &mut [u8],
82    ) -> Result<Handler<F, Ctx>, StaticProgramError>
83    where
84        Ctx: ExecutionCtxTrait,
85    {
86        let pre_compute: &mut MultiPreCompute = data.borrow_mut();
87        self.pre_compute_impl(pc, inst, pre_compute)?;
88        Ok(execute_e1_handler)
89    }
90}
91
92#[cfg(feature = "aot")]
93impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
94    for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
95where
96    F: PrimeField32,
97{
98    fn is_aot_supported(&self, inst: &Instruction<F>) -> bool {
99        inst.opcode == MulOpcode::MUL.global_opcode()
100    }
101
102    fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
103        let to_i16 = |c: F| -> i16 {
104            let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
105            let c_i24 = ((c_u24 << 8) as i32) >> 8;
106            c_i24 as i16
107        };
108        let a = to_i16(inst.a);
109        let b = to_i16(inst.b);
110        let c = to_i16(inst.c);
111
112        if a % 4 != 0 || b % 4 != 0 || c % 4 != 0 {
113            return Err(AotError::InvalidInstruction);
114        }
115
116        let mut asm_str = String::new();
117
118        let str_reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
119            RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
120        } else {
121            REG_A_W
122        };
123
124        if a == c {
125            // a = b * c; commutative, so don't need to write to tmp, but should copy c to a first
126            let (gpr_reg_c, delta_str_c) = xmm_to_gpr((c / 4) as u8, str_reg_a, true);
127            asm_str += &delta_str_c;
128            let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, REG_C_W, false);
129            asm_str += &delta_str_b;
130            asm_str += &format!("   imul {gpr_reg_c}, {gpr_reg_b}\n");
131            asm_str += &gpr_to_xmm(&gpr_reg_c, (a / 4) as u8);
132        } else {
133            let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, str_reg_a, true);
134            asm_str += &delta_str_b; // data is now in gpr_reg_b
135            let (gpr_reg_c, delta_str_c) = xmm_to_gpr((c / 4) as u8, REG_C_W, false); // data is in gpr_reg_c now
136            asm_str += &delta_str_c; // have to get a return value here, since it modifies further registers too
137            asm_str += &format!("   imul {gpr_reg_b}, {gpr_reg_c}\n");
138            asm_str += &gpr_to_xmm(&gpr_reg_b, (a / 4) as u8);
139        }
140
141        Ok(asm_str)
142    }
143}
144
145impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
146    for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
147where
148    F: PrimeField32,
149{
150    fn metered_pre_compute_size(&self) -> usize {
151        size_of::<E2PreCompute<MultiPreCompute>>()
152    }
153
154    #[cfg(not(feature = "tco"))]
155    fn metered_pre_compute<Ctx>(
156        &self,
157        chip_idx: usize,
158        pc: u32,
159        inst: &Instruction<F>,
160        data: &mut [u8],
161    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
162    where
163        Ctx: MeteredExecutionCtxTrait,
164    {
165        let pre_compute: &mut E2PreCompute<MultiPreCompute> = data.borrow_mut();
166        pre_compute.chip_idx = chip_idx as u32;
167        self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
168        Ok(execute_e2_impl)
169    }
170
171    #[cfg(feature = "tco")]
172    fn metered_handler<Ctx>(
173        &self,
174        chip_idx: usize,
175        pc: u32,
176        inst: &Instruction<F>,
177        data: &mut [u8],
178    ) -> Result<Handler<F, Ctx>, StaticProgramError>
179    where
180        Ctx: MeteredExecutionCtxTrait,
181    {
182        let pre_compute: &mut E2PreCompute<MultiPreCompute> = data.borrow_mut();
183        pre_compute.chip_idx = chip_idx as u32;
184        self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
185        Ok(execute_e2_handler)
186    }
187}
188
189#[cfg(feature = "aot")]
190impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
191    for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
192where
193    F: PrimeField32,
194{
195    fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
196        true
197    }
198    fn generate_x86_metered_asm(
199        &self,
200        inst: &Instruction<F>,
201        pc: u32,
202        chip_idx: usize,
203        config: &SystemConfig,
204    ) -> Result<String, AotError> {
205        let mut asm_str = self.generate_x86_asm(inst, pc)?;
206
207        asm_str += &update_height_change_asm(chip_idx, 1)?;
208        // read [b:4]_1
209        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
210        // read [c:4]_1
211        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
212        // write [a:4]_1
213        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
214
215        Ok(asm_str)
216    }
217}
218#[inline(always)]
219unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
220    pre_compute: &MultiPreCompute,
221    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
222) {
223    let rs1: [u8; RV32_REGISTER_NUM_LIMBS] =
224        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
225    let rs2: [u8; RV32_REGISTER_NUM_LIMBS] =
226        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
227    let rs1 = u32::from_le_bytes(rs1);
228    let rs2 = u32::from_le_bytes(rs2);
229    let rd = rs1.wrapping_mul(rs2);
230    exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd.to_le_bytes());
231
232    let pc = exec_state.pc();
233    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
234}
235
236#[create_handler]
237#[inline(always)]
238unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
239    pre_compute: *const u8,
240    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
241) {
242    let pre_compute: &MultiPreCompute =
243        std::slice::from_raw_parts(pre_compute, size_of::<MultiPreCompute>()).borrow();
244    execute_e12_impl(pre_compute, exec_state);
245}
246
247#[create_handler]
248#[inline(always)]
249unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
250    pre_compute: *const u8,
251    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
252) {
253    let pre_compute: &E2PreCompute<MultiPreCompute> =
254        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<MultiPreCompute>>())
255            .borrow();
256    exec_state
257        .ctx
258        .on_height_change(pre_compute.chip_idx as usize, 1);
259    execute_e12_impl(&pre_compute.data, exec_state);
260}