1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32Shift256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10 instruction::Instruction,
11 program::DEFAULT_PC_STEP,
12 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13 LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
16use openvm_rv32im_circuit::ShiftExecutor;
17use openvm_rv32im_transpiler::ShiftOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{
21 common::{bytes_to_u64_array, u64_array_to_bytes},
22 Rv32Shift256Executor, INT256_NUM_LIMBS,
23};
24
25type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
26
27impl Rv32Shift256Executor {
28 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
29 Self(ShiftExecutor::new(adapter, offset))
30 }
31}
32
33#[derive(AlignedBytesBorrow, Clone)]
34#[repr(C)]
35struct ShiftPreCompute {
36 a: u8,
37 b: u8,
38 c: u8,
39}
40
41macro_rules! dispatch {
42 ($execute_impl:ident, $local_opcode:ident) => {
43 Ok(match $local_opcode {
44 ShiftOpcode::SLL => $execute_impl::<_, _, SllOp>,
45 ShiftOpcode::SRA => $execute_impl::<_, _, SraOp>,
46 ShiftOpcode::SRL => $execute_impl::<_, _, SrlOp>,
47 })
48 };
49}
50
51impl<F: PrimeField32> Executor<F> for Rv32Shift256Executor {
52 fn pre_compute_size(&self) -> usize {
53 size_of::<ShiftPreCompute>()
54 }
55
56 #[cfg(not(feature = "tco"))]
57 fn pre_compute<Ctx>(
58 &self,
59 pc: u32,
60 inst: &Instruction<F>,
61 data: &mut [u8],
62 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
63 where
64 Ctx: ExecutionCtxTrait,
65 {
66 let data: &mut ShiftPreCompute = data.borrow_mut();
67 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
68 dispatch!(execute_e1_handler, local_opcode)
69 }
70
71 #[cfg(feature = "tco")]
72 fn handler<Ctx>(
73 &self,
74 pc: u32,
75 inst: &Instruction<F>,
76 data: &mut [u8],
77 ) -> Result<Handler<F, Ctx>, StaticProgramError>
78 where
79 Ctx: ExecutionCtxTrait,
80 {
81 let data: &mut ShiftPreCompute = data.borrow_mut();
82 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
83 dispatch!(execute_e1_handler, local_opcode)
84 }
85}
86
87impl<F: PrimeField32> MeteredExecutor<F> for Rv32Shift256Executor {
88 fn metered_pre_compute_size(&self) -> usize {
89 size_of::<E2PreCompute<ShiftPreCompute>>()
90 }
91
92 #[cfg(not(feature = "tco"))]
93 fn metered_pre_compute<Ctx>(
94 &self,
95 chip_idx: usize,
96 pc: u32,
97 inst: &Instruction<F>,
98 data: &mut [u8],
99 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
100 where
101 Ctx: MeteredExecutionCtxTrait,
102 {
103 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
104 data.chip_idx = chip_idx as u32;
105 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
106 dispatch!(execute_e2_handler, local_opcode)
107 }
108
109 #[cfg(feature = "tco")]
110 fn metered_handler<Ctx>(
111 &self,
112 chip_idx: usize,
113 pc: u32,
114 inst: &Instruction<F>,
115 data: &mut [u8],
116 ) -> Result<Handler<F, Ctx>, StaticProgramError>
117 where
118 Ctx: MeteredExecutionCtxTrait,
119 {
120 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
121 data.chip_idx = chip_idx as u32;
122 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
123 dispatch!(execute_e2_handler, local_opcode)
124 }
125}
126
127#[inline(always)]
128unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
129 pre_compute: &ShiftPreCompute,
130 instret: &mut u64,
131 pc: &mut u32,
132 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
133) {
134 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
135 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
136 let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
137 let rs1 =
138 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
139 let rs2 =
140 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
141 let rd = OP::compute(rs1, rs2);
142 exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
143 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
144 *instret += 1;
145}
146
147#[create_handler]
148#[inline(always)]
149unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
150 pre_compute: &[u8],
151 instret: &mut u64,
152 pc: &mut u32,
153 _instret_end: u64,
154 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
155) {
156 let pre_compute: &ShiftPreCompute = pre_compute.borrow();
157 execute_e12_impl::<F, CTX, OP>(pre_compute, instret, pc, exec_state);
158}
159
160#[create_handler]
161#[inline(always)]
162unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: ShiftOp>(
163 pre_compute: &[u8],
164 instret: &mut u64,
165 pc: &mut u32,
166 _arg: u64,
167 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
168) {
169 let pre_compute: &E2PreCompute<ShiftPreCompute> = pre_compute.borrow();
170 exec_state
171 .ctx
172 .on_height_change(pre_compute.chip_idx as usize, 1);
173 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, instret, pc, exec_state);
174}
175
176impl Rv32Shift256Executor {
177 fn pre_compute_impl<F: PrimeField32>(
178 &self,
179 pc: u32,
180 inst: &Instruction<F>,
181 data: &mut ShiftPreCompute,
182 ) -> Result<ShiftOpcode, StaticProgramError> {
183 let Instruction {
184 opcode,
185 a,
186 b,
187 c,
188 d,
189 e,
190 ..
191 } = inst;
192 let e_u32 = e.as_canonical_u32();
193 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
194 return Err(StaticProgramError::InvalidInstruction(pc));
195 }
196 *data = ShiftPreCompute {
197 a: a.as_canonical_u32() as u8,
198 b: b.as_canonical_u32() as u8,
199 c: c.as_canonical_u32() as u8,
200 };
201 let local_opcode =
202 ShiftOpcode::from_usize(opcode.local_opcode_idx(Rv32Shift256Opcode::CLASS_OFFSET));
203 Ok(local_opcode)
204 }
205}
206
207trait ShiftOp {
208 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS];
209}
210struct SllOp;
211struct SrlOp;
212struct SraOp;
213impl ShiftOp for SllOp {
214 #[inline(always)]
215 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
216 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
217 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
218 let mut rd = [0u64; 4];
219 let shift = (rs2_u64[0] & 0xff) as u32;
221 let index_offset = (shift / u64::BITS) as usize;
222 let bit_offset = shift % u64::BITS;
223 let mut carry = 0u64;
224 for i in index_offset..4 {
225 let curr = rs1_u64[i - index_offset];
226 rd[i] = (curr << bit_offset) + carry;
227 if bit_offset > 0 {
228 carry = curr >> (u64::BITS - bit_offset);
229 }
230 }
231 u64_array_to_bytes(rd)
232 }
233}
234impl ShiftOp for SrlOp {
235 #[inline(always)]
236 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
237 shift_right(rs1, rs2, 0)
239 }
240}
241impl ShiftOp for SraOp {
242 #[inline(always)]
243 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
244 if rs1[INT256_NUM_LIMBS - 1] & 0x80 > 0 {
246 shift_right(rs1, rs2, u64::MAX)
247 } else {
248 shift_right(rs1, rs2, 0)
249 }
250 }
251}
252
253#[inline(always)]
254fn shift_right(
255 rs1: [u8; INT256_NUM_LIMBS],
256 rs2: [u8; INT256_NUM_LIMBS],
257 init_value: u64,
258) -> [u8; INT256_NUM_LIMBS] {
259 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
260 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
261 let mut rd = [init_value; 4];
262 let shift = (rs2_u64[0] & 0xff) as u32;
263 let index_offset = (shift / u64::BITS) as usize;
264 let bit_offset = shift % u64::BITS;
265 let mut carry = if bit_offset > 0 {
266 init_value << (u64::BITS - bit_offset)
267 } else {
268 0
269 };
270 for i in (index_offset..4).rev() {
271 let curr = rs1_u64[i];
272 rd[i - index_offset] = (curr >> bit_offset) + carry;
273 if bit_offset > 0 {
274 carry = curr << (u64::BITS - bit_offset);
275 }
276 }
277 u64_array_to_bytes(rd)
278}
279
280#[cfg(test)]
281mod tests {
282 use alloy_primitives::U256;
283 use rand::{prelude::StdRng, Rng, SeedableRng};
284
285 use crate::{
286 shift::{ShiftOp, SllOp, SraOp, SrlOp},
287 INT256_NUM_LIMBS,
288 };
289
290 #[test]
291 fn test_shift_op() {
292 let mut rng = StdRng::from_seed([42; 32]);
293 for _ in 0..10000 {
294 let limbs_a: [u8; INT256_NUM_LIMBS] = rng.gen();
295 let mut limbs_b: [u8; INT256_NUM_LIMBS] = [0; INT256_NUM_LIMBS];
296 let shift: u8 = rng.gen();
297 limbs_b[0] = shift;
298 let a = U256::from_le_bytes(limbs_a);
299 {
300 let res = SllOp::compute(limbs_a, limbs_b);
301 assert_eq!(U256::from_le_bytes(res), a << shift);
302 }
303 {
304 let res = SraOp::compute(limbs_a, limbs_b);
305 assert_eq!(U256::from_le_bytes(res), a.arithmetic_shr(shift as usize));
306 }
307 {
308 let res = SrlOp::compute(limbs_a, limbs_b);
309 assert_eq!(U256::from_le_bytes(res), a >> shift);
310 }
311 }
312 }
313}