1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::ShiftOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use super::ShiftExecutor;
18use crate::adapters::imm_to_bytes;
19
20#[derive(AlignedBytesBorrow, Clone)]
21#[repr(C)]
22struct ShiftPreCompute {
23 c: u32,
24 a: u8,
25 b: u8,
26}
27
28impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftExecutor<A, NUM_LIMBS, LIMB_BITS> {
29 #[inline(always)]
30 fn pre_compute_impl<F: PrimeField32>(
31 &self,
32 pc: u32,
33 inst: &Instruction<F>,
34 data: &mut ShiftPreCompute,
35 ) -> Result<(bool, ShiftOpcode), StaticProgramError> {
36 let Instruction {
37 opcode, a, b, c, e, ..
38 } = inst;
39 let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset));
40 let e_u32 = e.as_canonical_u32();
41 if inst.d.as_canonical_u32() != RV32_REGISTER_AS
42 || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
43 {
44 return Err(StaticProgramError::InvalidInstruction(pc));
45 }
46 let is_imm = e_u32 == RV32_IMM_AS;
47 let c_u32 = c.as_canonical_u32();
48 *data = ShiftPreCompute {
49 c: if is_imm {
50 u32::from_le_bytes(imm_to_bytes(c_u32))
51 } else {
52 c_u32
53 },
54 a: a.as_canonical_u32() as u8,
55 b: b.as_canonical_u32() as u8,
56 };
57 Ok((is_imm, shift_opcode))
59 }
60}
61
62macro_rules! dispatch {
63 ($execute_impl:ident, $is_imm:ident, $shift_opcode:ident) => {
64 match ($is_imm, $shift_opcode) {
65 (true, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, true, SllOp>),
66 (false, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, false, SllOp>),
67 (true, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, true, SrlOp>),
68 (false, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, false, SrlOp>),
69 (true, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, true, SraOp>),
70 (false, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, false, SraOp>),
71 }
72 };
73}
74
75impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> Executor<F>
76 for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
77where
78 F: PrimeField32,
79{
80 fn pre_compute_size(&self) -> usize {
81 size_of::<ShiftPreCompute>()
82 }
83
84 #[cfg(not(feature = "tco"))]
85 fn pre_compute<Ctx: ExecutionCtxTrait>(
86 &self,
87 pc: u32,
88 inst: &Instruction<F>,
89 data: &mut [u8],
90 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
91 let data: &mut ShiftPreCompute = data.borrow_mut();
92 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
93 dispatch!(execute_e1_handler, is_imm, shift_opcode)
95 }
96
97 #[cfg(feature = "tco")]
98 fn handler<Ctx>(
99 &self,
100 pc: u32,
101 inst: &Instruction<F>,
102 data: &mut [u8],
103 ) -> Result<Handler<F, Ctx>, StaticProgramError>
104 where
105 Ctx: ExecutionCtxTrait,
106 {
107 let data: &mut ShiftPreCompute = data.borrow_mut();
108 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
109 dispatch!(execute_e1_handler, is_imm, shift_opcode)
111 }
112}
113
114impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> MeteredExecutor<F>
115 for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
116where
117 F: PrimeField32,
118{
119 fn metered_pre_compute_size(&self) -> usize {
120 size_of::<E2PreCompute<ShiftPreCompute>>()
121 }
122
123 #[cfg(not(feature = "tco"))]
124 fn metered_pre_compute<Ctx: MeteredExecutionCtxTrait>(
125 &self,
126 chip_idx: usize,
127 pc: u32,
128 inst: &Instruction<F>,
129 data: &mut [u8],
130 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
131 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
132 data.chip_idx = chip_idx as u32;
133 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
134 dispatch!(execute_e2_handler, is_imm, shift_opcode)
136 }
137
138 #[cfg(feature = "tco")]
139 fn metered_handler<Ctx: MeteredExecutionCtxTrait>(
140 &self,
141 chip_idx: usize,
142 pc: u32,
143 inst: &Instruction<F>,
144 data: &mut [u8],
145 ) -> Result<Handler<F, Ctx>, StaticProgramError> {
146 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
147 data.chip_idx = chip_idx as u32;
148 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
149 dispatch!(execute_e2_handler, is_imm, shift_opcode)
151 }
152}
153
154#[inline(always)]
155unsafe fn execute_e12_impl<
156 F: PrimeField32,
157 CTX: ExecutionCtxTrait,
158 const IS_IMM: bool,
159 OP: ShiftOp,
160>(
161 pre_compute: &ShiftPreCompute,
162 instret: &mut u64,
163 pc: &mut u32,
164 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
165) {
166 let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
167 let rs2 = if IS_IMM {
168 pre_compute.c.to_le_bytes()
169 } else {
170 exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
171 };
172 let rs2 = u32::from_le_bytes(rs2);
173
174 let rd = <OP as ShiftOp>::compute(rs1, rs2);
176 exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
178
179 *instret += 1;
180 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
181}
182
183#[create_handler]
184#[inline(always)]
185unsafe fn execute_e1_impl<
186 F: PrimeField32,
187 CTX: ExecutionCtxTrait,
188 const IS_IMM: bool,
189 OP: ShiftOp,
190>(
191 pre_compute: &[u8],
192 instret: &mut u64,
193 pc: &mut u32,
194 _instret_end: u64,
195 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
196) {
197 let pre_compute: &ShiftPreCompute = pre_compute.borrow();
198 execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, instret, pc, exec_state);
199}
200
201#[create_handler]
202#[inline(always)]
203unsafe fn execute_e2_impl<
204 F: PrimeField32,
205 CTX: MeteredExecutionCtxTrait,
206 const IS_IMM: bool,
207 OP: ShiftOp,
208>(
209 pre_compute: &[u8],
210 instret: &mut u64,
211 pc: &mut u32,
212 _arg: u64,
213 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
214) {
215 let pre_compute: &E2PreCompute<ShiftPreCompute> = pre_compute.borrow();
216 exec_state
217 .ctx
218 .on_height_change(pre_compute.chip_idx as usize, 1);
219 execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, instret, pc, exec_state);
220}
221
222trait ShiftOp {
223 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4];
224}
225struct SllOp;
226struct SrlOp;
227struct SraOp;
228impl ShiftOp for SllOp {
229 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
230 let rs1 = u32::from_le_bytes(rs1);
231 (rs1 << (rs2 & 0x1F)).to_le_bytes()
233 }
234}
235impl ShiftOp for SrlOp {
236 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
237 let rs1 = u32::from_le_bytes(rs1);
238 (rs1 >> (rs2 & 0x1F)).to_le_bytes()
240 }
241}
242impl ShiftOp for SraOp {
243 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
244 let rs1 = i32::from_le_bytes(rs1);
245 (rs1 >> (rs2 & 0x1F)).to_le_bytes()
247 }
248}