openvm_rv32im_circuit/base_alu/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::BaseAluOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17#[allow(unused_imports)]
18use crate::{adapters::imm_to_bytes, common::*, BaseAluExecutor};
19
20#[derive(AlignedBytesBorrow, Clone)]
21#[repr(C)]
22pub(super) struct BaseAluPreCompute {
23    c: u32,
24    a: u8,
25    b: u8,
26}
27
28impl<A, const LIMB_BITS: usize> BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
29    /// Return `is_imm`, true if `e` is RV32_IMM_AS.
30    #[inline(always)]
31    pub(super) fn pre_compute_impl<F: PrimeField32>(
32        &self,
33        pc: u32,
34        inst: &Instruction<F>,
35        data: &mut BaseAluPreCompute,
36    ) -> Result<bool, StaticProgramError> {
37        let Instruction { a, b, c, d, e, .. } = inst;
38        let e_u32 = e.as_canonical_u32();
39        if (d.as_canonical_u32() != RV32_REGISTER_AS)
40            || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
41        {
42            return Err(StaticProgramError::InvalidInstruction(pc));
43        }
44        let is_imm = e_u32 == RV32_IMM_AS;
45        let c_u32 = c.as_canonical_u32();
46        *data = BaseAluPreCompute {
47            c: if is_imm {
48                u32::from_le_bytes(imm_to_bytes(c_u32))
49            } else {
50                c_u32
51            },
52            a: a.as_canonical_u32() as u8,
53            b: b.as_canonical_u32() as u8,
54        };
55        Ok(is_imm)
56    }
57}
58
59macro_rules! dispatch {
60    ($execute_impl:ident, $is_imm:ident, $opcode:expr, $offset:expr) => {
61        Ok(
62            match (
63                $is_imm,
64                BaseAluOpcode::from_usize($opcode.local_opcode_idx($offset)),
65            ) {
66                (true, BaseAluOpcode::ADD) => $execute_impl::<_, _, true, AddOp>,
67                (false, BaseAluOpcode::ADD) => $execute_impl::<_, _, false, AddOp>,
68                (true, BaseAluOpcode::SUB) => $execute_impl::<_, _, true, SubOp>,
69                (false, BaseAluOpcode::SUB) => $execute_impl::<_, _, false, SubOp>,
70                (true, BaseAluOpcode::XOR) => $execute_impl::<_, _, true, XorOp>,
71                (false, BaseAluOpcode::XOR) => $execute_impl::<_, _, false, XorOp>,
72                (true, BaseAluOpcode::OR) => $execute_impl::<_, _, true, OrOp>,
73                (false, BaseAluOpcode::OR) => $execute_impl::<_, _, false, OrOp>,
74                (true, BaseAluOpcode::AND) => $execute_impl::<_, _, true, AndOp>,
75                (false, BaseAluOpcode::AND) => $execute_impl::<_, _, false, AndOp>,
76            },
77        )
78    };
79}
80
81impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
82    for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
83where
84    F: PrimeField32,
85{
86    #[inline(always)]
87    fn pre_compute_size(&self) -> usize {
88        size_of::<BaseAluPreCompute>()
89    }
90
91    #[cfg(not(feature = "tco"))]
92    fn pre_compute<Ctx>(
93        &self,
94        pc: u32,
95        inst: &Instruction<F>,
96        data: &mut [u8],
97    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
98    where
99        Ctx: ExecutionCtxTrait,
100    {
101        let data: &mut BaseAluPreCompute = data.borrow_mut();
102        let is_imm = self.pre_compute_impl(pc, inst, data)?;
103
104        dispatch!(execute_e1_handler, is_imm, inst.opcode, self.offset)
105    }
106
107    #[cfg(feature = "tco")]
108    fn handler<Ctx>(
109        &self,
110        pc: u32,
111        inst: &Instruction<F>,
112        data: &mut [u8],
113    ) -> Result<Handler<F, Ctx>, StaticProgramError>
114    where
115        Ctx: ExecutionCtxTrait,
116    {
117        let data: &mut BaseAluPreCompute = data.borrow_mut();
118        let is_imm = self.pre_compute_impl(pc, inst, data)?;
119
120        dispatch!(execute_e1_handler, is_imm, inst.opcode, self.offset)
121    }
122}
123
124impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
125    for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
126where
127    F: PrimeField32,
128{
129    #[inline(always)]
130    fn metered_pre_compute_size(&self) -> usize {
131        size_of::<E2PreCompute<BaseAluPreCompute>>()
132    }
133
134    #[cfg(not(feature = "tco"))]
135    fn metered_pre_compute<Ctx>(
136        &self,
137        chip_idx: usize,
138        pc: u32,
139        inst: &Instruction<F>,
140        data: &mut [u8],
141    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
142    where
143        Ctx: MeteredExecutionCtxTrait,
144    {
145        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
146        data.chip_idx = chip_idx as u32;
147        let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
148
149        dispatch!(execute_e2_handler, is_imm, inst.opcode, self.offset)
150    }
151
152    #[cfg(feature = "tco")]
153    fn metered_handler<Ctx>(
154        &self,
155        chip_idx: usize,
156        pc: u32,
157        inst: &Instruction<F>,
158        data: &mut [u8],
159    ) -> Result<Handler<F, Ctx>, StaticProgramError>
160    where
161        Ctx: MeteredExecutionCtxTrait,
162    {
163        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
164        data.chip_idx = chip_idx as u32;
165        let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?;
166
167        dispatch!(execute_e2_handler, is_imm, inst.opcode, self.offset)
168    }
169}
170
171#[cfg(feature = "aot")]
172impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
173    for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
174where
175    F: PrimeField32,
176{
177    fn is_aot_supported(&self, _instruction: &Instruction<F>) -> bool {
178        true
179    }
180
181    fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
182        let to_i16 = |c: F| -> i16 {
183            let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
184            let c_i24 = ((c_u24 << 8) as i32) >> 8;
185            c_i24 as i16
186        };
187        let mut asm_str = String::new();
188
189        let a: i16 = to_i16(inst.a);
190        let b: i16 = to_i16(inst.b);
191        let c: i16 = to_i16(inst.c);
192        let e: i16 = to_i16(inst.e);
193
194        let str_reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
195            RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
196        } else {
197            REG_A_W
198        };
199
200        let mut asm_opcode = String::new();
201        if inst.opcode == BaseAluOpcode::ADD.global_opcode() {
202            asm_opcode += "add";
203        } else if inst.opcode == BaseAluOpcode::SUB.global_opcode() {
204            asm_opcode += "sub";
205        } else if inst.opcode == BaseAluOpcode::AND.global_opcode() {
206            asm_opcode += "and";
207        } else if inst.opcode == BaseAluOpcode::OR.global_opcode() {
208            asm_opcode += "or";
209        } else if inst.opcode == BaseAluOpcode::XOR.global_opcode() {
210            asm_opcode += "xor";
211        }
212
213        if e == 0 {
214            // [a:4]_1 = [a:4]_1 + c
215            let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, str_reg_a, a != b);
216            asm_str += &delta_str_b;
217            asm_str += &format!("   {asm_opcode} {gpr_reg_b}, {c}\n");
218            asm_str += &gpr_to_xmm(&gpr_reg_b, (a / 4) as u8);
219        } else if a == c {
220            let (gpr_reg_c, delta_str_c) = xmm_to_gpr((c / 4) as u8, REG_C_W, true);
221            asm_str += &delta_str_c;
222            let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, str_reg_a, true);
223            asm_str += &delta_str_b;
224            asm_str += &format!("   {asm_opcode} {gpr_reg_b}, {gpr_reg_c}\n");
225            asm_str += &gpr_to_xmm(&gpr_reg_b, (a / 4) as u8);
226        } else {
227            let (gpr_reg_b, delta_str_b) = xmm_to_gpr((b / 4) as u8, str_reg_a, true);
228            asm_str += &delta_str_b; // data is now in gpr_reg_b
229            let (gpr_reg_c, delta_str_c) = xmm_to_gpr((c / 4) as u8, REG_C_W, false); // data is in gpr_reg_c now
230            asm_str += &delta_str_c; // have to get a return value here, since it modifies further registers too
231            asm_str += &format!("   {asm_opcode} {gpr_reg_b}, {gpr_reg_c}\n");
232            asm_str += &gpr_to_xmm(&gpr_reg_b, (a / 4) as u8);
233        }
234
235        Ok(asm_str)
236    }
237}
238
239#[cfg(feature = "aot")]
240impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
241    for BaseAluExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
242where
243    F: PrimeField32,
244{
245    fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
246        true
247    }
248    fn generate_x86_metered_asm(
249        &self,
250        inst: &Instruction<F>,
251        pc: u32,
252        chip_idx: usize,
253        config: &SystemConfig,
254    ) -> Result<String, AotError> {
255        let mut asm_str = self.generate_x86_asm(inst, pc)?;
256        asm_str += &update_height_change_asm(chip_idx, 1)?;
257        // read [b:4]_1
258        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
259        // read [c:4]_1
260        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
261        if inst.e.as_canonical_u32() != RV32_IMM_AS {
262            // read [a:4]_1
263            asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
264        }
265        Ok(asm_str)
266    }
267}
268
269#[inline(always)]
270unsafe fn execute_e12_impl<
271    F: PrimeField32,
272    CTX: ExecutionCtxTrait,
273    const IS_IMM: bool,
274    OP: AluOp,
275>(
276    pre_compute: &BaseAluPreCompute,
277    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
278) {
279    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
280    let rs2 = if IS_IMM {
281        pre_compute.c.to_le_bytes()
282    } else {
283        exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
284    };
285    let rs1 = u32::from_le_bytes(rs1);
286    let rs2 = u32::from_le_bytes(rs2);
287    let rd = <OP as AluOp>::compute(rs1, rs2);
288    let rd = rd.to_le_bytes();
289    exec_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
290    let pc = exec_state.pc();
291    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
292}
293
294#[create_handler]
295#[inline(always)]
296unsafe fn execute_e1_impl<
297    F: PrimeField32,
298    CTX: ExecutionCtxTrait,
299    const IS_IMM: bool,
300    OP: AluOp,
301>(
302    pre_compute: *const u8,
303    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
304) {
305    let pre_compute: &BaseAluPreCompute =
306        std::slice::from_raw_parts(pre_compute, size_of::<BaseAluPreCompute>()).borrow();
307    execute_e12_impl::<F, CTX, IS_IMM, OP>(pre_compute, exec_state);
308}
309
310#[create_handler]
311#[inline(always)]
312unsafe fn execute_e2_impl<
313    F: PrimeField32,
314    CTX: MeteredExecutionCtxTrait,
315    const IS_IMM: bool,
316    OP: AluOp,
317>(
318    pre_compute: *const u8,
319    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
320) {
321    let pre_compute: &E2PreCompute<BaseAluPreCompute> =
322        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<BaseAluPreCompute>>())
323            .borrow();
324    exec_state
325        .ctx
326        .on_height_change(pre_compute.chip_idx as usize, 1);
327    execute_e12_impl::<F, CTX, IS_IMM, OP>(&pre_compute.data, exec_state);
328}
329
330trait AluOp {
331    fn compute(rs1: u32, rs2: u32) -> u32;
332}
333struct AddOp;
334struct SubOp;
335struct XorOp;
336struct OrOp;
337struct AndOp;
338impl AluOp for AddOp {
339    #[inline(always)]
340    fn compute(rs1: u32, rs2: u32) -> u32 {
341        rs1.wrapping_add(rs2)
342    }
343}
344impl AluOp for SubOp {
345    #[inline(always)]
346    fn compute(rs1: u32, rs2: u32) -> u32 {
347        rs1.wrapping_sub(rs2)
348    }
349}
350impl AluOp for XorOp {
351    #[inline(always)]
352    fn compute(rs1: u32, rs2: u32) -> u32 {
353        rs1 ^ rs2
354    }
355}
356impl AluOp for OrOp {
357    #[inline(always)]
358    fn compute(rs1: u32, rs2: u32) -> u32 {
359        rs1 | rs2
360    }
361}
362impl AluOp for AndOp {
363    #[inline(always)]
364    fn compute(rs1: u32, rs2: u32) -> u32 {
365        rs1 & rs2
366    }
367}