1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32Shift256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10 instruction::Instruction,
11 program::DEFAULT_PC_STEP,
12 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13 LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
16use openvm_rv32im_circuit::ShiftExecutor;
17use openvm_rv32im_transpiler::ShiftOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{
21 common::{bytes_to_u64_array, u64_array_to_bytes},
22 Rv32Shift256Executor, INT256_NUM_LIMBS,
23};
24
25type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
26
27impl Rv32Shift256Executor {
28 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
29 Self(ShiftExecutor::new(adapter, offset))
30 }
31}
32
33#[derive(AlignedBytesBorrow, Clone)]
34#[repr(C)]
35struct ShiftPreCompute {
36 a: u8,
37 b: u8,
38 c: u8,
39}
40
41macro_rules! dispatch {
42 ($execute_impl:ident, $local_opcode:ident) => {
43 Ok(match $local_opcode {
44 ShiftOpcode::SLL => $execute_impl::<_, _, SllOp>,
45 ShiftOpcode::SRA => $execute_impl::<_, _, SraOp>,
46 ShiftOpcode::SRL => $execute_impl::<_, _, SrlOp>,
47 })
48 };
49}
50
51impl<F: PrimeField32> InterpreterExecutor<F> for Rv32Shift256Executor {
52 fn pre_compute_size(&self) -> usize {
53 size_of::<ShiftPreCompute>()
54 }
55
56 #[cfg(not(feature = "tco"))]
57 fn pre_compute<Ctx>(
58 &self,
59 pc: u32,
60 inst: &Instruction<F>,
61 data: &mut [u8],
62 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
63 where
64 Ctx: ExecutionCtxTrait,
65 {
66 let data: &mut ShiftPreCompute = data.borrow_mut();
67 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
68 dispatch!(execute_e1_handler, local_opcode)
69 }
70
71 #[cfg(feature = "tco")]
72 fn handler<Ctx>(
73 &self,
74 pc: u32,
75 inst: &Instruction<F>,
76 data: &mut [u8],
77 ) -> Result<Handler<F, Ctx>, StaticProgramError>
78 where
79 Ctx: ExecutionCtxTrait,
80 {
81 let data: &mut ShiftPreCompute = data.borrow_mut();
82 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
83 dispatch!(execute_e1_handler, local_opcode)
84 }
85}
86
87#[cfg(feature = "aot")]
88impl<F: PrimeField32> AotExecutor<F> for Rv32Shift256Executor {}
89
90impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32Shift256Executor {
91 fn metered_pre_compute_size(&self) -> usize {
92 size_of::<E2PreCompute<ShiftPreCompute>>()
93 }
94
95 #[cfg(not(feature = "tco"))]
96 fn metered_pre_compute<Ctx>(
97 &self,
98 chip_idx: usize,
99 pc: u32,
100 inst: &Instruction<F>,
101 data: &mut [u8],
102 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
103 where
104 Ctx: MeteredExecutionCtxTrait,
105 {
106 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
107 data.chip_idx = chip_idx as u32;
108 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
109 dispatch!(execute_e2_handler, local_opcode)
110 }
111
112 #[cfg(feature = "tco")]
113 fn metered_handler<Ctx>(
114 &self,
115 chip_idx: usize,
116 pc: u32,
117 inst: &Instruction<F>,
118 data: &mut [u8],
119 ) -> Result<Handler<F, Ctx>, StaticProgramError>
120 where
121 Ctx: MeteredExecutionCtxTrait,
122 {
123 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
124 data.chip_idx = chip_idx as u32;
125 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
126 dispatch!(execute_e2_handler, local_opcode)
127 }
128}
129
130#[cfg(feature = "aot")]
131impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32Shift256Executor {}
132
133#[inline(always)]
134unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
135 pre_compute: &ShiftPreCompute,
136 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
137) {
138 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
139 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
140 let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
141 let rs1 =
142 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
143 let rs2 =
144 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
145 let rd = OP::compute(rs1, rs2);
146 exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
147 let pc = exec_state.pc();
148 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
149}
150
151#[create_handler]
152#[inline(always)]
153unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
154 pre_compute: *const u8,
155 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
156) {
157 let pre_compute: &ShiftPreCompute =
158 std::slice::from_raw_parts(pre_compute, size_of::<ShiftPreCompute>()).borrow();
159 execute_e12_impl::<F, CTX, OP>(pre_compute, exec_state);
160}
161
162#[create_handler]
163#[inline(always)]
164unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: ShiftOp>(
165 pre_compute: *const u8,
166 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
167) {
168 let pre_compute: &E2PreCompute<ShiftPreCompute> =
169 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<ShiftPreCompute>>())
170 .borrow();
171 exec_state
172 .ctx
173 .on_height_change(pre_compute.chip_idx as usize, 1);
174 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, exec_state);
175}
176
177impl Rv32Shift256Executor {
178 fn pre_compute_impl<F: PrimeField32>(
179 &self,
180 pc: u32,
181 inst: &Instruction<F>,
182 data: &mut ShiftPreCompute,
183 ) -> Result<ShiftOpcode, StaticProgramError> {
184 let Instruction {
185 opcode,
186 a,
187 b,
188 c,
189 d,
190 e,
191 ..
192 } = inst;
193 let e_u32 = e.as_canonical_u32();
194 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
195 return Err(StaticProgramError::InvalidInstruction(pc));
196 }
197 *data = ShiftPreCompute {
198 a: a.as_canonical_u32() as u8,
199 b: b.as_canonical_u32() as u8,
200 c: c.as_canonical_u32() as u8,
201 };
202 let local_opcode =
203 ShiftOpcode::from_usize(opcode.local_opcode_idx(Rv32Shift256Opcode::CLASS_OFFSET));
204 Ok(local_opcode)
205 }
206}
207
208trait ShiftOp {
209 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS];
210}
211struct SllOp;
212struct SrlOp;
213struct SraOp;
214impl ShiftOp for SllOp {
215 #[inline(always)]
216 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
217 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
218 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
219 let mut rd = [0u64; 4];
220 let shift = (rs2_u64[0] & 0xff) as u32;
222 let index_offset = (shift / u64::BITS) as usize;
223 let bit_offset = shift % u64::BITS;
224 let mut carry = 0u64;
225 for i in index_offset..4 {
226 let curr = rs1_u64[i - index_offset];
227 rd[i] = (curr << bit_offset) + carry;
228 if bit_offset > 0 {
229 carry = curr >> (u64::BITS - bit_offset);
230 }
231 }
232 u64_array_to_bytes(rd)
233 }
234}
235impl ShiftOp for SrlOp {
236 #[inline(always)]
237 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
238 shift_right(rs1, rs2, 0)
240 }
241}
242impl ShiftOp for SraOp {
243 #[inline(always)]
244 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
245 if rs1[INT256_NUM_LIMBS - 1] & 0x80 > 0 {
247 shift_right(rs1, rs2, u64::MAX)
248 } else {
249 shift_right(rs1, rs2, 0)
250 }
251 }
252}
253
254#[inline(always)]
255fn shift_right(
256 rs1: [u8; INT256_NUM_LIMBS],
257 rs2: [u8; INT256_NUM_LIMBS],
258 init_value: u64,
259) -> [u8; INT256_NUM_LIMBS] {
260 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
261 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
262 let mut rd = [init_value; 4];
263 let shift = (rs2_u64[0] & 0xff) as u32;
264 let index_offset = (shift / u64::BITS) as usize;
265 let bit_offset = shift % u64::BITS;
266 let mut carry = if bit_offset > 0 {
267 init_value << (u64::BITS - bit_offset)
268 } else {
269 0
270 };
271 for i in (index_offset..4).rev() {
272 let curr = rs1_u64[i];
273 rd[i - index_offset] = (curr >> bit_offset) + carry;
274 if bit_offset > 0 {
275 carry = curr << (u64::BITS - bit_offset);
276 }
277 }
278 u64_array_to_bytes(rd)
279}
280
281#[cfg(test)]
282mod tests {
283 use alloy_primitives::U256;
284 use rand::{prelude::StdRng, Rng, SeedableRng};
285
286 use crate::{
287 shift::{ShiftOp, SllOp, SraOp, SrlOp},
288 INT256_NUM_LIMBS,
289 };
290
291 #[test]
292 fn test_shift_op() {
293 let mut rng = StdRng::from_seed([42; 32]);
294 for _ in 0..10000 {
295 let limbs_a: [u8; INT256_NUM_LIMBS] = rng.gen();
296 let mut limbs_b: [u8; INT256_NUM_LIMBS] = [0; INT256_NUM_LIMBS];
297 let shift: u8 = rng.gen();
298 limbs_b[0] = shift;
299 let a = U256::from_le_bytes(limbs_a);
300 {
301 let res = SllOp::compute(limbs_a, limbs_b);
302 assert_eq!(U256::from_le_bytes(res), a << shift);
303 }
304 {
305 let res = SraOp::compute(limbs_a, limbs_b);
306 assert_eq!(U256::from_le_bytes(res), a.arithmetic_shr(shift as usize));
307 }
308 {
309 let res = SrlOp::compute(limbs_a, limbs_b);
310 assert_eq!(U256::from_le_bytes(res), a >> shift);
311 }
312 }
313 }
314}