openvm_rv32im_circuit/base_alu/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::BaseAluOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{adapters::imm_to_bytes, BaseAluExecutor};
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21pub(super) struct BaseAluPreCompute {
22    c: u32,
23    a: u8,
24    b: u8,
25}
26
27impl<A, const LIMB_BITS: usize> BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
28    /// Return `is_imm`, true if `e` is RV32_IMM_AS.
29    #[inline(always)]
30    pub(super) fn pre_compute_impl<F: PrimeField32>(
31        &self,
32        pc: u32,
33        inst: &Instruction<F>,
34        data: &mut BaseAluPreCompute,
35    ) -> Result<bool, StaticProgramError> {
36        let Instruction { a, b, c, d, e, .. } = inst;
37        let e_u32 = e.as_canonical_u32();
38        if (d.as_canonical_u32() != RV32_REGISTER_AS)
39            || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
40        {
41            return Err(StaticProgramError::InvalidInstruction(pc));
42        }
43        let is_imm = e_u32 == RV32_IMM_AS;
44        let c_u32 = c.as_canonical_u32();
45        *data = BaseAluPreCompute {
46            c: if is_imm {
47                u32::from_le_bytes(imm_to_bytes(c_u32))
48            } else {
49                c_u32
50            },
51            a: a.as_canonical_u32() as u8,
52            b: b.as_canonical_u32() as u8,
53        };
54        Ok(is_imm)
55    }
56}
57
58macro_rules! dispatch {
59    ($execute_impl:ident, $is_imm:ident, $opcode:expr, $offset:expr) => {
60        Ok(
61            match (
62                $is_imm,
63                BaseAluOpcode::from_usize($opcode.local_opcode_idx($offset)),
64            ) {
65                (true, BaseAluOpcode::ADD) => $execute_impl::<_, _, true, AddOp>,
66                (false, BaseAluOpcode::ADD) => $execute_impl::<_, _, false, AddOp>,
67                (true, BaseAluOpcode::SUB) => $execute_impl::<_, _, true, SubOp>,
68                (false, BaseAluOpcode::SUB) => $execute_impl::<_, _, false, SubOp>,
69                (true, BaseAluOpcode::XOR) => $execute_impl::<_, _, true, XorOp>,
70                (false, BaseAluOpcode::XOR) => $execute_impl::<_, _, false, XorOp>,
71                (true, BaseAluOpcode::OR) => $execute_impl::<_, _, true, OrOp>,
72                (false, BaseAluOpcode::OR) => $execute_impl::<_, _, false, OrOp>,
73                (true, BaseAluOpcode::AND) => $execute_impl::<_, _, true, AndOp>,
74                (false, BaseAluOpcode::AND) => $execute_impl::<_, _, false, AndOp>,
75            },
76        )
77    };
78}
79
80impl<F, A, const LIMB_BITS: usize> Executor<F>
81    for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
82where
83    F: PrimeField32,
84{
85    #[inline(always)]
86    fn pre_compute_size(&self) -> usize {
87        size_of::<BaseAluPreCompute>()
88    }
89
90    #[cfg(not(feature = "tco"))]
91    fn pre_compute<Ctx>(
92        &self,
93        pc: u32,
94        inst: &Instruction<F>,
95        data: &mut [u8],
96    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
97    where
98        Ctx: ExecutionCtxTrait,
99    {
100        let data: &mut BaseAluPreCompute = data.borrow_mut();
101        let is_imm = self.pre_compute_impl(pc, inst, data)?;
102
103        dispatch!(execute_e1_handler, is_imm, inst.opcode, self.offset)
104    }
105
106    #[cfg(feature = "tco")]
107    fn handler<Ctx>(
108        &self,
109        pc: u32,
110        inst: &Instruction<F>,
111        data: &mut [u8],
112    ) -> Result<Handler<F, Ctx>, StaticProgramError>
113    where
114        Ctx: ExecutionCtxTrait,
115    {
116        let data: &mut BaseAluPreCompute = data.borrow_mut();
117        let is_imm = self.pre_compute_impl(pc, inst, data)?;
118
119        dispatch!(execute_e1_handler, is_imm, inst.opcode, self.offset)
120    }
121}
122
123impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
124    for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
125where
126    F: PrimeField32,
127{
128    #[inline(always)]
129    fn metered_pre_compute_size(&self) -> usize {
130        size_of::<E2PreCompute<BaseAluPreCompute>>()
131    }
132
133    #[cfg(not(feature = "tco"))]
134    fn metered_pre_compute<Ctx>(
135        &self,
136        chip_idx: usize,
137        pc: u32,
138        inst: &Instruction<F>,
139        data: &mut [u8],
140    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
141    where
142        Ctx: MeteredExecutionCtxTrait,
143    {
144        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
145        data.chip_idx = chip_idx as u32;
146        let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
147
148        dispatch!(execute_e2_handler, is_imm, inst.opcode, self.offset)
149    }
150
151    #[cfg(feature = "tco")]
152    fn metered_handler<Ctx>(
153        &self,
154        chip_idx: usize,
155        pc: u32,
156        inst: &Instruction<F>,
157        data: &mut [u8],
158    ) -> Result<Handler<F, Ctx>, StaticProgramError>
159    where
160        Ctx: MeteredExecutionCtxTrait,
161    {
162        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
163        data.chip_idx = chip_idx as u32;
164        let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
165
166        dispatch!(execute_e2_handler, is_imm, inst.opcode, self.offset)
167    }
168}
169
170#[inline(always)]
171unsafe fn execute_e12_impl<
172    F: PrimeField32,
173    CTX: ExecutionCtxTrait,
174    const IS_IMM: bool,
175    OP: AluOp,
176>(
177    pre_compute: &BaseAluPreCompute,
178    instret: &mut u64,
179    pc: &mut u32,
180    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
181) {
182    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
183    let rs2 = if IS_IMM {
184        pre_compute.c.to_le_bytes()
185    } else {
186        exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
187    };
188    let rs1 = u32::from_le_bytes(rs1);
189    let rs2 = u32::from_le_bytes(rs2);
190    let rd = <OP as AluOp>::compute(rs1, rs2);
191    let rd = rd.to_le_bytes();
192    exec_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
193    *pc = pc.wrapping_add(DEFAULT_PC_STEP);
194    *instret += 1;
195}
196
197#[create_handler]
198#[inline(always)]
199unsafe fn execute_e1_impl<
200    F: PrimeField32,
201    CTX: ExecutionCtxTrait,
202    const IS_IMM: bool,
203    OP: AluOp,
204>(
205    pre_compute: &[u8],
206    instret: &mut u64,
207    pc: &mut u32,
208    _instret_end: u64,
209    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
210) {
211    let pre_compute: &BaseAluPreCompute = pre_compute.borrow();
212    execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, instret, pc, exec_state);
213}
214
215#[create_handler]
216#[inline(always)]
217unsafe fn execute_e2_impl<
218    F: PrimeField32,
219    CTX: MeteredExecutionCtxTrait,
220    const IS_IMM: bool,
221    OP: AluOp,
222>(
223    pre_compute: &[u8],
224    instret: &mut u64,
225    pc: &mut u32,
226    _arg: u64,
227    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
228) {
229    let pre_compute: &E2PreCompute<BaseAluPreCompute> = pre_compute.borrow();
230    exec_state
231        .ctx
232        .on_height_change(pre_compute.chip_idx as usize, 1);
233    execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, instret, pc, exec_state);
234}
235
236trait AluOp {
237    fn compute(rs1: u32, rs2: u32) -> u32;
238}
239struct AddOp;
240struct SubOp;
241struct XorOp;
242struct OrOp;
243struct AndOp;
244impl AluOp for AddOp {
245    #[inline(always)]
246    fn compute(rs1: u32, rs2: u32) -> u32 {
247        rs1.wrapping_add(rs2)
248    }
249}
250impl AluOp for SubOp {
251    #[inline(always)]
252    fn compute(rs1: u32, rs2: u32) -> u32 {
253        rs1.wrapping_sub(rs2)
254    }
255}
256impl AluOp for XorOp {
257    #[inline(always)]
258    fn compute(rs1: u32, rs2: u32) -> u32 {
259        rs1 ^ rs2
260    }
261}
262impl AluOp for OrOp {
263    #[inline(always)]
264    fn compute(rs1: u32, rs2: u32) -> u32 {
265        rs1 | rs2
266    }
267}
268impl AluOp for AndOp {
269    #[inline(always)]
270    fn compute(rs1: u32, rs2: u32) -> u32 {
271        rs1 & rs2
272    }
273}