openvm_bigint_circuit/
shift.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32Shift256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7    instruction::Instruction,
8    program::DEFAULT_PC_STEP,
9    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10    LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
13use openvm_rv32im_circuit::ShiftExecutor;
14use openvm_rv32im_transpiler::ShiftOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18    common::{bytes_to_u64_array, u64_array_to_bytes},
19    Rv32Shift256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32Shift256Executor {
25    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26        Self(ShiftExecutor::new(adapter, offset))
27    }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct ShiftPreCompute {
33    a: u8,
34    b: u8,
35    c: u8,
36}
37
38macro_rules! dispatch {
39    ($execute_impl:ident, $local_opcode:ident) => {
40        Ok(match $local_opcode {
41            ShiftOpcode::SLL => $execute_impl::<_, _, SllOp>,
42            ShiftOpcode::SRA => $execute_impl::<_, _, SraOp>,
43            ShiftOpcode::SRL => $execute_impl::<_, _, SrlOp>,
44        })
45    };
46}
47
48impl<F: PrimeField32> Executor<F> for Rv32Shift256Executor {
49    fn pre_compute_size(&self) -> usize {
50        size_of::<ShiftPreCompute>()
51    }
52
53    fn pre_compute<Ctx>(
54        &self,
55        pc: u32,
56        inst: &Instruction<F>,
57        data: &mut [u8],
58    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
59    where
60        Ctx: ExecutionCtxTrait,
61    {
62        let data: &mut ShiftPreCompute = data.borrow_mut();
63        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
64        dispatch!(execute_e1_impl, local_opcode)
65    }
66
67    #[cfg(feature = "tco")]
68    fn handler<Ctx>(
69        &self,
70        pc: u32,
71        inst: &Instruction<F>,
72        data: &mut [u8],
73    ) -> Result<Handler<F, Ctx>, StaticProgramError>
74    where
75        Ctx: ExecutionCtxTrait,
76    {
77        let data: &mut ShiftPreCompute = data.borrow_mut();
78        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
79        dispatch!(execute_e1_tco_handler, local_opcode)
80    }
81}
82
83impl<F: PrimeField32> MeteredExecutor<F> for Rv32Shift256Executor {
84    fn metered_pre_compute_size(&self) -> usize {
85        size_of::<E2PreCompute<ShiftPreCompute>>()
86    }
87
88    fn metered_pre_compute<Ctx>(
89        &self,
90        chip_idx: usize,
91        pc: u32,
92        inst: &Instruction<F>,
93        data: &mut [u8],
94    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
95    where
96        Ctx: MeteredExecutionCtxTrait,
97    {
98        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
99        data.chip_idx = chip_idx as u32;
100        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
101        dispatch!(execute_e2_impl, local_opcode)
102    }
103
104    #[cfg(feature = "tco")]
105    fn metered_handler<Ctx>(
106        &self,
107        chip_idx: usize,
108        pc: u32,
109        inst: &Instruction<F>,
110        data: &mut [u8],
111    ) -> Result<Handler<F, Ctx>, StaticProgramError>
112    where
113        Ctx: MeteredExecutionCtxTrait,
114    {
115        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
116        data.chip_idx = chip_idx as u32;
117        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
118        dispatch!(execute_e2_tco_handler, local_opcode)
119    }
120}
121
122#[inline(always)]
123unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
124    pre_compute: &ShiftPreCompute,
125    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
126) {
127    let rs1_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
128    let rs2_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
129    let rd_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
130    let rs1 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
131    let rs2 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
132    let rd = OP::compute(rs1, rs2);
133    vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
134    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
135    vm_state.instret += 1;
136}
137
138#[create_tco_handler]
139unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
140    pre_compute: &[u8],
141    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
142) {
143    let pre_compute: &ShiftPreCompute = pre_compute.borrow();
144    execute_e12_impl::<F, CTX, OP>(pre_compute, vm_state);
145}
146#[create_tco_handler]
147unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: ShiftOp>(
148    pre_compute: &[u8],
149    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
150) {
151    let pre_compute: &E2PreCompute<ShiftPreCompute> = pre_compute.borrow();
152    vm_state
153        .ctx
154        .on_height_change(pre_compute.chip_idx as usize, 1);
155    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, vm_state);
156}
157
158impl Rv32Shift256Executor {
159    fn pre_compute_impl<F: PrimeField32>(
160        &self,
161        pc: u32,
162        inst: &Instruction<F>,
163        data: &mut ShiftPreCompute,
164    ) -> Result<ShiftOpcode, StaticProgramError> {
165        let Instruction {
166            opcode,
167            a,
168            b,
169            c,
170            d,
171            e,
172            ..
173        } = inst;
174        let e_u32 = e.as_canonical_u32();
175        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
176            return Err(StaticProgramError::InvalidInstruction(pc));
177        }
178        *data = ShiftPreCompute {
179            a: a.as_canonical_u32() as u8,
180            b: b.as_canonical_u32() as u8,
181            c: c.as_canonical_u32() as u8,
182        };
183        let local_opcode =
184            ShiftOpcode::from_usize(opcode.local_opcode_idx(Rv32Shift256Opcode::CLASS_OFFSET));
185        Ok(local_opcode)
186    }
187}
188
189trait ShiftOp {
190    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS];
191}
192struct SllOp;
193struct SrlOp;
194struct SraOp;
195impl ShiftOp for SllOp {
196    #[inline(always)]
197    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
198        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
199        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
200        let mut rd = [0u64; 4];
201        // Only use the first 8 bits.
202        let shift = (rs2_u64[0] & 0xff) as u32;
203        let index_offset = (shift / u64::BITS) as usize;
204        let bit_offset = shift % u64::BITS;
205        let mut carry = 0u64;
206        for i in index_offset..4 {
207            let curr = rs1_u64[i - index_offset];
208            rd[i] = (curr << bit_offset) + carry;
209            if bit_offset > 0 {
210                carry = curr >> (u64::BITS - bit_offset);
211            }
212        }
213        u64_array_to_bytes(rd)
214    }
215}
216impl ShiftOp for SrlOp {
217    #[inline(always)]
218    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
219        // Logical right shift - fill with 0
220        shift_right(rs1, rs2, 0)
221    }
222}
223impl ShiftOp for SraOp {
224    #[inline(always)]
225    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
226        // Arithmetic right shift - fill with sign bit
227        if rs1[INT256_NUM_LIMBS - 1] & 0x80 > 0 {
228            shift_right(rs1, rs2, u64::MAX)
229        } else {
230            shift_right(rs1, rs2, 0)
231        }
232    }
233}
234
235#[inline(always)]
236fn shift_right(
237    rs1: [u8; INT256_NUM_LIMBS],
238    rs2: [u8; INT256_NUM_LIMBS],
239    init_value: u64,
240) -> [u8; INT256_NUM_LIMBS] {
241    let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
242    let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
243    let mut rd = [init_value; 4];
244    let shift = (rs2_u64[0] & 0xff) as u32;
245    let index_offset = (shift / u64::BITS) as usize;
246    let bit_offset = shift % u64::BITS;
247    let mut carry = if bit_offset > 0 {
248        init_value << (u64::BITS - bit_offset)
249    } else {
250        0
251    };
252    for i in (index_offset..4).rev() {
253        let curr = rs1_u64[i];
254        rd[i - index_offset] = (curr >> bit_offset) + carry;
255        if bit_offset > 0 {
256            carry = curr << (u64::BITS - bit_offset);
257        }
258    }
259    u64_array_to_bytes(rd)
260}
261
262#[cfg(test)]
263mod tests {
264    use alloy_primitives::U256;
265    use rand::{prelude::StdRng, Rng, SeedableRng};
266
267    use crate::{
268        shift::{ShiftOp, SllOp, SraOp, SrlOp},
269        INT256_NUM_LIMBS,
270    };
271
272    #[test]
273    fn test_shift_op() {
274        let mut rng = StdRng::from_seed([42; 32]);
275        for _ in 0..10000 {
276            let limbs_a: [u8; INT256_NUM_LIMBS] = rng.gen();
277            let mut limbs_b: [u8; INT256_NUM_LIMBS] = [0; INT256_NUM_LIMBS];
278            let shift: u8 = rng.gen();
279            limbs_b[0] = shift;
280            let a = U256::from_le_bytes(limbs_a);
281            {
282                let res = SllOp::compute(limbs_a, limbs_b);
283                assert_eq!(U256::from_le_bytes(res), a << shift);
284            }
285            {
286                let res = SraOp::compute(limbs_a, limbs_b);
287                assert_eq!(U256::from_le_bytes(res), a.arithmetic_shr(shift as usize));
288            }
289            {
290                let res = SrlOp::compute(limbs_a, limbs_b);
291                assert_eq!(U256::from_le_bytes(res), a >> shift);
292            }
293        }
294    }
295}