openvm_bigint_circuit/
shift.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32Shift256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10    instruction::Instruction,
11    program::DEFAULT_PC_STEP,
12    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13    LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
16use openvm_rv32im_circuit::ShiftExecutor;
17use openvm_rv32im_transpiler::ShiftOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{
21    common::{bytes_to_u64_array, u64_array_to_bytes},
22    Rv32Shift256Executor, INT256_NUM_LIMBS,
23};
24
25type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
26
27impl Rv32Shift256Executor {
28    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
29        Self(ShiftExecutor::new(adapter, offset))
30    }
31}
32
33#[derive(AlignedBytesBorrow, Clone)]
34#[repr(C)]
35struct ShiftPreCompute {
36    a: u8,
37    b: u8,
38    c: u8,
39}
40
41macro_rules! dispatch {
42    ($execute_impl:ident, $local_opcode:ident) => {
43        Ok(match $local_opcode {
44            ShiftOpcode::SLL => $execute_impl::<_, _, SllOp>,
45            ShiftOpcode::SRA => $execute_impl::<_, _, SraOp>,
46            ShiftOpcode::SRL => $execute_impl::<_, _, SrlOp>,
47        })
48    };
49}
50
51impl<F: PrimeField32> InterpreterExecutor<F> for Rv32Shift256Executor {
52    fn pre_compute_size(&self) -> usize {
53        size_of::<ShiftPreCompute>()
54    }
55
56    #[cfg(not(feature = "tco"))]
57    fn pre_compute<Ctx>(
58        &self,
59        pc: u32,
60        inst: &Instruction<F>,
61        data: &mut [u8],
62    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
63    where
64        Ctx: ExecutionCtxTrait,
65    {
66        let data: &mut ShiftPreCompute = data.borrow_mut();
67        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
68        dispatch!(execute_e1_handler, local_opcode)
69    }
70
71    #[cfg(feature = "tco")]
72    fn handler<Ctx>(
73        &self,
74        pc: u32,
75        inst: &Instruction<F>,
76        data: &mut [u8],
77    ) -> Result<Handler<F, Ctx>, StaticProgramError>
78    where
79        Ctx: ExecutionCtxTrait,
80    {
81        let data: &mut ShiftPreCompute = data.borrow_mut();
82        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
83        dispatch!(execute_e1_handler, local_opcode)
84    }
85}
86
87#[cfg(feature = "aot")]
88impl<F: PrimeField32> AotExecutor<F> for Rv32Shift256Executor {}
89
90impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32Shift256Executor {
91    fn metered_pre_compute_size(&self) -> usize {
92        size_of::<E2PreCompute<ShiftPreCompute>>()
93    }
94
95    #[cfg(not(feature = "tco"))]
96    fn metered_pre_compute<Ctx>(
97        &self,
98        chip_idx: usize,
99        pc: u32,
100        inst: &Instruction<F>,
101        data: &mut [u8],
102    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
103    where
104        Ctx: MeteredExecutionCtxTrait,
105    {
106        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
107        data.chip_idx = chip_idx as u32;
108        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
109        dispatch!(execute_e2_handler, local_opcode)
110    }
111
112    #[cfg(feature = "tco")]
113    fn metered_handler<Ctx>(
114        &self,
115        chip_idx: usize,
116        pc: u32,
117        inst: &Instruction<F>,
118        data: &mut [u8],
119    ) -> Result<Handler<F, Ctx>, StaticProgramError>
120    where
121        Ctx: MeteredExecutionCtxTrait,
122    {
123        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
124        data.chip_idx = chip_idx as u32;
125        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
126        dispatch!(execute_e2_handler, local_opcode)
127    }
128}
129
130#[cfg(feature = "aot")]
131impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32Shift256Executor {}
132
133#[inline(always)]
134unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
135    pre_compute: &ShiftPreCompute,
136    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
137) {
138    let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
139    let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
140    let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
141    let rs1 =
142        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
143    let rs2 =
144        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
145    let rd = OP::compute(rs1, rs2);
146    exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
147    let pc = exec_state.pc();
148    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
149}
150
151#[create_handler]
152#[inline(always)]
153unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: ShiftOp>(
154    pre_compute: *const u8,
155    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
156) {
157    let pre_compute: &ShiftPreCompute =
158        std::slice::from_raw_parts(pre_compute, size_of::<ShiftPreCompute>()).borrow();
159    execute_e12_impl::<F, CTX, OP>(pre_compute, exec_state);
160}
161
162#[create_handler]
163#[inline(always)]
164unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: ShiftOp>(
165    pre_compute: *const u8,
166    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
167) {
168    let pre_compute: &E2PreCompute<ShiftPreCompute> =
169        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<ShiftPreCompute>>())
170            .borrow();
171    exec_state
172        .ctx
173        .on_height_change(pre_compute.chip_idx as usize, 1);
174    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, exec_state);
175}
176
177impl Rv32Shift256Executor {
178    fn pre_compute_impl<F: PrimeField32>(
179        &self,
180        pc: u32,
181        inst: &Instruction<F>,
182        data: &mut ShiftPreCompute,
183    ) -> Result<ShiftOpcode, StaticProgramError> {
184        let Instruction {
185            opcode,
186            a,
187            b,
188            c,
189            d,
190            e,
191            ..
192        } = inst;
193        let e_u32 = e.as_canonical_u32();
194        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
195            return Err(StaticProgramError::InvalidInstruction(pc));
196        }
197        *data = ShiftPreCompute {
198            a: a.as_canonical_u32() as u8,
199            b: b.as_canonical_u32() as u8,
200            c: c.as_canonical_u32() as u8,
201        };
202        let local_opcode =
203            ShiftOpcode::from_usize(opcode.local_opcode_idx(Rv32Shift256Opcode::CLASS_OFFSET));
204        Ok(local_opcode)
205    }
206}
207
208trait ShiftOp {
209    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS];
210}
211struct SllOp;
212struct SrlOp;
213struct SraOp;
214impl ShiftOp for SllOp {
215    #[inline(always)]
216    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
217        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
218        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
219        let mut rd = [0u64; 4];
220        // Only use the first 8 bits.
221        let shift = (rs2_u64[0] & 0xff) as u32;
222        let index_offset = (shift / u64::BITS) as usize;
223        let bit_offset = shift % u64::BITS;
224        let mut carry = 0u64;
225        for i in index_offset..4 {
226            let curr = rs1_u64[i - index_offset];
227            rd[i] = (curr << bit_offset) + carry;
228            if bit_offset > 0 {
229                carry = curr >> (u64::BITS - bit_offset);
230            }
231        }
232        u64_array_to_bytes(rd)
233    }
234}
235impl ShiftOp for SrlOp {
236    #[inline(always)]
237    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
238        // Logical right shift - fill with 0
239        shift_right(rs1, rs2, 0)
240    }
241}
242impl ShiftOp for SraOp {
243    #[inline(always)]
244    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
245        // Arithmetic right shift - fill with sign bit
246        if rs1[INT256_NUM_LIMBS - 1] & 0x80 > 0 {
247            shift_right(rs1, rs2, u64::MAX)
248        } else {
249            shift_right(rs1, rs2, 0)
250        }
251    }
252}
253
254#[inline(always)]
255fn shift_right(
256    rs1: [u8; INT256_NUM_LIMBS],
257    rs2: [u8; INT256_NUM_LIMBS],
258    init_value: u64,
259) -> [u8; INT256_NUM_LIMBS] {
260    let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
261    let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
262    let mut rd = [init_value; 4];
263    let shift = (rs2_u64[0] & 0xff) as u32;
264    let index_offset = (shift / u64::BITS) as usize;
265    let bit_offset = shift % u64::BITS;
266    let mut carry = if bit_offset > 0 {
267        init_value << (u64::BITS - bit_offset)
268    } else {
269        0
270    };
271    for i in (index_offset..4).rev() {
272        let curr = rs1_u64[i];
273        rd[i - index_offset] = (curr >> bit_offset) + carry;
274        if bit_offset > 0 {
275            carry = curr << (u64::BITS - bit_offset);
276        }
277    }
278    u64_array_to_bytes(rd)
279}
280
281#[cfg(test)]
282mod tests {
283    use alloy_primitives::U256;
284    use rand::{prelude::StdRng, Rng, SeedableRng};
285
286    use crate::{
287        shift::{ShiftOp, SllOp, SraOp, SrlOp},
288        INT256_NUM_LIMBS,
289    };
290
291    #[test]
292    fn test_shift_op() {
293        let mut rng = StdRng::from_seed([42; 32]);
294        for _ in 0..10000 {
295            let limbs_a: [u8; INT256_NUM_LIMBS] = rng.gen();
296            let mut limbs_b: [u8; INT256_NUM_LIMBS] = [0; INT256_NUM_LIMBS];
297            let shift: u8 = rng.gen();
298            limbs_b[0] = shift;
299            let a = U256::from_le_bytes(limbs_a);
300            {
301                let res = SllOp::compute(limbs_a, limbs_b);
302                assert_eq!(U256::from_le_bytes(res), a << shift);
303            }
304            {
305                let res = SraOp::compute(limbs_a, limbs_b);
306                assert_eq!(U256::from_le_bytes(res), a.arithmetic_shr(shift as usize));
307            }
308            {
309                let res = SrlOp::compute(limbs_a, limbs_b);
310                assert_eq!(U256::from_le_bytes(res), a >> shift);
311            }
312        }
313    }
314}