openvm_rv32im_circuit/base_alu/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::BaseAluOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{adapters::imm_to_bytes, BaseAluExecutor};
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21pub(super) struct BaseAluPreCompute {
22    c: u32,
23    a: u8,
24    b: u8,
25}
26
27impl<A, const LIMB_BITS: usize> BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
28    /// Return `is_imm`, true if `e` is RV32_IMM_AS.
29    #[inline(always)]
30    pub(super) fn pre_compute_impl<F: PrimeField32>(
31        &self,
32        pc: u32,
33        inst: &Instruction<F>,
34        data: &mut BaseAluPreCompute,
35    ) -> Result<bool, StaticProgramError> {
36        let Instruction { a, b, c, d, e, .. } = inst;
37        let e_u32 = e.as_canonical_u32();
38        if (d.as_canonical_u32() != RV32_REGISTER_AS)
39            || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
40        {
41            return Err(StaticProgramError::InvalidInstruction(pc));
42        }
43        let is_imm = e_u32 == RV32_IMM_AS;
44        let c_u32 = c.as_canonical_u32();
45        *data = BaseAluPreCompute {
46            c: if is_imm {
47                u32::from_le_bytes(imm_to_bytes(c_u32))
48            } else {
49                c_u32
50            },
51            a: a.as_canonical_u32() as u8,
52            b: b.as_canonical_u32() as u8,
53        };
54        Ok(is_imm)
55    }
56}
57
58macro_rules! dispatch {
59    ($execute_impl:ident, $is_imm:ident, $opcode:expr, $offset:expr) => {
60        Ok(
61            match (
62                $is_imm,
63                BaseAluOpcode::from_usize($opcode.local_opcode_idx($offset)),
64            ) {
65                (true, BaseAluOpcode::ADD) => $execute_impl::<_, _, true, AddOp>,
66                (false, BaseAluOpcode::ADD) => $execute_impl::<_, _, false, AddOp>,
67                (true, BaseAluOpcode::SUB) => $execute_impl::<_, _, true, SubOp>,
68                (false, BaseAluOpcode::SUB) => $execute_impl::<_, _, false, SubOp>,
69                (true, BaseAluOpcode::XOR) => $execute_impl::<_, _, true, XorOp>,
70                (false, BaseAluOpcode::XOR) => $execute_impl::<_, _, false, XorOp>,
71                (true, BaseAluOpcode::OR) => $execute_impl::<_, _, true, OrOp>,
72                (false, BaseAluOpcode::OR) => $execute_impl::<_, _, false, OrOp>,
73                (true, BaseAluOpcode::AND) => $execute_impl::<_, _, true, AndOp>,
74                (false, BaseAluOpcode::AND) => $execute_impl::<_, _, false, AndOp>,
75            },
76        )
77    };
78}
79
80impl<F, A, const LIMB_BITS: usize> Executor<F>
81    for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
82where
83    F: PrimeField32,
84{
85    #[inline(always)]
86    fn pre_compute_size(&self) -> usize {
87        size_of::<BaseAluPreCompute>()
88    }
89
90    fn pre_compute<Ctx>(
91        &self,
92        pc: u32,
93        inst: &Instruction<F>,
94        data: &mut [u8],
95    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
96    where
97        Ctx: ExecutionCtxTrait,
98    {
99        let data: &mut BaseAluPreCompute = data.borrow_mut();
100        let is_imm = self.pre_compute_impl(pc, inst, data)?;
101
102        dispatch!(execute_e1_impl, is_imm, inst.opcode, self.offset)
103    }
104
105    #[cfg(feature = "tco")]
106    fn handler<Ctx>(
107        &self,
108        pc: u32,
109        inst: &Instruction<F>,
110        data: &mut [u8],
111    ) -> Result<Handler<F, Ctx>, StaticProgramError>
112    where
113        Ctx: ExecutionCtxTrait,
114    {
115        let data: &mut BaseAluPreCompute = data.borrow_mut();
116        let is_imm = self.pre_compute_impl(pc, inst, data)?;
117
118        dispatch!(execute_e1_tco_handler, is_imm, inst.opcode, self.offset)
119    }
120}
121
122impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
123    for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
124where
125    F: PrimeField32,
126{
127    #[inline(always)]
128    fn metered_pre_compute_size(&self) -> usize {
129        size_of::<E2PreCompute<BaseAluPreCompute>>()
130    }
131
132    fn metered_pre_compute<Ctx>(
133        &self,
134        chip_idx: usize,
135        pc: u32,
136        inst: &Instruction<F>,
137        data: &mut [u8],
138    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
139    where
140        Ctx: MeteredExecutionCtxTrait,
141    {
142        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
143        data.chip_idx = chip_idx as u32;
144        let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
145
146        dispatch!(execute_e2_impl, is_imm, inst.opcode, self.offset)
147    }
148
149    #[cfg(feature = "tco")]
150    fn metered_handler<Ctx>(
151        &self,
152        chip_idx: usize,
153        pc: u32,
154        inst: &Instruction<F>,
155        data: &mut [u8],
156    ) -> Result<Handler<F, Ctx>, StaticProgramError>
157    where
158        Ctx: MeteredExecutionCtxTrait,
159    {
160        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
161        data.chip_idx = chip_idx as u32;
162        let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
163
164        dispatch!(execute_e2_tco_handler, is_imm, inst.opcode, self.offset)
165    }
166}
167
168#[inline(always)]
169unsafe fn execute_e12_impl<
170    F: PrimeField32,
171    CTX: ExecutionCtxTrait,
172    const IS_IMM: bool,
173    OP: AluOp,
174>(
175    pre_compute: &BaseAluPreCompute,
176    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
177) {
178    let rs1 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
179    let rs2 = if IS_IMM {
180        pre_compute.c.to_le_bytes()
181    } else {
182        vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
183    };
184    let rs1 = u32::from_le_bytes(rs1);
185    let rs2 = u32::from_le_bytes(rs2);
186    let rd = <OP as AluOp>::compute(rs1, rs2);
187    let rd = rd.to_le_bytes();
188    vm_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
189    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
190    vm_state.instret += 1;
191}
192
193#[create_tco_handler]
194#[inline(always)]
195unsafe fn execute_e1_impl<
196    F: PrimeField32,
197    CTX: ExecutionCtxTrait,
198    const IS_IMM: bool,
199    OP: AluOp,
200>(
201    pre_compute: &[u8],
202    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
203) {
204    let pre_compute: &BaseAluPreCompute = pre_compute.borrow();
205    execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, vm_state);
206}
207
208#[create_tco_handler]
209#[inline(always)]
210unsafe fn execute_e2_impl<
211    F: PrimeField32,
212    CTX: MeteredExecutionCtxTrait,
213    const IS_IMM: bool,
214    OP: AluOp,
215>(
216    pre_compute: &[u8],
217    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
218) {
219    let pre_compute: &E2PreCompute<BaseAluPreCompute> = pre_compute.borrow();
220    vm_state
221        .ctx
222        .on_height_change(pre_compute.chip_idx as usize, 1);
223    execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, vm_state);
224}
225
226trait AluOp {
227    fn compute(rs1: u32, rs2: u32) -> u32;
228}
229struct AddOp;
230struct SubOp;
231struct XorOp;
232struct OrOp;
233struct AndOp;
234impl AluOp for AddOp {
235    #[inline(always)]
236    fn compute(rs1: u32, rs2: u32) -> u32 {
237        rs1.wrapping_add(rs2)
238    }
239}
240impl AluOp for SubOp {
241    #[inline(always)]
242    fn compute(rs1: u32, rs2: u32) -> u32 {
243        rs1.wrapping_sub(rs2)
244    }
245}
246impl AluOp for XorOp {
247    #[inline(always)]
248    fn compute(rs1: u32, rs2: u32) -> u32 {
249        rs1 ^ rs2
250    }
251}
252impl AluOp for OrOp {
253    #[inline(always)]
254    fn compute(rs1: u32, rs2: u32) -> u32 {
255        rs1 | rs2
256    }
257}
258impl AluOp for AndOp {
259    #[inline(always)]
260    fn compute(rs1: u32, rs2: u32) -> u32 {
261        rs1 & rs2
262    }
263}