1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::BaseAluOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17#[allow(unused_imports)]
18use crate::{adapters::imm_to_bytes, common::*, BaseAluExecutor};
19
20#[derive(AlignedBytesBorrow, Clone)]
21#[repr(C)]
22pub(super) struct BaseAluPreCompute {
23 c: u32,
24 a: u8,
25 b: u8,
26}
27
28impl<A, const LIMB_BITS: usize> BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
29 #[inline(always)]
31 pub(super) fn pre_compute_impl<F: PrimeField32>(
32 &self,
33 pc: u32,
34 inst: &Instruction<F>,
35 data: &mut BaseAluPreCompute,
36 ) -> Result<bool, StaticProgramError> {
37 let Instruction { a, b, c, d, e, .. } = inst;
38 let e_u32 = e.as_canonical_u32();
39 if (d.as_canonical_u32() != RV32_REGISTER_AS)
40 || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
41 {
42 return Err(StaticProgramError::InvalidInstruction(pc));
43 }
44 let is_imm = e_u32 == RV32_IMM_AS;
45 let c_u32 = c.as_canonical_u32();
46 *data = BaseAluPreCompute {
47 c: if is_imm {
48 u32::from_le_bytes(imm_to_bytes(c_u32))
49 } else {
50 c_u32
51 },
52 a: a.as_canonical_u32() as u8,
53 b: b.as_canonical_u32() as u8,
54 };
55 Ok(is_imm)
56 }
57}
58
59macro_rules! dispatch {
60 ($execute_impl:ident, $is_imm:ident, $opcode:expr, $offset:expr) => {
61 Ok(
62 match (
63 $is_imm,
64 BaseAluOpcode::from_usize($opcode.local_opcode_idx($offset)),
65 ) {
66 (true, BaseAluOpcode::ADD) => $execute_impl::<_, _, true, AddOp>,
67 (false, BaseAluOpcode::ADD) => $execute_impl::<_, _, false, AddOp>,
68 (true, BaseAluOpcode::SUB) => $execute_impl::<_, _, true, SubOp>,
69 (false, BaseAluOpcode::SUB) => $execute_impl::<_, _, false, SubOp>,
70 (true, BaseAluOpcode::XOR) => $execute_impl::<_, _, true, XorOp>,
71 (false, BaseAluOpcode::XOR) => $execute_impl::<_, _, false, XorOp>,
72 (true, BaseAluOpcode::OR) => $execute_impl::<_, _, true, OrOp>,
73 (false, BaseAluOpcode::OR) => $execute_impl::<_, _, false, OrOp>,
74 (true, BaseAluOpcode::AND) => $execute_impl::<_, _, true, AndOp>,
75 (false, BaseAluOpcode::AND) => $execute_impl::<_, _, false, AndOp>,
76 },
77 )
78 };
79}
80
81impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
82 for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
83where
84 F: PrimeField32,
85{
86 #[inline(always)]
87 fn pre_compute_size(&self) -> usize {
88 size_of::<BaseAluPreCompute>()
89 }
90
91 #[cfg(not(feature = "tco"))]
92 fn pre_compute<Ctx>(
93 &self,
94 pc: u32,
95 inst: &Instruction<F>,
96 data: &mut [u8],
97 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
98 where
99 Ctx: ExecutionCtxTrait,
100 {
101 let data: &mut BaseAluPreCompute = data.borrow_mut();
102 let is_imm = self.pre_compute_impl(pc, inst, data)?;
103
104 dispatch!(execute_e1_handler, is_imm, inst.opcode, self.offset)
105 }
106
107 #[cfg(feature = "tco")]
108 fn handler<Ctx>(
109 &self,
110 pc: u32,
111 inst: &Instruction<F>,
112 data: &mut [u8],
113 ) -> Result<Handler<F, Ctx>, StaticProgramError>
114 where
115 Ctx: ExecutionCtxTrait,
116 {
117 let data: &mut BaseAluPreCompute = data.borrow_mut();
118 let is_imm = self.pre_compute_impl(pc, inst, data)?;
119
120 dispatch!(execute_e1_handler, is_imm, inst.opcode, self.offset)
121 }
122}
123
124impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
125 for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
126where
127 F: PrimeField32,
128{
129 #[inline(always)]
130 fn metered_pre_compute_size(&self) -> usize {
131 size_of::<E2PreCompute<BaseAluPreCompute>>()
132 }
133
134 #[cfg(not(feature = "tco"))]
135 fn metered_pre_compute<Ctx>(
136 &self,
137 chip_idx: usize,
138 pc: u32,
139 inst: &Instruction<F>,
140 data: &mut [u8],
141 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
142 where
143 Ctx: MeteredExecutionCtxTrait,
144 {
145 let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
146 data.chip_idx = chip_idx as u32;
147 let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
148
149 dispatch!(execute_e2_handler, is_imm, inst.opcode, self.offset)
150 }
151
152 #[cfg(feature = "tco")]
153 fn metered_handler<Ctx>(
154 &self,
155 chip_idx: usize,
156 pc: u32,
157 inst: &Instruction<F>,
158 data: &mut [u8],
159 ) -> Result<Handler<F, Ctx>, StaticProgramError>
160 where
161 Ctx: MeteredExecutionCtxTrait,
162 {
163 let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
164 data.chip_idx = chip_idx as u32;
165 let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
166
167 dispatch!(execute_e2_handler, is_imm, inst.opcode, self.offset)
168 }
169}
170
171#[cfg(feature = "aot")]
172impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
173 for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
174where
175 F: PrimeField32,
176{
177 fn is_aot_supported(&self, _instruction: &Instruction<F>) -> bool {
178 true
179 }
180
181 fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
182 let to_i16 = |c: F| -> i16 {
183 let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
184 let c_i24 = ((c_u24 << 8) as i32) >> 8;
185 c_i24 as i16
186 };
187 let mut asm_str = String::new();
188
189 let a: i16 = to_i16(inst.a);
190 let b: i16 = to_i16(inst.b);
191 let c: i16 = to_i16(inst.c);
192 let e: i16 = to_i16(inst.e);
193
194 let str_reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
195 RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
196 } else {
197 REG_A_W
198 };
199
200 let mut asm_opcode = String::new();
201 if inst.opcode == BaseAluOpcode::ADD.global_opcode() {
202 asm_opcode += "add";
203 } else if inst.opcode == BaseAluOpcode::SUB.global_opcode() {
204 asm_opcode += "sub";
205 } else if inst.opcode == BaseAluOpcode::AND.global_opcode() {
206 asm_opcode += "and";
207 } else if inst.opcode == BaseAluOpcode::OR.global_opcode() {
208 asm_opcode += "or";
209 } else if inst.opcode == BaseAluOpcode::XOR.global_opcode() {
210 asm_opcode += "xor";
211 }
212
213 if e == 0 {
214 let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, str_reg_a, a != b);
216 asm_str += &delta_str_b;
217 asm_str += &format!(" {asm_opcode} {gpr_reg_b}, {c}\n");
218 asm_str += &gpr_to_xmm(&gpr_reg_b, (a / 4) as u8);
219 } else if a == c {
220 let (gpr_reg_c, delta_str_c) = xmm_to_gpr((c / 4) as u8, REG_C_W, true);
221 asm_str += &delta_str_c;
222 let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, str_reg_a, true);
223 asm_str += &delta_str_b;
224 asm_str += &format!(" {asm_opcode} {gpr_reg_b}, {gpr_reg_c}\n");
225 asm_str += &gpr_to_xmm(&gpr_reg_b, (a / 4) as u8);
226 } else {
227 let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, str_reg_a, true);
228 asm_str += &delta_str_b; let (gpr_reg_c, delta_str_c) = xmm_to_gpr((c / 4) as u8, REG_C_W, false); asm_str += &delta_str_c; asm_str += &format!(" {asm_opcode} {gpr_reg_b}, {gpr_reg_c}\n");
232 asm_str += &gpr_to_xmm(&gpr_reg_b, (a / 4) as u8);
233 }
234
235 Ok(asm_str)
236 }
237}
238
239#[cfg(feature = "aot")]
240impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
241 for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
242where
243 F: PrimeField32,
244{
245 fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
246 true
247 }
248 fn generate_x86_metered_asm(
249 &self,
250 inst: &Instruction<F>,
251 pc: u32,
252 chip_idx: usize,
253 config: &SystemConfig,
254 ) -> Result<String, AotError> {
255 let mut asm_str = self.generate_x86_asm(inst, pc)?;
256 asm_str += &update_height_change_asm(chip_idx, 1)?;
257 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
259 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
261 if inst.e.as_canonical_u32() != RV32_IMM_AS {
262 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
264 }
265 Ok(asm_str)
266 }
267}
268
269#[inline(always)]
270unsafe fn execute_e12_impl<
271 F: PrimeField32,
272 CTX: ExecutionCtxTrait,
273 const IS_IMM: bool,
274 OP: AluOp,
275>(
276 pre_compute: &BaseAluPreCompute,
277 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
278) {
279 let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
280 let rs2 = if IS_IMM {
281 pre_compute.c.to_le_bytes()
282 } else {
283 exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
284 };
285 let rs1 = u32::from_le_bytes(rs1);
286 let rs2 = u32::from_le_bytes(rs2);
287 let rd = <OP as AluOp>::compute(rs1, rs2);
288 let rd = rd.to_le_bytes();
289 exec_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
290 let pc = exec_state.pc();
291 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
292}
293
294#[create_handler]
295#[inline(always)]
296unsafe fn execute_e1_impl<
297 F: PrimeField32,
298 CTX: ExecutionCtxTrait,
299 const IS_IMM: bool,
300 OP: AluOp,
301>(
302 pre_compute: *const u8,
303 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
304) {
305 let pre_compute: &BaseAluPreCompute =
306 std::slice::from_raw_parts(pre_compute, size_of::<BaseAluPreCompute>()).borrow();
307 execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, exec_state);
308}
309
310#[create_handler]
311#[inline(always)]
312unsafe fn execute_e2_impl<
313 F: PrimeField32,
314 CTX: MeteredExecutionCtxTrait,
315 const IS_IMM: bool,
316 OP: AluOp,
317>(
318 pre_compute: *const u8,
319 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
320) {
321 let pre_compute: &E2PreCompute<BaseAluPreCompute> =
322 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<BaseAluPreCompute>>())
323 .borrow();
324 exec_state
325 .ctx
326 .on_height_change(pre_compute.chip_idx as usize, 1);
327 execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, exec_state);
328}
329
330trait AluOp {
331 fn compute(rs1: u32, rs2: u32) -> u32;
332}
333struct AddOp;
334struct SubOp;
335struct XorOp;
336struct OrOp;
337struct AndOp;
338impl AluOp for AddOp {
339 #[inline(always)]
340 fn compute(rs1: u32, rs2: u32) -> u32 {
341 rs1.wrapping_add(rs2)
342 }
343}
344impl AluOp for SubOp {
345 #[inline(always)]
346 fn compute(rs1: u32, rs2: u32) -> u32 {
347 rs1.wrapping_sub(rs2)
348 }
349}
350impl AluOp for XorOp {
351 #[inline(always)]
352 fn compute(rs1: u32, rs2: u32) -> u32 {
353 rs1 ^ rs2
354 }
355}
356impl AluOp for OrOp {
357 #[inline(always)]
358 fn compute(rs1: u32, rs2: u32) -> u32 {
359 rs1 | rs2
360 }
361}
362impl AluOp for AndOp {
363 #[inline(always)]
364 fn compute(rs1: u32, rs2: u32) -> u32 {
365 rs1 & rs2
366 }
367}