1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::BaseAluOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{adapters::imm_to_bytes, BaseAluExecutor};
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21pub(super) struct BaseAluPreCompute {
22 c: u32,
23 a: u8,
24 b: u8,
25}
26
27impl<A, const LIMB_BITS: usize> BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
28 #[inline(always)]
30 pub(super) fn pre_compute_impl<F: PrimeField32>(
31 &self,
32 pc: u32,
33 inst: &Instruction<F>,
34 data: &mut BaseAluPreCompute,
35 ) -> Result<bool, StaticProgramError> {
36 let Instruction { a, b, c, d, e, .. } = inst;
37 let e_u32 = e.as_canonical_u32();
38 if (d.as_canonical_u32() != RV32_REGISTER_AS)
39 || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
40 {
41 return Err(StaticProgramError::InvalidInstruction(pc));
42 }
43 let is_imm = e_u32 == RV32_IMM_AS;
44 let c_u32 = c.as_canonical_u32();
45 *data = BaseAluPreCompute {
46 c: if is_imm {
47 u32::from_le_bytes(imm_to_bytes(c_u32))
48 } else {
49 c_u32
50 },
51 a: a.as_canonical_u32() as u8,
52 b: b.as_canonical_u32() as u8,
53 };
54 Ok(is_imm)
55 }
56}
57
58macro_rules! dispatch {
59 ($execute_impl:ident, $is_imm:ident, $opcode:expr, $offset:expr) => {
60 Ok(
61 match (
62 $is_imm,
63 BaseAluOpcode::from_usize($opcode.local_opcode_idx($offset)),
64 ) {
65 (true, BaseAluOpcode::ADD) => $execute_impl::<_, _, true, AddOp>,
66 (false, BaseAluOpcode::ADD) => $execute_impl::<_, _, false, AddOp>,
67 (true, BaseAluOpcode::SUB) => $execute_impl::<_, _, true, SubOp>,
68 (false, BaseAluOpcode::SUB) => $execute_impl::<_, _, false, SubOp>,
69 (true, BaseAluOpcode::XOR) => $execute_impl::<_, _, true, XorOp>,
70 (false, BaseAluOpcode::XOR) => $execute_impl::<_, _, false, XorOp>,
71 (true, BaseAluOpcode::OR) => $execute_impl::<_, _, true, OrOp>,
72 (false, BaseAluOpcode::OR) => $execute_impl::<_, _, false, OrOp>,
73 (true, BaseAluOpcode::AND) => $execute_impl::<_, _, true, AndOp>,
74 (false, BaseAluOpcode::AND) => $execute_impl::<_, _, false, AndOp>,
75 },
76 )
77 };
78}
79
80impl<F, A, const LIMB_BITS: usize> Executor<F>
81 for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
82where
83 F: PrimeField32,
84{
85 #[inline(always)]
86 fn pre_compute_size(&self) -> usize {
87 size_of::<BaseAluPreCompute>()
88 }
89
90 fn pre_compute<Ctx>(
91 &self,
92 pc: u32,
93 inst: &Instruction<F>,
94 data: &mut [u8],
95 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
96 where
97 Ctx: ExecutionCtxTrait,
98 {
99 let data: &mut BaseAluPreCompute = data.borrow_mut();
100 let is_imm = self.pre_compute_impl(pc, inst, data)?;
101
102 dispatch!(execute_e1_impl, is_imm, inst.opcode, self.offset)
103 }
104
105 #[cfg(feature = "tco")]
106 fn handler<Ctx>(
107 &self,
108 pc: u32,
109 inst: &Instruction<F>,
110 data: &mut [u8],
111 ) -> Result<Handler<F, Ctx>, StaticProgramError>
112 where
113 Ctx: ExecutionCtxTrait,
114 {
115 let data: &mut BaseAluPreCompute = data.borrow_mut();
116 let is_imm = self.pre_compute_impl(pc, inst, data)?;
117
118 dispatch!(execute_e1_tco_handler, is_imm, inst.opcode, self.offset)
119 }
120}
121
122impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
123 for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
124where
125 F: PrimeField32,
126{
127 #[inline(always)]
128 fn metered_pre_compute_size(&self) -> usize {
129 size_of::<E2PreCompute<BaseAluPreCompute>>()
130 }
131
132 fn metered_pre_compute<Ctx>(
133 &self,
134 chip_idx: usize,
135 pc: u32,
136 inst: &Instruction<F>,
137 data: &mut [u8],
138 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
139 where
140 Ctx: MeteredExecutionCtxTrait,
141 {
142 let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
143 data.chip_idx = chip_idx as u32;
144 let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
145
146 dispatch!(execute_e2_impl, is_imm, inst.opcode, self.offset)
147 }
148
149 #[cfg(feature = "tco")]
150 fn metered_handler<Ctx>(
151 &self,
152 chip_idx: usize,
153 pc: u32,
154 inst: &Instruction<F>,
155 data: &mut [u8],
156 ) -> Result<Handler<F, Ctx>, StaticProgramError>
157 where
158 Ctx: MeteredExecutionCtxTrait,
159 {
160 let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
161 data.chip_idx = chip_idx as u32;
162 let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
163
164 dispatch!(execute_e2_tco_handler, is_imm, inst.opcode, self.offset)
165 }
166}
167
168#[inline(always)]
169unsafe fn execute_e12_impl<
170 F: PrimeField32,
171 CTX: ExecutionCtxTrait,
172 const IS_IMM: bool,
173 OP: AluOp,
174>(
175 pre_compute: &BaseAluPreCompute,
176 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
177) {
178 let rs1 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
179 let rs2 = if IS_IMM {
180 pre_compute.c.to_le_bytes()
181 } else {
182 vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
183 };
184 let rs1 = u32::from_le_bytes(rs1);
185 let rs2 = u32::from_le_bytes(rs2);
186 let rd = <OP as AluOp>::compute(rs1, rs2);
187 let rd = rd.to_le_bytes();
188 vm_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
189 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
190 vm_state.instret += 1;
191}
192
193#[create_tco_handler]
194#[inline(always)]
195unsafe fn execute_e1_impl<
196 F: PrimeField32,
197 CTX: ExecutionCtxTrait,
198 const IS_IMM: bool,
199 OP: AluOp,
200>(
201 pre_compute: &[u8],
202 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
203) {
204 let pre_compute: &BaseAluPreCompute = pre_compute.borrow();
205 execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, vm_state);
206}
207
208#[create_tco_handler]
209#[inline(always)]
210unsafe fn execute_e2_impl<
211 F: PrimeField32,
212 CTX: MeteredExecutionCtxTrait,
213 const IS_IMM: bool,
214 OP: AluOp,
215>(
216 pre_compute: &[u8],
217 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
218) {
219 let pre_compute: &E2PreCompute<BaseAluPreCompute> = pre_compute.borrow();
220 vm_state
221 .ctx
222 .on_height_change(pre_compute.chip_idx as usize, 1);
223 execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, vm_state);
224}
225
226trait AluOp {
227 fn compute(rs1: u32, rs2: u32) -> u32;
228}
229struct AddOp;
230struct SubOp;
231struct XorOp;
232struct OrOp;
233struct AndOp;
234impl AluOp for AddOp {
235 #[inline(always)]
236 fn compute(rs1: u32, rs2: u32) -> u32 {
237 rs1.wrapping_add(rs2)
238 }
239}
240impl AluOp for SubOp {
241 #[inline(always)]
242 fn compute(rs1: u32, rs2: u32) -> u32 {
243 rs1.wrapping_sub(rs2)
244 }
245}
246impl AluOp for XorOp {
247 #[inline(always)]
248 fn compute(rs1: u32, rs2: u32) -> u32 {
249 rs1 ^ rs2
250 }
251}
252impl AluOp for OrOp {
253 #[inline(always)]
254 fn compute(rs1: u32, rs2: u32) -> u32 {
255 rs1 | rs2
256 }
257}
258impl AluOp for AndOp {
259 #[inline(always)]
260 fn compute(rs1: u32, rs2: u32) -> u32 {
261 rs1 & rs2
262 }
263}