openvm_rv32im_circuit/mul/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::MulOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::MultiplicationExecutor;
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21struct MultiPreCompute {
22    a: u8,
23    b: u8,
24    c: u8,
25}
26
27impl<A, const LIMB_BITS: usize> MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
28    fn pre_compute_impl<F: PrimeField32>(
29        &self,
30        pc: u32,
31        inst: &Instruction<F>,
32        data: &mut MultiPreCompute,
33    ) -> Result<(), StaticProgramError> {
34        assert_eq!(
35            MulOpcode::from_usize(inst.opcode.local_opcode_idx(self.offset)),
36            MulOpcode::MUL
37        );
38        if inst.d.as_canonical_u32() != RV32_REGISTER_AS {
39            return Err(StaticProgramError::InvalidInstruction(pc));
40        }
41
42        *data = MultiPreCompute {
43            a: inst.a.as_canonical_u32() as u8,
44            b: inst.b.as_canonical_u32() as u8,
45            c: inst.c.as_canonical_u32() as u8,
46        };
47        Ok(())
48    }
49}
50
51impl<F, A, const LIMB_BITS: usize> Executor<F>
52    for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
53where
54    F: PrimeField32,
55{
56    fn pre_compute_size(&self) -> usize {
57        size_of::<MultiPreCompute>()
58    }
59    #[cfg(not(feature = "tco"))]
60    fn pre_compute<Ctx>(
61        &self,
62        pc: u32,
63        inst: &Instruction<F>,
64        data: &mut [u8],
65    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
66    where
67        Ctx: ExecutionCtxTrait,
68    {
69        let pre_compute: &mut MultiPreCompute = data.borrow_mut();
70        self.pre_compute_impl(pc, inst, pre_compute)?;
71        Ok(execute_e1_impl)
72    }
73
74    #[cfg(feature = "tco")]
75    fn handler<Ctx>(
76        &self,
77        pc: u32,
78        inst: &Instruction<F>,
79        data: &mut [u8],
80    ) -> Result<Handler<F, Ctx>, StaticProgramError>
81    where
82        Ctx: ExecutionCtxTrait,
83    {
84        let pre_compute: &mut MultiPreCompute = data.borrow_mut();
85        self.pre_compute_impl(pc, inst, pre_compute)?;
86        Ok(execute_e1_handler)
87    }
88}
89
90impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
91    for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
92where
93    F: PrimeField32,
94{
95    fn metered_pre_compute_size(&self) -> usize {
96        size_of::<E2PreCompute<MultiPreCompute>>()
97    }
98
99    #[cfg(not(feature = "tco"))]
100    fn metered_pre_compute<Ctx>(
101        &self,
102        chip_idx: usize,
103        pc: u32,
104        inst: &Instruction<F>,
105        data: &mut [u8],
106    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
107    where
108        Ctx: MeteredExecutionCtxTrait,
109    {
110        let pre_compute: &mut E2PreCompute<MultiPreCompute> = data.borrow_mut();
111        pre_compute.chip_idx = chip_idx as u32;
112        self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
113        Ok(execute_e2_impl)
114    }
115
116    #[cfg(feature = "tco")]
117    fn metered_handler<Ctx>(
118        &self,
119        chip_idx: usize,
120        pc: u32,
121        inst: &Instruction<F>,
122        data: &mut [u8],
123    ) -> Result<Handler<F, Ctx>, StaticProgramError>
124    where
125        Ctx: MeteredExecutionCtxTrait,
126    {
127        let pre_compute: &mut E2PreCompute<MultiPreCompute> = data.borrow_mut();
128        pre_compute.chip_idx = chip_idx as u32;
129        self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
130        Ok(execute_e2_handler)
131    }
132}
133
134#[inline(always)]
135unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
136    pre_compute: &MultiPreCompute,
137    instret: &mut u64,
138    pc: &mut u32,
139    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
140) {
141    let rs1: [u8; RV32_REGISTER_NUM_LIMBS] =
142        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
143    let rs2: [u8; RV32_REGISTER_NUM_LIMBS] =
144        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32);
145    let rs1 = u32::from_le_bytes(rs1);
146    let rs2 = u32::from_le_bytes(rs2);
147    let rd = rs1.wrapping_mul(rs2);
148    exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd.to_le_bytes());
149
150    *pc += DEFAULT_PC_STEP;
151    *instret += 1;
152}
153
154#[create_handler]
155#[inline(always)]
156unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
157    pre_compute: &[u8],
158    instret: &mut u64,
159    pc: &mut u32,
160    _instret_end: u64,
161    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
162) {
163    let pre_compute: &MultiPreCompute = pre_compute.borrow();
164    execute_e12_impl(pre_compute, instret, pc, exec_state);
165}
166
167#[create_handler]
168#[inline(always)]
169unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
170    pre_compute: &[u8],
171    instret: &mut u64,
172    pc: &mut u32,
173    _arg: u64,
174    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
175) {
176    let pre_compute: &E2PreCompute<MultiPreCompute> = pre_compute.borrow();
177    exec_state
178        .ctx
179        .on_height_change(pre_compute.chip_idx as usize, 1);
180    execute_e12_impl(&pre_compute.data, instret, pc, exec_state);
181}