1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::ShiftOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use super::ShiftExecutor;
18use crate::adapters::imm_to_bytes;
19
20#[derive(AlignedBytesBorrow, Clone)]
21#[repr(C)]
22struct ShiftPreCompute {
23 c: u32,
24 a: u8,
25 b: u8,
26}
27
28impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftExecutor<A, NUM_LIMBS, LIMB_BITS> {
29 #[inline(always)]
30 fn pre_compute_impl<F: PrimeField32>(
31 &self,
32 pc: u32,
33 inst: &Instruction<F>,
34 data: &mut ShiftPreCompute,
35 ) -> Result<(bool, ShiftOpcode), StaticProgramError> {
36 let Instruction {
37 opcode, a, b, c, e, ..
38 } = inst;
39 let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset));
40 let e_u32 = e.as_canonical_u32();
41 if inst.d.as_canonical_u32() != RV32_REGISTER_AS
42 || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
43 {
44 return Err(StaticProgramError::InvalidInstruction(pc));
45 }
46 let is_imm = e_u32 == RV32_IMM_AS;
47 let c_u32 = c.as_canonical_u32();
48 *data = ShiftPreCompute {
49 c: if is_imm {
50 u32::from_le_bytes(imm_to_bytes(c_u32))
51 } else {
52 c_u32
53 },
54 a: a.as_canonical_u32() as u8,
55 b: b.as_canonical_u32() as u8,
56 };
57 Ok((is_imm, shift_opcode))
59 }
60}
61
62macro_rules! dispatch {
63 ($execute_impl:ident, $is_imm:ident, $shift_opcode:ident) => {
64 match ($is_imm, $shift_opcode) {
65 (true, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, true, SllOp>),
66 (false, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, false, SllOp>),
67 (true, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, true, SrlOp>),
68 (false, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, false, SrlOp>),
69 (true, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, true, SraOp>),
70 (false, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, false, SraOp>),
71 }
72 };
73}
74
75impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> Executor<F>
76 for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
77where
78 F: PrimeField32,
79{
80 fn pre_compute_size(&self) -> usize {
81 size_of::<ShiftPreCompute>()
82 }
83
84 fn pre_compute<Ctx: ExecutionCtxTrait>(
85 &self,
86 pc: u32,
87 inst: &Instruction<F>,
88 data: &mut [u8],
89 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
90 let data: &mut ShiftPreCompute = data.borrow_mut();
91 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
92 dispatch!(execute_e1_impl, is_imm, shift_opcode)
94 }
95
96 #[cfg(feature = "tco")]
97 fn handler<Ctx>(
98 &self,
99 pc: u32,
100 inst: &Instruction<F>,
101 data: &mut [u8],
102 ) -> Result<Handler<F, Ctx>, StaticProgramError>
103 where
104 Ctx: ExecutionCtxTrait,
105 {
106 let data: &mut ShiftPreCompute = data.borrow_mut();
107 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
108 dispatch!(execute_e1_tco_handler, is_imm, shift_opcode)
110 }
111}
112
113impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> MeteredExecutor<F>
114 for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
115where
116 F: PrimeField32,
117{
118 fn metered_pre_compute_size(&self) -> usize {
119 size_of::<E2PreCompute<ShiftPreCompute>>()
120 }
121
122 fn metered_pre_compute<Ctx: MeteredExecutionCtxTrait>(
123 &self,
124 chip_idx: usize,
125 pc: u32,
126 inst: &Instruction<F>,
127 data: &mut [u8],
128 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
129 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
130 data.chip_idx = chip_idx as u32;
131 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
132 dispatch!(execute_e2_impl, is_imm, shift_opcode)
134 }
135
136 #[cfg(feature = "tco")]
137 fn metered_handler<Ctx: MeteredExecutionCtxTrait>(
138 &self,
139 chip_idx: usize,
140 pc: u32,
141 inst: &Instruction<F>,
142 data: &mut [u8],
143 ) -> Result<Handler<F, Ctx>, StaticProgramError> {
144 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
145 data.chip_idx = chip_idx as u32;
146 let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
147 dispatch!(execute_e2_tco_handler, is_imm, shift_opcode)
149 }
150}
151
152unsafe fn execute_e12_impl<
153 F: PrimeField32,
154 CTX: ExecutionCtxTrait,
155 const IS_IMM: bool,
156 OP: ShiftOp,
157>(
158 pre_compute: &ShiftPreCompute,
159 state: &mut VmExecState<F, GuestMemory, CTX>,
160) {
161 let rs1 = state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
162 let rs2 = if IS_IMM {
163 pre_compute.c.to_le_bytes()
164 } else {
165 state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
166 };
167 let rs2 = u32::from_le_bytes(rs2);
168
169 let rd = <OP as ShiftOp>::compute(rs1, rs2);
171 state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
173
174 state.instret += 1;
175 state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
176}
177
178#[create_tco_handler]
179unsafe fn execute_e1_impl<
180 F: PrimeField32,
181 CTX: ExecutionCtxTrait,
182 const IS_IMM: bool,
183 OP: ShiftOp,
184>(
185 pre_compute: &[u8],
186 state: &mut VmExecState<F, GuestMemory, CTX>,
187) {
188 let pre_compute: &ShiftPreCompute = pre_compute.borrow();
189 execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, state);
190}
191
192#[create_tco_handler]
193unsafe fn execute_e2_impl<
194 F: PrimeField32,
195 CTX: MeteredExecutionCtxTrait,
196 const IS_IMM: bool,
197 OP: ShiftOp,
198>(
199 pre_compute: &[u8],
200 state: &mut VmExecState<F, GuestMemory, CTX>,
201) {
202 let pre_compute: &E2PreCompute<ShiftPreCompute> = pre_compute.borrow();
203 state.ctx.on_height_change(pre_compute.chip_idx as usize, 1);
204 execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, state);
205}
206
207trait ShiftOp {
208 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4];
209}
210struct SllOp;
211struct SrlOp;
212struct SraOp;
213impl ShiftOp for SllOp {
214 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
215 let rs1 = u32::from_le_bytes(rs1);
216 (rs1 << (rs2 & 0x1F)).to_le_bytes()
218 }
219}
220impl ShiftOp for SrlOp {
221 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
222 let rs1 = u32::from_le_bytes(rs1);
223 (rs1 >> (rs2 & 0x1F)).to_le_bytes()
225 }
226}
227impl ShiftOp for SraOp {
228 fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
229 let rs1 = i32::from_le_bytes(rs1);
230 (rs1 >> (rs2 & 0x1F)).to_le_bytes()
232 }
233}