1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32Shift256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7 instruction::Instruction,
8 program::DEFAULT_PC_STEP,
9 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10 LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
13use openvm_rv32im_circuit::ShiftExecutor;
14use openvm_rv32im_transpiler::ShiftOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18 common::{bytes_to_u64_array, u64_array_to_bytes},
19 Rv32Shift256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32Shift256Executor {
25 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26 Self(ShiftExecutor::new(adapter, offset))
27 }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct ShiftPreCompute {
33 a: u8,
34 b: u8,
35 c: u8,
36}
37
38macro_rules! dispatch {
39 ($execute_impl:ident, $local_opcode:ident) => {
40 Ok(match $local_opcode {
41 ShiftOpcode::SLL => $execute_impl::<_, _, SllOp>,
42 ShiftOpcode::SRA => $execute_impl::<_, _, SraOp>,
43 ShiftOpcode::SRL => $execute_impl::<_, _, SrlOp>,
44 })
45 };
46}
47
48impl<F: PrimeField32> Executor<F> for Rv32Shift256Executor {
49 fn pre_compute_size(&self) -> usize {
50 size_of::<ShiftPreCompute>()
51 }
52
53 fn pre_compute<Ctx>(
54 &self,
55 pc: u32,
56 inst: &Instruction<F>,
57 data: &mut [u8],
58 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
59 where
60 Ctx: ExecutionCtxTrait,
61 {
62 let data: &mut ShiftPreCompute = data.borrow_mut();
63 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
64 dispatch!(execute_e1_impl, local_opcode)
65 }
66
67 #[cfg(feature = "tco")]
68 fn handler<Ctx>(
69 &self,
70 pc: u32,
71 inst: &Instruction<F>,
72 data: &mut [u8],
73 ) -> Result<Handler<F, Ctx>, StaticProgramError>
74 where
75 Ctx: ExecutionCtxTrait,
76 {
77 let data: &mut ShiftPreCompute = data.borrow_mut();
78 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
79 dispatch!(execute_e1_tco_handler, local_opcode)
80 }
81}
82
83impl<F: PrimeField32> MeteredExecutor<F> for Rv32Shift256Executor {
84 fn metered_pre_compute_size(&self) -> usize {
85 size_of::<E2PreCompute<ShiftPreCompute>>()
86 }
87
88 fn metered_pre_compute<Ctx>(
89 &self,
90 chip_idx: usize,
91 pc: u32,
92 inst: &Instruction<F>,
93 data: &mut [u8],
94 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
95 where
96 Ctx: MeteredExecutionCtxTrait,
97 {
98 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
99 data.chip_idx = chip_idx as u32;
100 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
101 dispatch!(execute_e2_impl, local_opcode)
102 }
103
104 #[cfg(feature = "tco")]
105 fn metered_handler<Ctx>(
106 &self,
107 chip_idx: usize,
108 pc: u32,
109 inst: &Instruction<F>,
110 data: &mut [u8],
111 ) -> Result<Handler<F, Ctx>, StaticProgramError>
112 where
113 Ctx: MeteredExecutionCtxTrait,
114 {
115 let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
116 data.chip_idx = chip_idx as u32;
117 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
118 dispatch!(execute_e2_tco_handler, local_opcode)
119 }
120}
121
122#[inline(always)]
123unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
124 pre_compute: &ShiftPreCompute,
125 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
126) {
127 let rs1_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
128 let rs2_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
129 let rd_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
130 let rs1 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
131 let rs2 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
132 let rd = OP::compute(rs1, rs2);
133 vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
134 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
135 vm_state.instret += 1;
136}
137
138#[create_tco_handler]
139unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
140 pre_compute: &[u8],
141 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
142) {
143 let pre_compute: &ShiftPreCompute = pre_compute.borrow();
144 execute_e12_impl::<F, CTX, OP>(pre_compute, vm_state);
145}
146#[create_tco_handler]
147unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: ShiftOp>(
148 pre_compute: &[u8],
149 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
150) {
151 let pre_compute: &E2PreCompute<ShiftPreCompute> = pre_compute.borrow();
152 vm_state
153 .ctx
154 .on_height_change(pre_compute.chip_idx as usize, 1);
155 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, vm_state);
156}
157
158impl Rv32Shift256Executor {
159 fn pre_compute_impl<F: PrimeField32>(
160 &self,
161 pc: u32,
162 inst: &Instruction<F>,
163 data: &mut ShiftPreCompute,
164 ) -> Result<ShiftOpcode, StaticProgramError> {
165 let Instruction {
166 opcode,
167 a,
168 b,
169 c,
170 d,
171 e,
172 ..
173 } = inst;
174 let e_u32 = e.as_canonical_u32();
175 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
176 return Err(StaticProgramError::InvalidInstruction(pc));
177 }
178 *data = ShiftPreCompute {
179 a: a.as_canonical_u32() as u8,
180 b: b.as_canonical_u32() as u8,
181 c: c.as_canonical_u32() as u8,
182 };
183 let local_opcode =
184 ShiftOpcode::from_usize(opcode.local_opcode_idx(Rv32Shift256Opcode::CLASS_OFFSET));
185 Ok(local_opcode)
186 }
187}
188
189trait ShiftOp {
190 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS];
191}
192struct SllOp;
193struct SrlOp;
194struct SraOp;
195impl ShiftOp for SllOp {
196 #[inline(always)]
197 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
198 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
199 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
200 let mut rd = [0u64; 4];
201 let shift = (rs2_u64[0] & 0xff) as u32;
203 let index_offset = (shift / u64::BITS) as usize;
204 let bit_offset = shift % u64::BITS;
205 let mut carry = 0u64;
206 for i in index_offset..4 {
207 let curr = rs1_u64[i - index_offset];
208 rd[i] = (curr << bit_offset) + carry;
209 if bit_offset > 0 {
210 carry = curr >> (u64::BITS - bit_offset);
211 }
212 }
213 u64_array_to_bytes(rd)
214 }
215}
216impl ShiftOp for SrlOp {
217 #[inline(always)]
218 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
219 shift_right(rs1, rs2, 0)
221 }
222}
223impl ShiftOp for SraOp {
224 #[inline(always)]
225 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
226 if rs1[INT256_NUM_LIMBS - 1] & 0x80 > 0 {
228 shift_right(rs1, rs2, u64::MAX)
229 } else {
230 shift_right(rs1, rs2, 0)
231 }
232 }
233}
234
235#[inline(always)]
236fn shift_right(
237 rs1: [u8; INT256_NUM_LIMBS],
238 rs2: [u8; INT256_NUM_LIMBS],
239 init_value: u64,
240) -> [u8; INT256_NUM_LIMBS] {
241 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
242 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
243 let mut rd = [init_value; 4];
244 let shift = (rs2_u64[0] & 0xff) as u32;
245 let index_offset = (shift / u64::BITS) as usize;
246 let bit_offset = shift % u64::BITS;
247 let mut carry = if bit_offset > 0 {
248 init_value << (u64::BITS - bit_offset)
249 } else {
250 0
251 };
252 for i in (index_offset..4).rev() {
253 let curr = rs1_u64[i];
254 rd[i - index_offset] = (curr >> bit_offset) + carry;
255 if bit_offset > 0 {
256 carry = curr << (u64::BITS - bit_offset);
257 }
258 }
259 u64_array_to_bytes(rd)
260}
261
262#[cfg(test)]
263mod tests {
264 use alloy_primitives::U256;
265 use rand::{prelude::StdRng, Rng, SeedableRng};
266
267 use crate::{
268 shift::{ShiftOp, SllOp, SraOp, SrlOp},
269 INT256_NUM_LIMBS,
270 };
271
272 #[test]
273 fn test_shift_op() {
274 let mut rng = StdRng::from_seed([42; 32]);
275 for _ in 0..10000 {
276 let limbs_a: [u8; INT256_NUM_LIMBS] = rng.gen();
277 let mut limbs_b: [u8; INT256_NUM_LIMBS] = [0; INT256_NUM_LIMBS];
278 let shift: u8 = rng.gen();
279 limbs_b[0] = shift;
280 let a = U256::from_le_bytes(limbs_a);
281 {
282 let res = SllOp::compute(limbs_a, limbs_b);
283 assert_eq!(U256::from_le_bytes(res), a << shift);
284 }
285 {
286 let res = SraOp::compute(limbs_a, limbs_b);
287 assert_eq!(U256::from_le_bytes(res), a.arithmetic_shr(shift as usize));
288 }
289 {
290 let res = SrlOp::compute(limbs_a, limbs_b);
291 assert_eq!(U256::from_le_bytes(res), a >> shift);
292 }
293 }
294 }
295}