openvm_rv32im_circuit/shift/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::ShiftOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use super::ShiftExecutor;
18use crate::adapters::imm_to_bytes;
19
20#[derive(AlignedBytesBorrow, Clone)]
21#[repr(C)]
22struct ShiftPreCompute {
23    c: u32,
24    a: u8,
25    b: u8,
26}
27
28impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftExecutor<A, NUM_LIMBS, LIMB_BITS> {
29    #[inline(always)]
30    fn pre_compute_impl<F: PrimeField32>(
31        &self,
32        pc: u32,
33        inst: &Instruction<F>,
34        data: &mut ShiftPreCompute,
35    ) -> Result<(bool, ShiftOpcode), StaticProgramError> {
36        let Instruction {
37            opcode, a, b, c, e, ..
38        } = inst;
39        let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset));
40        let e_u32 = e.as_canonical_u32();
41        if inst.d.as_canonical_u32() != RV32_REGISTER_AS
42            || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
43        {
44            return Err(StaticProgramError::InvalidInstruction(pc));
45        }
46        let is_imm = e_u32 == RV32_IMM_AS;
47        let c_u32 = c.as_canonical_u32();
48        *data = ShiftPreCompute {
49            c: if is_imm {
50                u32::from_le_bytes(imm_to_bytes(c_u32))
51            } else {
52                c_u32
53            },
54            a: a.as_canonical_u32() as u8,
55            b: b.as_canonical_u32() as u8,
56        };
57        // `d` is always expected to be RV32_REGISTER_AS.
58        Ok((is_imm, shift_opcode))
59    }
60}
61
62macro_rules! dispatch {
63    ($execute_impl:ident, $is_imm:ident, $shift_opcode:ident) => {
64        match ($is_imm, $shift_opcode) {
65            (true, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, true, SllOp>),
66            (false, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, false, SllOp>),
67            (true, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, true, SrlOp>),
68            (false, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, false, SrlOp>),
69            (true, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, true, SraOp>),
70            (false, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, false, SraOp>),
71        }
72    };
73}
74
75impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> Executor<F>
76    for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
77where
78    F: PrimeField32,
79{
80    fn pre_compute_size(&self) -> usize {
81        size_of::<ShiftPreCompute>()
82    }
83
84    #[cfg(not(feature = "tco"))]
85    fn pre_compute<Ctx: ExecutionCtxTrait>(
86        &self,
87        pc: u32,
88        inst: &Instruction<F>,
89        data: &mut [u8],
90    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
91        let data: &mut ShiftPreCompute = data.borrow_mut();
92        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
93        // `d` is always expected to be RV32_REGISTER_AS.
94        dispatch!(execute_e1_handler, is_imm, shift_opcode)
95    }
96
97    #[cfg(feature = "tco")]
98    fn handler<Ctx>(
99        &self,
100        pc: u32,
101        inst: &Instruction<F>,
102        data: &mut [u8],
103    ) -> Result<Handler<F, Ctx>, StaticProgramError>
104    where
105        Ctx: ExecutionCtxTrait,
106    {
107        let data: &mut ShiftPreCompute = data.borrow_mut();
108        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
109        // `d` is always expected to be RV32_REGISTER_AS.
110        dispatch!(execute_e1_handler, is_imm, shift_opcode)
111    }
112}
113
114impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> MeteredExecutor<F>
115    for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
116where
117    F: PrimeField32,
118{
119    fn metered_pre_compute_size(&self) -> usize {
120        size_of::<E2PreCompute<ShiftPreCompute>>()
121    }
122
123    #[cfg(not(feature = "tco"))]
124    fn metered_pre_compute<Ctx: MeteredExecutionCtxTrait>(
125        &self,
126        chip_idx: usize,
127        pc: u32,
128        inst: &Instruction<F>,
129        data: &mut [u8],
130    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
131        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
132        data.chip_idx = chip_idx as u32;
133        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
134        // `d` is always expected to be RV32_REGISTER_AS.
135        dispatch!(execute_e2_handler, is_imm, shift_opcode)
136    }
137
138    #[cfg(feature = "tco")]
139    fn metered_handler<Ctx: MeteredExecutionCtxTrait>(
140        &self,
141        chip_idx: usize,
142        pc: u32,
143        inst: &Instruction<F>,
144        data: &mut [u8],
145    ) -> Result<Handler<F, Ctx>, StaticProgramError> {
146        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
147        data.chip_idx = chip_idx as u32;
148        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
149        // `d` is always expected to be RV32_REGISTER_AS.
150        dispatch!(execute_e2_handler, is_imm, shift_opcode)
151    }
152}
153
154#[inline(always)]
155unsafe fn execute_e12_impl<
156    F: PrimeField32,
157    CTX: ExecutionCtxTrait,
158    const IS_IMM: bool,
159    OP: ShiftOp,
160>(
161    pre_compute: &ShiftPreCompute,
162    instret: &mut u64,
163    pc: &mut u32,
164    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
165) {
166    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
167    let rs2 = if IS_IMM {
168        pre_compute.c.to_le_bytes()
169    } else {
170        exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
171    };
172    let rs2 = u32::from_le_bytes(rs2);
173
174    // Execute the shift operation
175    let rd = <OP as ShiftOp>::compute(rs1, rs2);
176    // Write the result back to memory
177    exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
178
179    *instret += 1;
180    *pc = pc.wrapping_add(DEFAULT_PC_STEP);
181}
182
183#[create_handler]
184#[inline(always)]
185unsafe fn execute_e1_impl<
186    F: PrimeField32,
187    CTX: ExecutionCtxTrait,
188    const IS_IMM: bool,
189    OP: ShiftOp,
190>(
191    pre_compute: &[u8],
192    instret: &mut u64,
193    pc: &mut u32,
194    _instret_end: u64,
195    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
196) {
197    let pre_compute: &ShiftPreCompute = pre_compute.borrow();
198    execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, instret, pc, exec_state);
199}
200
201#[create_handler]
202#[inline(always)]
203unsafe fn execute_e2_impl<
204    F: PrimeField32,
205    CTX: MeteredExecutionCtxTrait,
206    const IS_IMM: bool,
207    OP: ShiftOp,
208>(
209    pre_compute: &[u8],
210    instret: &mut u64,
211    pc: &mut u32,
212    _arg: u64,
213    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
214) {
215    let pre_compute: &E2PreCompute<ShiftPreCompute> = pre_compute.borrow();
216    exec_state
217        .ctx
218        .on_height_change(pre_compute.chip_idx as usize, 1);
219    execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, instret, pc, exec_state);
220}
221
222trait ShiftOp {
223    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4];
224}
225struct SllOp;
226struct SrlOp;
227struct SraOp;
228impl ShiftOp for SllOp {
229    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
230        let rs1 = u32::from_le_bytes(rs1);
231        // `rs2`'s  other bits are ignored.
232        (rs1 << (rs2 & 0x1F)).to_le_bytes()
233    }
234}
235impl ShiftOp for SrlOp {
236    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
237        let rs1 = u32::from_le_bytes(rs1);
238        // `rs2`'s  other bits are ignored.
239        (rs1 >> (rs2 & 0x1F)).to_le_bytes()
240    }
241}
242impl ShiftOp for SraOp {
243    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
244        let rs1 = i32::from_le_bytes(rs1);
245        // `rs2`'s  other bits are ignored.
246        (rs1 >> (rs2 & 0x1F)).to_le_bytes()
247    }
248}