openvm_rv32im_circuit/mul/
execution.rs1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::MulOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17#[cfg(feature = "aot")]
18use crate::common::*;
19use crate::MultiplicationExecutor;
20
21#[derive(AlignedBytesBorrow, Clone)]
22#[repr(C)]
23struct MultiPreCompute {
24 a: u8,
25 b: u8,
26 c: u8,
27}
28
29impl<A, const LIMB_BITS: usize> MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
30 fn pre_compute_impl<F: PrimeField32>(
31 &self,
32 pc: u32,
33 inst: &Instruction<F>,
34 data: &mut MultiPreCompute,
35 ) -> Result<(), StaticProgramError> {
36 assert_eq!(
37 MulOpcode::from_usize(inst.opcode.local_opcode_idx(self.offset)),
38 MulOpcode::MUL
39 );
40 if inst.d.as_canonical_u32() != RV32_REGISTER_AS {
41 return Err(StaticProgramError::InvalidInstruction(pc));
42 }
43
44 *data = MultiPreCompute {
45 a: inst.a.as_canonical_u32() as u8,
46 b: inst.b.as_canonical_u32() as u8,
47 c: inst.c.as_canonical_u32() as u8,
48 };
49 Ok(())
50 }
51}
52
53impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
54 for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
55where
56 F: PrimeField32,
57{
58 fn pre_compute_size(&self) -> usize {
59 size_of::<MultiPreCompute>()
60 }
61 #[cfg(not(feature = "tco"))]
62 fn pre_compute<Ctx>(
63 &self,
64 pc: u32,
65 inst: &Instruction<F>,
66 data: &mut [u8],
67 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
68 where
69 Ctx: ExecutionCtxTrait,
70 {
71 let pre_compute: &mut MultiPreCompute = data.borrow_mut();
72 self.pre_compute_impl(pc, inst, pre_compute)?;
73 Ok(execute_e1_impl)
74 }
75
76 #[cfg(feature = "tco")]
77 fn handler<Ctx>(
78 &self,
79 pc: u32,
80 inst: &Instruction<F>,
81 data: &mut [u8],
82 ) -> Result<Handler<F, Ctx>, StaticProgramError>
83 where
84 Ctx: ExecutionCtxTrait,
85 {
86 let pre_compute: &mut MultiPreCompute = data.borrow_mut();
87 self.pre_compute_impl(pc, inst, pre_compute)?;
88 Ok(execute_e1_handler)
89 }
90}
91
92#[cfg(feature = "aot")]
93impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
94 for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
95where
96 F: PrimeField32,
97{
98 fn is_aot_supported(&self, inst: &Instruction<F>) -> bool {
99 inst.opcode == MulOpcode::MUL.global_opcode()
100 }
101
102 fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
103 let to_i16 = |c: F| -> i16 {
104 let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
105 let c_i24 = ((c_u24 << 8) as i32) >> 8;
106 c_i24 as i16
107 };
108 let a = to_i16(inst.a);
109 let b = to_i16(inst.b);
110 let c = to_i16(inst.c);
111
112 if a % 4 != 0 || b % 4 != 0 || c % 4 != 0 {
113 return Err(AotError::InvalidInstruction);
114 }
115
116 let mut asm_str = String::new();
117
118 let str_reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
119 RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
120 } else {
121 REG_A_W
122 };
123
124 if a == c {
125 let (gpr_reg_c, delta_str_c) = xmm_to_gpr((c / 4) as u8, str_reg_a, true);
127 asm_str += &delta_str_c;
128 let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, REG_C_W, false);
129 asm_str += &delta_str_b;
130 asm_str += &format!(" imul {gpr_reg_c}, {gpr_reg_b}\n");
131 asm_str += &gpr_to_xmm(&gpr_reg_c, (a / 4) as u8);
132 } else {
133 let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, str_reg_a, true);
134 asm_str += &delta_str_b; let (gpr_reg_c, delta_str_c) = xmm_to_gpr((c / 4) as u8, REG_C_W, false); asm_str += &delta_str_c; asm_str += &format!(" imul {gpr_reg_b}, {gpr_reg_c}\n");
138 asm_str += &gpr_to_xmm(&gpr_reg_b, (a / 4) as u8);
139 }
140
141 Ok(asm_str)
142 }
143}
144
145impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
146 for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
147where
148 F: PrimeField32,
149{
150 fn metered_pre_compute_size(&self) -> usize {
151 size_of::<E2PreCompute<MultiPreCompute>>()
152 }
153
154 #[cfg(not(feature = "tco"))]
155 fn metered_pre_compute<Ctx>(
156 &self,
157 chip_idx: usize,
158 pc: u32,
159 inst: &Instruction<F>,
160 data: &mut [u8],
161 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
162 where
163 Ctx: MeteredExecutionCtxTrait,
164 {
165 let pre_compute: &mut E2PreCompute<MultiPreCompute> = data.borrow_mut();
166 pre_compute.chip_idx = chip_idx as u32;
167 self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
168 Ok(execute_e2_impl)
169 }
170
171 #[cfg(feature = "tco")]
172 fn metered_handler<Ctx>(
173 &self,
174 chip_idx: usize,
175 pc: u32,
176 inst: &Instruction<F>,
177 data: &mut [u8],
178 ) -> Result<Handler<F, Ctx>, StaticProgramError>
179 where
180 Ctx: MeteredExecutionCtxTrait,
181 {
182 let pre_compute: &mut E2PreCompute<MultiPreCompute> = data.borrow_mut();
183 pre_compute.chip_idx = chip_idx as u32;
184 self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
185 Ok(execute_e2_handler)
186 }
187}
188
189#[cfg(feature = "aot")]
190impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
191 for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
192where
193 F: PrimeField32,
194{
195 fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
196 true
197 }
198 fn generate_x86_metered_asm(
199 &self,
200 inst: &Instruction<F>,
201 pc: u32,
202 chip_idx: usize,
203 config: &SystemConfig,
204 ) -> Result<String, AotError> {
205 let mut asm_str = self.generate_x86_asm(inst, pc)?;
206
207 asm_str += &update_height_change_asm(chip_idx, 1)?;
208 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
210 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
212 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
214
215 Ok(asm_str)
216 }
217}
218#[inline(always)]
219unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
220 pre_compute: &MultiPreCompute,
221 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
222) {
223 let rs1: [u8; RV32_REGISTER_NUM_LIMBS] =
224 exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
225 let rs2: [u8; RV32_REGISTER_NUM_LIMBS] =
226 exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
227 let rs1 = u32::from_le_bytes(rs1);
228 let rs2 = u32::from_le_bytes(rs2);
229 let rd = rs1.wrapping_mul(rs2);
230 exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd.to_le_bytes());
231
232 let pc = exec_state.pc();
233 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
234}
235
236#[create_handler]
237#[inline(always)]
238unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
239 pre_compute: *const u8,
240 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
241) {
242 let pre_compute: &MultiPreCompute =
243 std::slice::from_raw_parts(pre_compute, size_of::<MultiPreCompute>()).borrow();
244 execute_e12_impl(pre_compute, exec_state);
245}
246
247#[create_handler]
248#[inline(always)]
249unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
250 pre_compute: *const u8,
251 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
252) {
253 let pre_compute: &E2PreCompute<MultiPreCompute> =
254 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<MultiPreCompute>>())
255 .borrow();
256 exec_state
257 .ctx
258 .on_height_change(pre_compute.chip_idx as usize, 1);
259 execute_e12_impl(&pre_compute.data, exec_state);
260}