1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::ShiftOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use super::ShiftExecutor;
18#[allow(unused_imports)]
19use crate::{adapters::imm_to_bytes, common::*};
20
21#[derive(AlignedBytesBorrow, Clone)]
22#[repr(C)]
23struct ShiftPreCompute {
24 c: u32,
25 a: u8,
26 b: u8,
27}
28
29impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftExecutor<A, NUM_LIMBS, LIMB_BITS> {
30 #[inline(always)]
31 fn pre_compute_impl<F: PrimeField32>(
32 &self,
33 pc: u32,
34 inst: &Instruction<F>,
35 data: &mut ShiftPreCompute,
36 ) -> Result<(bool, ShiftOpcode), StaticProgramError> {
37 let Instruction {
38 opcode, a, b, c, e, ..
39 } = inst;
40 let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset));
41 let e_u32 = e.as_canonical_u32();
42 if inst.d.as_canonical_u32() != RV32_REGISTER_AS
43 || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
44 {
45 return Err(StaticProgramError::InvalidInstruction(pc));
46 }
47 let is_imm = e_u32 == RV32_IMM_AS;
48 let c_u32 = c.as_canonical_u32();
49 *data = ShiftPreCompute {
50 c: if is_imm {
51 u32::from_le_bytes(imm_to_bytes(c_u32))
52 } else {
53 c_u32
54 },
55 a: a.as_canonical_u32() as u8,
56 b: b.as_canonical_u32() as u8,
57 };
58 Ok((is_imm, shift_opcode))
60 }
61}
62
63macro_rules! dispatch {
64 ($execute_impl:ident, $is_imm:ident, $shift_opcode:ident) => {
65 match ($is_imm, $shift_opcode) {
66 (true, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, true, SllOp>),
67 (false, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, false, SllOp>),
68 (true, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, true, SrlOp>),
69 (false, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, false, SrlOp>),
70 (true, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, true, SraOp>),
71 (false, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, false, SraOp>),
72 }
73 };
74}
75
76impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> InterpreterExecutor<F>
77 for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
78where
79 F: PrimeField32,
80{
81 fn pre_compute_size(&self) -> usize {
82 size_of::<ShiftPreCompute>()
83 }
84
85 #[cfg(not(feature = "tco"))]
86 fn pre_compute<Ctx: ExecutionCtxTrait>(
87 &self,
88 pc: u32,
89 inst: &Instruction<F>,
90 data: &mut [u8],
91 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
92 let data: &mut ShiftPreCompute = data.borrow_mut();
93 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
94 dispatch!(execute_e1_handler, is_imm, shift_opcode)
96 }
97
98 #[cfg(feature = "tco")]
99 fn handler<Ctx>(
100 &self,
101 pc: u32,
102 inst: &Instruction<F>,
103 data: &mut [u8],
104 ) -> Result<Handler<F, Ctx>, StaticProgramError>
105 where
106 Ctx: ExecutionCtxTrait,
107 {
108 let data: &mut ShiftPreCompute = data.borrow_mut();
109 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
110 dispatch!(execute_e1_handler, is_imm, shift_opcode)
112 }
113}
114
115#[cfg(feature = "aot")]
116impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> AotExecutor<F>
117 for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
118where
119 F: PrimeField32,
120{
121 fn is_aot_supported(&self, _instruction: &Instruction<F>) -> bool {
122 true
123 }
124 fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
125 let to_i16 = |c: F| -> i16 {
126 let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
127 let c_i24 = ((c_u24 << 8) as i32) >> 8;
128 c_i24 as i16
129 };
130 let mut asm_str = String::new();
131 let a: i16 = to_i16(inst.a);
132 let b: i16 = to_i16(inst.b);
133 let c: i16 = to_i16(inst.c);
134 let e: i16 = to_i16(inst.e);
135 assert!(a % 4 == 0, "instruction.a must be a multiple of 4");
136 assert!(b % 4 == 0, "instruction.b must be a multiple of 4");
137
138 let str_reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
144 RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
145 } else {
146 REG_A_W
147 };
148
149 if e == 0 {
150 let mut asm_opcode = String::new();
152 if inst.opcode == ShiftOpcode::SLL.global_opcode() {
153 asm_opcode += "shl";
154 } else if inst.opcode == ShiftOpcode::SRL.global_opcode() {
155 asm_opcode += "shr";
156 } else if inst.opcode == ShiftOpcode::SRA.global_opcode() {
157 asm_opcode += "sar";
158 }
159
160 let (reg_b, delta_str_b) = &xmm_to_gpr((b / 4) as u8, str_reg_a, true);
161 asm_str += delta_str_b;
162 asm_str += &format!(" {asm_opcode} {reg_b}, {c}\n");
163 asm_str += &gpr_to_xmm(reg_b, (a / 4) as u8);
164 } else {
165 let mut asm_opcode = String::new();
167 if inst.opcode == ShiftOpcode::SLL.global_opcode() {
168 asm_opcode += "shlx";
169 } else if inst.opcode == ShiftOpcode::SRL.global_opcode() {
170 asm_opcode += "shrx";
171 } else if inst.opcode == ShiftOpcode::SRA.global_opcode() {
172 asm_opcode += "sarx";
173 }
174
175 let (reg_b, delta_str_b) = &xmm_to_gpr((b / 4) as u8, REG_B_W, false);
176 asm_str += delta_str_b;
178
179 let (reg_c, delta_str_c) = &xmm_to_gpr((c / 4) as u8, REG_C_W, false);
180 asm_str += delta_str_c;
181
182 asm_str += &format!(" {asm_opcode} {str_reg_a}, {reg_b}, {reg_c}\n");
183
184 asm_str += &gpr_to_xmm(str_reg_a, (a / 4) as u8);
185 }
186
187 Ok(asm_str)
189 }
190}
191
192impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
193 for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
194where
195 F: PrimeField32,
196{
197 fn metered_pre_compute_size(&self) -> usize {
198 size_of::<E2PreCompute<ShiftPreCompute>>()
199 }
200
201 #[cfg(not(feature = "tco"))]
202 fn metered_pre_compute<Ctx: MeteredExecutionCtxTrait>(
203 &self,
204 chip_idx: usize,
205 pc: u32,
206 inst: &Instruction<F>,
207 data: &mut [u8],
208 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
209 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
210 data.chip_idx = chip_idx as u32;
211 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
212 dispatch!(execute_e2_handler, is_imm, shift_opcode)
214 }
215
216 #[cfg(feature = "tco")]
217 fn metered_handler<Ctx: MeteredExecutionCtxTrait>(
218 &self,
219 chip_idx: usize,
220 pc: u32,
221 inst: &Instruction<F>,
222 data: &mut [u8],
223 ) -> Result<Handler<F, Ctx>, StaticProgramError> {
224 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
225 data.chip_idx = chip_idx as u32;
226 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
227 dispatch!(execute_e2_handler, is_imm, shift_opcode)
229 }
230}
231
232#[cfg(feature = "aot")]
233impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> AotMeteredExecutor<F>
234 for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
235where
236 F: PrimeField32,
237{
238 fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
239 true
240 }
241 fn generate_x86_metered_asm(
242 &self,
243 inst: &Instruction<F>,
244 pc: u32,
245 chip_idx: usize,
246 config: &SystemConfig,
247 ) -> Result<String, AotError> {
248 let is_imm = inst.e.as_canonical_u32() == RV32_IMM_AS;
249 let mut asm_str = self.generate_x86_asm(inst, pc)?;
250 asm_str += &update_height_change_asm(chip_idx, 1)?;
251 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
253 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
255 if !is_imm {
256 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
258 }
259 Ok(asm_str)
260 }
261}
262#[inline(always)]
263unsafe fn execute_e12_impl<
264 F: PrimeField32,
265 CTX: ExecutionCtxTrait,
266 const IS_IMM: bool,
267 OP: ShiftOp,
268>(
269 pre_compute: &ShiftPreCompute,
270 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
271) {
272 let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
273 let rs2 = if IS_IMM {
274 pre_compute.c.to_le_bytes()
275 } else {
276 exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
277 };
278 let rs2 = u32::from_le_bytes(rs2);
279
280 let rd = <OP as ShiftOp>::compute(rs1, rs2);
282 exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
284
285 let pc = exec_state.pc();
286 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
287}
288
289#[create_handler]
290#[inline(always)]
291unsafe fn execute_e1_impl<
292 F: PrimeField32,
293 CTX: ExecutionCtxTrait,
294 const IS_IMM: bool,
295 OP: ShiftOp,
296>(
297 pre_compute: *const u8,
298 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
299) {
300 let pre_compute: &ShiftPreCompute =
301 std::slice::from_raw_parts(pre_compute, size_of::<ShiftPreCompute>()).borrow();
302 execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, exec_state);
303}
304
305#[create_handler]
306#[inline(always)]
307unsafe fn execute_e2_impl<
308 F: PrimeField32,
309 CTX: MeteredExecutionCtxTrait,
310 const IS_IMM: bool,
311 OP: ShiftOp,
312>(
313 pre_compute: *const u8,
314 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
315) {
316 let pre_compute: &E2PreCompute<ShiftPreCompute> =
317 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<ShiftPreCompute>>())
318 .borrow();
319 exec_state
320 .ctx
321 .on_height_change(pre_compute.chip_idx as usize, 1);
322 execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, exec_state);
323}
324
325trait ShiftOp {
326 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4];
327}
328struct SllOp;
329struct SrlOp;
330struct SraOp;
331impl ShiftOp for SllOp {
332 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
333 let rs1 = u32::from_le_bytes(rs1);
334 (rs1 << (rs2 & 0x1F)).to_le_bytes()
336 }
337}
338impl ShiftOp for SrlOp {
339 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
340 let rs1 = u32::from_le_bytes(rs1);
341 (rs1 >> (rs2 & 0x1F)).to_le_bytes()
343 }
344}
345impl ShiftOp for SraOp {
346 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
347 let rs1 = i32::from_le_bytes(rs1);
348 (rs1 >> (rs2 & 0x1F)).to_le_bytes()
350 }
351}