1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::BaseAluOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{adapters::imm_to_bytes, BaseAluExecutor};
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21pub(super) struct BaseAluPreCompute {
22 c: u32,
23 a: u8,
24 b: u8,
25}
26
27impl<A, const LIMB_BITS: usize> BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
28 #[inline(always)]
30 pub(super) fn pre_compute_impl<F: PrimeField32>(
31 &self,
32 pc: u32,
33 inst: &Instruction<F>,
34 data: &mut BaseAluPreCompute,
35 ) -> Result<bool, StaticProgramError> {
36 let Instruction { a, b, c, d, e, .. } = inst;
37 let e_u32 = e.as_canonical_u32();
38 if (d.as_canonical_u32() != RV32_REGISTER_AS)
39 || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
40 {
41 return Err(StaticProgramError::InvalidInstruction(pc));
42 }
43 let is_imm = e_u32 == RV32_IMM_AS;
44 let c_u32 = c.as_canonical_u32();
45 *data = BaseAluPreCompute {
46 c: if is_imm {
47 u32::from_le_bytes(imm_to_bytes(c_u32))
48 } else {
49 c_u32
50 },
51 a: a.as_canonical_u32() as u8,
52 b: b.as_canonical_u32() as u8,
53 };
54 Ok(is_imm)
55 }
56}
57
58macro_rules! dispatch {
59 ($execute_impl:ident, $is_imm:ident, $opcode:expr, $offset:expr) => {
60 Ok(
61 match (
62 $is_imm,
63 BaseAluOpcode::from_usize($opcode.local_opcode_idx($offset)),
64 ) {
65 (true, BaseAluOpcode::ADD) => $execute_impl::<_, _, true, AddOp>,
66 (false, BaseAluOpcode::ADD) => $execute_impl::<_, _, false, AddOp>,
67 (true, BaseAluOpcode::SUB) => $execute_impl::<_, _, true, SubOp>,
68 (false, BaseAluOpcode::SUB) => $execute_impl::<_, _, false, SubOp>,
69 (true, BaseAluOpcode::XOR) => $execute_impl::<_, _, true, XorOp>,
70 (false, BaseAluOpcode::XOR) => $execute_impl::<_, _, false, XorOp>,
71 (true, BaseAluOpcode::OR) => $execute_impl::<_, _, true, OrOp>,
72 (false, BaseAluOpcode::OR) => $execute_impl::<_, _, false, OrOp>,
73 (true, BaseAluOpcode::AND) => $execute_impl::<_, _, true, AndOp>,
74 (false, BaseAluOpcode::AND) => $execute_impl::<_, _, false, AndOp>,
75 },
76 )
77 };
78}
79
80impl<F, A, const LIMB_BITS: usize> Executor<F>
81 for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
82where
83 F: PrimeField32,
84{
85 #[inline(always)]
86 fn pre_compute_size(&self) -> usize {
87 size_of::<BaseAluPreCompute>()
88 }
89
90 #[cfg(not(feature = "tco"))]
91 fn pre_compute<Ctx>(
92 &self,
93 pc: u32,
94 inst: &Instruction<F>,
95 data: &mut [u8],
96 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
97 where
98 Ctx: ExecutionCtxTrait,
99 {
100 let data: &mut BaseAluPreCompute = data.borrow_mut();
101 let is_imm = self.pre_compute_impl(pc, inst, data)?;
102
103 dispatch!(execute_e1_handler, is_imm, inst.opcode, self.offset)
104 }
105
106 #[cfg(feature = "tco")]
107 fn handler<Ctx>(
108 &self,
109 pc: u32,
110 inst: &Instruction<F>,
111 data: &mut [u8],
112 ) -> Result<Handler<F, Ctx>, StaticProgramError>
113 where
114 Ctx: ExecutionCtxTrait,
115 {
116 let data: &mut BaseAluPreCompute = data.borrow_mut();
117 let is_imm = self.pre_compute_impl(pc, inst, data)?;
118
119 dispatch!(execute_e1_handler, is_imm, inst.opcode, self.offset)
120 }
121}
122
123impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
124 for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
125where
126 F: PrimeField32,
127{
128 #[inline(always)]
129 fn metered_pre_compute_size(&self) -> usize {
130 size_of::<E2PreCompute<BaseAluPreCompute>>()
131 }
132
133 #[cfg(not(feature = "tco"))]
134 fn metered_pre_compute<Ctx>(
135 &self,
136 chip_idx: usize,
137 pc: u32,
138 inst: &Instruction<F>,
139 data: &mut [u8],
140 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
141 where
142 Ctx: MeteredExecutionCtxTrait,
143 {
144 let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
145 data.chip_idx = chip_idx as u32;
146 let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
147
148 dispatch!(execute_e2_handler, is_imm, inst.opcode, self.offset)
149 }
150
151 #[cfg(feature = "tco")]
152 fn metered_handler<Ctx>(
153 &self,
154 chip_idx: usize,
155 pc: u32,
156 inst: &Instruction<F>,
157 data: &mut [u8],
158 ) -> Result<Handler<F, Ctx>, StaticProgramError>
159 where
160 Ctx: MeteredExecutionCtxTrait,
161 {
162 let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
163 data.chip_idx = chip_idx as u32;
164 let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
165
166 dispatch!(execute_e2_handler, is_imm, inst.opcode, self.offset)
167 }
168}
169
170#[inline(always)]
171unsafe fn execute_e12_impl<
172 F: PrimeField32,
173 CTX: ExecutionCtxTrait,
174 const IS_IMM: bool,
175 OP: AluOp,
176>(
177 pre_compute: &BaseAluPreCompute,
178 instret: &mut u64,
179 pc: &mut u32,
180 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
181) {
182 let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
183 let rs2 = if IS_IMM {
184 pre_compute.c.to_le_bytes()
185 } else {
186 exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
187 };
188 let rs1 = u32::from_le_bytes(rs1);
189 let rs2 = u32::from_le_bytes(rs2);
190 let rd = <OP as AluOp>::compute(rs1, rs2);
191 let rd = rd.to_le_bytes();
192 exec_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
193 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
194 *instret += 1;
195}
196
197#[create_handler]
198#[inline(always)]
199unsafe fn execute_e1_impl<
200 F: PrimeField32,
201 CTX: ExecutionCtxTrait,
202 const IS_IMM: bool,
203 OP: AluOp,
204>(
205 pre_compute: &[u8],
206 instret: &mut u64,
207 pc: &mut u32,
208 _instret_end: u64,
209 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
210) {
211 let pre_compute: &BaseAluPreCompute = pre_compute.borrow();
212 execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, instret, pc, exec_state);
213}
214
215#[create_handler]
216#[inline(always)]
217unsafe fn execute_e2_impl<
218 F: PrimeField32,
219 CTX: MeteredExecutionCtxTrait,
220 const IS_IMM: bool,
221 OP: AluOp,
222>(
223 pre_compute: &[u8],
224 instret: &mut u64,
225 pc: &mut u32,
226 _arg: u64,
227 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
228) {
229 let pre_compute: &E2PreCompute<BaseAluPreCompute> = pre_compute.borrow();
230 exec_state
231 .ctx
232 .on_height_change(pre_compute.chip_idx as usize, 1);
233 execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, instret, pc, exec_state);
234}
235
236trait AluOp {
237 fn compute(rs1: u32, rs2: u32) -> u32;
238}
239struct AddOp;
240struct SubOp;
241struct XorOp;
242struct OrOp;
243struct AndOp;
244impl AluOp for AddOp {
245 #[inline(always)]
246 fn compute(rs1: u32, rs2: u32) -> u32 {
247 rs1.wrapping_add(rs2)
248 }
249}
250impl AluOp for SubOp {
251 #[inline(always)]
252 fn compute(rs1: u32, rs2: u32) -> u32 {
253 rs1.wrapping_sub(rs2)
254 }
255}
256impl AluOp for XorOp {
257 #[inline(always)]
258 fn compute(rs1: u32, rs2: u32) -> u32 {
259 rs1 ^ rs2
260 }
261}
262impl AluOp for OrOp {
263 #[inline(always)]
264 fn compute(rs1: u32, rs2: u32) -> u32 {
265 rs1 | rs2
266 }
267}
268impl AluOp for AndOp {
269 #[inline(always)]
270 fn compute(rs1: u32, rs2: u32) -> u32 {
271 rs1 & rs2
272 }
273}