openvm_rv32im_circuit/shift/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::ShiftOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use super::ShiftExecutor;
18use crate::adapters::imm_to_bytes;
19
20#[derive(AlignedBytesBorrow, Clone)]
21#[repr(C)]
22struct ShiftPreCompute {
23    c: u32,
24    a: u8,
25    b: u8,
26}
27
28impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftExecutor<A, NUM_LIMBS, LIMB_BITS> {
29    #[inline(always)]
30    fn pre_compute_impl<F: PrimeField32>(
31        &self,
32        pc: u32,
33        inst: &Instruction<F>,
34        data: &mut ShiftPreCompute,
35    ) -> Result<(bool, ShiftOpcode), StaticProgramError> {
36        let Instruction {
37            opcode, a, b, c, e, ..
38        } = inst;
39        let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset));
40        let e_u32 = e.as_canonical_u32();
41        if inst.d.as_canonical_u32() != RV32_REGISTER_AS
42            || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
43        {
44            return Err(StaticProgramError::InvalidInstruction(pc));
45        }
46        let is_imm = e_u32 == RV32_IMM_AS;
47        let c_u32 = c.as_canonical_u32();
48        *data = ShiftPreCompute {
49            c: if is_imm {
50                u32::from_le_bytes(imm_to_bytes(c_u32))
51            } else {
52                c_u32
53            },
54            a: a.as_canonical_u32() as u8,
55            b: b.as_canonical_u32() as u8,
56        };
57        // `d` is always expected to be RV32_REGISTER_AS.
58        Ok((is_imm, shift_opcode))
59    }
60}
61
62macro_rules! dispatch {
63    ($execute_impl:ident, $is_imm:ident, $shift_opcode:ident) => {
64        match ($is_imm, $shift_opcode) {
65            (true, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, true, SllOp>),
66            (false, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, false, SllOp>),
67            (true, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, true, SrlOp>),
68            (false, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, false, SrlOp>),
69            (true, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, true, SraOp>),
70            (false, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, false, SraOp>),
71        }
72    };
73}
74
75impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> Executor<F>
76    for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
77where
78    F: PrimeField32,
79{
80    fn pre_compute_size(&self) -> usize {
81        size_of::<ShiftPreCompute>()
82    }
83
84    fn pre_compute<Ctx: ExecutionCtxTrait>(
85        &self,
86        pc: u32,
87        inst: &Instruction<F>,
88        data: &mut [u8],
89    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
90        let data: &mut ShiftPreCompute = data.borrow_mut();
91        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
92        // `d` is always expected to be RV32_REGISTER_AS.
93        dispatch!(execute_e1_impl, is_imm, shift_opcode)
94    }
95
96    #[cfg(feature = "tco")]
97    fn handler<Ctx>(
98        &self,
99        pc: u32,
100        inst: &Instruction<F>,
101        data: &mut [u8],
102    ) -> Result<Handler<F, Ctx>, StaticProgramError>
103    where
104        Ctx: ExecutionCtxTrait,
105    {
106        let data: &mut ShiftPreCompute = data.borrow_mut();
107        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
108        // `d` is always expected to be RV32_REGISTER_AS.
109        dispatch!(execute_e1_tco_handler, is_imm, shift_opcode)
110    }
111}
112
113impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> MeteredExecutor<F>
114    for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
115where
116    F: PrimeField32,
117{
118    fn metered_pre_compute_size(&self) -> usize {
119        size_of::<E2PreCompute<ShiftPreCompute>>()
120    }
121
122    fn metered_pre_compute<Ctx: MeteredExecutionCtxTrait>(
123        &self,
124        chip_idx: usize,
125        pc: u32,
126        inst: &Instruction<F>,
127        data: &mut [u8],
128    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
129        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
130        data.chip_idx = chip_idx as u32;
131        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
132        // `d` is always expected to be RV32_REGISTER_AS.
133        dispatch!(execute_e2_impl, is_imm, shift_opcode)
134    }
135
136    #[cfg(feature = "tco")]
137    fn metered_handler<Ctx: MeteredExecutionCtxTrait>(
138        &self,
139        chip_idx: usize,
140        pc: u32,
141        inst: &Instruction<F>,
142        data: &mut [u8],
143    ) -> Result<Handler<F, Ctx>, StaticProgramError> {
144        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
145        data.chip_idx = chip_idx as u32;
146        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
147        // `d` is always expected to be RV32_REGISTER_AS.
148        dispatch!(execute_e2_tco_handler, is_imm, shift_opcode)
149    }
150}
151
152unsafe fn execute_e12_impl<
153    F: PrimeField32,
154    CTX: ExecutionCtxTrait,
155    const IS_IMM: bool,
156    OP: ShiftOp,
157>(
158    pre_compute: &ShiftPreCompute,
159    state: &mut VmExecState<F, GuestMemory, CTX>,
160) {
161    let rs1 = state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
162    let rs2 = if IS_IMM {
163        pre_compute.c.to_le_bytes()
164    } else {
165        state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
166    };
167    let rs2 = u32::from_le_bytes(rs2);
168
169    // Execute the shift operation
170    let rd = <OP as ShiftOp>::compute(rs1, rs2);
171    // Write the result back to memory
172    state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
173
174    state.instret += 1;
175    state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
176}
177
178#[create_tco_handler]
179unsafe fn execute_e1_impl<
180    F: PrimeField32,
181    CTX: ExecutionCtxTrait,
182    const IS_IMM: bool,
183    OP: ShiftOp,
184>(
185    pre_compute: &[u8],
186    state: &mut VmExecState<F, GuestMemory, CTX>,
187) {
188    let pre_compute: &ShiftPreCompute = pre_compute.borrow();
189    execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, state);
190}
191
192#[create_tco_handler]
193unsafe fn execute_e2_impl<
194    F: PrimeField32,
195    CTX: MeteredExecutionCtxTrait,
196    const IS_IMM: bool,
197    OP: ShiftOp,
198>(
199    pre_compute: &[u8],
200    state: &mut VmExecState<F, GuestMemory, CTX>,
201) {
202    let pre_compute: &E2PreCompute<ShiftPreCompute> = pre_compute.borrow();
203    state.ctx.on_height_change(pre_compute.chip_idx as usize, 1);
204    execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, state);
205}
206
207trait ShiftOp {
208    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4];
209}
210struct SllOp;
211struct SrlOp;
212struct SraOp;
213impl ShiftOp for SllOp {
214    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
215        let rs1 = u32::from_le_bytes(rs1);
216        // `rs2`'s  other bits are ignored.
217        (rs1 << (rs2 & 0x1F)).to_le_bytes()
218    }
219}
220impl ShiftOp for SrlOp {
221    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
222        let rs1 = u32::from_le_bytes(rs1);
223        // `rs2`'s  other bits are ignored.
224        (rs1 >> (rs2 & 0x1F)).to_le_bytes()
225    }
226}
227impl ShiftOp for SraOp {
228    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
229        let rs1 = i32::from_le_bytes(rs1);
230        // `rs2`'s  other bits are ignored.
231        (rs1 >> (rs2 & 0x1F)).to_le_bytes()
232    }
233}