openvm_rv32im_circuit/shift/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::ShiftOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use super::ShiftExecutor;
18#[allow(unused_imports)]
19use crate::{adapters::imm_to_bytes, common::*};
20
21#[derive(AlignedBytesBorrow, Clone)]
22#[repr(C)]
23struct ShiftPreCompute {
24    c: u32,
25    a: u8,
26    b: u8,
27}
28
29impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> ShiftExecutor<A, NUM_LIMBS, LIMB_BITS> {
30    #[inline(always)]
31    fn pre_compute_impl<F: PrimeField32>(
32        &self,
33        pc: u32,
34        inst: &Instruction<F>,
35        data: &mut ShiftPreCompute,
36    ) -> Result<(bool, ShiftOpcode), StaticProgramError> {
37        let Instruction {
38            opcode, a, b, c, e, ..
39        } = inst;
40        let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset));
41        let e_u32 = e.as_canonical_u32();
42        if inst.d.as_canonical_u32() != RV32_REGISTER_AS
43            || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
44        {
45            return Err(StaticProgramError::InvalidInstruction(pc));
46        }
47        let is_imm = e_u32 == RV32_IMM_AS;
48        let c_u32 = c.as_canonical_u32();
49        *data = ShiftPreCompute {
50            c: if is_imm {
51                u32::from_le_bytes(imm_to_bytes(c_u32))
52            } else {
53                c_u32
54            },
55            a: a.as_canonical_u32() as u8,
56            b: b.as_canonical_u32() as u8,
57        };
58        // `d` is always expected to be RV32_REGISTER_AS.
59        Ok((is_imm, shift_opcode))
60    }
61}
62
63macro_rules! dispatch {
64    ($execute_impl:ident, $is_imm:ident, $shift_opcode:ident) => {
65        match ($is_imm, $shift_opcode) {
66            (true, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, true, SllOp>),
67            (false, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, false, SllOp>),
68            (true, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, true, SrlOp>),
69            (false, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, false, SrlOp>),
70            (true, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, true, SraOp>),
71            (false, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, false, SraOp>),
72        }
73    };
74}
75
76impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> InterpreterExecutor<F>
77    for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
78where
79    F: PrimeField32,
80{
81    fn pre_compute_size(&self) -> usize {
82        size_of::<ShiftPreCompute>()
83    }
84
85    #[cfg(not(feature = "tco"))]
86    fn pre_compute<Ctx: ExecutionCtxTrait>(
87        &self,
88        pc: u32,
89        inst: &Instruction<F>,
90        data: &mut [u8],
91    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
92        let data: &mut ShiftPreCompute = data.borrow_mut();
93        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
94        // `d` is always expected to be RV32_REGISTER_AS.
95        dispatch!(execute_e1_handler, is_imm, shift_opcode)
96    }
97
98    #[cfg(feature = "tco")]
99    fn handler<Ctx>(
100        &self,
101        pc: u32,
102        inst: &Instruction<F>,
103        data: &mut [u8],
104    ) -> Result<Handler<F, Ctx>, StaticProgramError>
105    where
106        Ctx: ExecutionCtxTrait,
107    {
108        let data: &mut ShiftPreCompute = data.borrow_mut();
109        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?;
110        // `d` is always expected to be RV32_REGISTER_AS.
111        dispatch!(execute_e1_handler, is_imm, shift_opcode)
112    }
113}
114
115#[cfg(feature = "aot")]
116impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> AotExecutor<F>
117    for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
118where
119    F: PrimeField32,
120{
121    fn is_aot_supported(&self, _instruction: &Instruction<F>) -> bool {
122        true
123    }
124    fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
125        let to_i16 = |c: F| -> i16 {
126            let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
127            let c_i24 = ((c_u24 << 8) as i32) >> 8;
128            c_i24 as i16
129        };
130        let mut asm_str = String::new();
131        let a: i16 = to_i16(inst.a);
132        let b: i16 = to_i16(inst.b);
133        let c: i16 = to_i16(inst.c);
134        let e: i16 = to_i16(inst.e);
135        assert!(a % 4 == 0, "instruction.a must be a multiple of 4");
136        assert!(b % 4 == 0, "instruction.b must be a multiple of 4");
137
138        // note: for shift we will use REG_B since
139        // it is a hardware requirement that cl is used as the shift value
140        // and we don't want to override the written [b:4]_1
141        // [a:4]_1 <- [b:4]_1
142
143        let str_reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
144            RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
145        } else {
146            REG_A_W
147        };
148
149        if e == 0 {
150            // [a:4]_1 <- [b:4]_1 (shift) c
151            let mut asm_opcode = String::new();
152            if inst.opcode == ShiftOpcode::SLL.global_opcode() {
153                asm_opcode += "shl";
154            } else if inst.opcode == ShiftOpcode::SRL.global_opcode() {
155                asm_opcode += "shr";
156            } else if inst.opcode == ShiftOpcode::SRA.global_opcode() {
157                asm_opcode += "sar";
158            }
159
160            let (reg_b, delta_str_b) = &xmm_to_gpr((b / 4) as u8, str_reg_a, true);
161            asm_str += delta_str_b;
162            asm_str += &format!("   {asm_opcode} {reg_b}, {c}\n");
163            asm_str += &gpr_to_xmm(reg_b, (a / 4) as u8);
164        } else {
165            // [b:4]_1 <- [b:4]_1 (shift) [c:4]_1
166            let mut asm_opcode = String::new();
167            if inst.opcode == ShiftOpcode::SLL.global_opcode() {
168                asm_opcode += "shlx";
169            } else if inst.opcode == ShiftOpcode::SRL.global_opcode() {
170                asm_opcode += "shrx";
171            } else if inst.opcode == ShiftOpcode::SRA.global_opcode() {
172                asm_opcode += "sarx";
173            }
174
175            let (reg_b, delta_str_b) = &xmm_to_gpr((b / 4) as u8, REG_B_W, false);
176            // after this force write, we set [a:4]_1 <- [b:4]_1
177            asm_str += delta_str_b;
178
179            let (reg_c, delta_str_c) = &xmm_to_gpr((c / 4) as u8, REG_C_W, false);
180            asm_str += delta_str_c;
181
182            asm_str += &format!("   {asm_opcode} {str_reg_a}, {reg_b}, {reg_c}\n");
183
184            asm_str += &gpr_to_xmm(str_reg_a, (a / 4) as u8);
185        }
186
187        // let it fall to the next instruction
188        Ok(asm_str)
189    }
190}
191
192impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
193    for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
194where
195    F: PrimeField32,
196{
197    fn metered_pre_compute_size(&self) -> usize {
198        size_of::<E2PreCompute<ShiftPreCompute>>()
199    }
200
201    #[cfg(not(feature = "tco"))]
202    fn metered_pre_compute<Ctx: MeteredExecutionCtxTrait>(
203        &self,
204        chip_idx: usize,
205        pc: u32,
206        inst: &Instruction<F>,
207        data: &mut [u8],
208    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
209        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
210        data.chip_idx = chip_idx as u32;
211        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
212        // `d` is always expected to be RV32_REGISTER_AS.
213        dispatch!(execute_e2_handler, is_imm, shift_opcode)
214    }
215
216    #[cfg(feature = "tco")]
217    fn metered_handler<Ctx: MeteredExecutionCtxTrait>(
218        &self,
219        chip_idx: usize,
220        pc: u32,
221        inst: &Instruction<F>,
222        data: &mut [u8],
223    ) -> Result<Handler<F, Ctx>, StaticProgramError> {
224        let data: &mut E2PreCompute<ShiftPreCompute> = data.borrow_mut();
225        data.chip_idx = chip_idx as u32;
226        let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?;
227        // `d` is always expected to be RV32_REGISTER_AS.
228        dispatch!(execute_e2_handler, is_imm, shift_opcode)
229    }
230}
231
232#[cfg(feature = "aot")]
233impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> AotMeteredExecutor<F>
234    for ShiftExecutor<A, NUM_LIMBS, LIMB_BITS>
235where
236    F: PrimeField32,
237{
238    fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
239        true
240    }
241    fn generate_x86_metered_asm(
242        &self,
243        inst: &Instruction<F>,
244        pc: u32,
245        chip_idx: usize,
246        config: &SystemConfig,
247    ) -> Result<String, AotError> {
248        let is_imm = inst.e.as_canonical_u32() == RV32_IMM_AS;
249        let mut asm_str = self.generate_x86_asm(inst, pc)?;
250        asm_str += &update_height_change_asm(chip_idx, 1)?;
251        // read [a:4]_1
252        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
253        // read [b:4]_1
254        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
255        if !is_imm {
256            // read [c:4]_1
257            asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
258        }
259        Ok(asm_str)
260    }
261}
262#[inline(always)]
263unsafe fn execute_e12_impl<
264    F: PrimeField32,
265    CTX: ExecutionCtxTrait,
266    const IS_IMM: bool,
267    OP: ShiftOp,
268>(
269    pre_compute: &ShiftPreCompute,
270    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
271) {
272    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
273    let rs2 = if IS_IMM {
274        pre_compute.c.to_le_bytes()
275    } else {
276        exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
277    };
278    let rs2 = u32::from_le_bytes(rs2);
279
280    // Execute the shift operation
281    let rd = <OP as ShiftOp>::compute(rs1, rs2);
282    // Write the result back to memory
283    exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
284
285    let pc = exec_state.pc();
286    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
287}
288
289#[create_handler]
290#[inline(always)]
291unsafe fn execute_e1_impl<
292    F: PrimeField32,
293    CTX: ExecutionCtxTrait,
294    const IS_IMM: bool,
295    OP: ShiftOp,
296>(
297    pre_compute: *const u8,
298    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
299) {
300    let pre_compute: &ShiftPreCompute =
301        std::slice::from_raw_parts(pre_compute, size_of::<ShiftPreCompute>()).borrow();
302    execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, exec_state);
303}
304
305#[create_handler]
306#[inline(always)]
307unsafe fn execute_e2_impl<
308    F: PrimeField32,
309    CTX: MeteredExecutionCtxTrait,
310    const IS_IMM: bool,
311    OP: ShiftOp,
312>(
313    pre_compute: *const u8,
314    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
315) {
316    let pre_compute: &E2PreCompute<ShiftPreCompute> =
317        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<ShiftPreCompute>>())
318            .borrow();
319    exec_state
320        .ctx
321        .on_height_change(pre_compute.chip_idx as usize, 1);
322    execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, exec_state);
323}
324
325trait ShiftOp {
326    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4];
327}
328struct SllOp;
329struct SrlOp;
330struct SraOp;
331impl ShiftOp for SllOp {
332    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
333        let rs1 = u32::from_le_bytes(rs1);
334        // `rs2`'s  other bits are ignored.
335        (rs1 << (rs2 & 0x1F)).to_le_bytes()
336    }
337}
338impl ShiftOp for SrlOp {
339    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
340        let rs1 = u32::from_le_bytes(rs1);
341        // `rs2`'s  other bits are ignored.
342        (rs1 >> (rs2 & 0x1F)).to_le_bytes()
343    }
344}
345impl ShiftOp for SraOp {
346    fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] {
347        let rs1 = i32::from_le_bytes(rs1);
348        // `rs2`'s  other bits are ignored.
349        (rs1 >> (rs2 & 0x1F)).to_le_bytes()
350    }
351}