openvm_rv32im_circuit/divrem/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction,
10    program::DEFAULT_PC_STEP,
11    riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12    LocalOpcode,
13};
14use openvm_rv32im_transpiler::DivRemOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use super::core::DivRemExecutor;
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21struct DivRemPreCompute {
22    a: u8,
23    b: u8,
24    c: u8,
25}
26
27impl<A, const LIMB_BITS: usize> DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
28    #[inline(always)]
29    fn pre_compute_impl<F: PrimeField32>(
30        &self,
31        pc: u32,
32        inst: &Instruction<F>,
33        data: &mut DivRemPreCompute,
34    ) -> Result<DivRemOpcode, StaticProgramError> {
35        let &Instruction {
36            opcode, a, b, c, d, ..
37        } = inst;
38        let local_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset));
39        if d.as_canonical_u32() != RV32_REGISTER_AS {
40            return Err(StaticProgramError::InvalidInstruction(pc));
41        }
42        let pre_compute: &mut DivRemPreCompute = data.borrow_mut();
43        *pre_compute = DivRemPreCompute {
44            a: a.as_canonical_u32() as u8,
45            b: b.as_canonical_u32() as u8,
46            c: c.as_canonical_u32() as u8,
47        };
48        Ok(local_opcode)
49    }
50}
51
52macro_rules! dispatch {
53    ($execute_impl:ident, $local_opcode:ident) => {
54        match $local_opcode {
55            DivRemOpcode::DIV => Ok($execute_impl::<_, _, DivOp>),
56            DivRemOpcode::DIVU => Ok($execute_impl::<_, _, DivuOp>),
57            DivRemOpcode::REM => Ok($execute_impl::<_, _, RemOp>),
58            DivRemOpcode::REMU => Ok($execute_impl::<_, _, RemuOp>),
59        }
60    };
61}
62
63impl<F, A, const LIMB_BITS: usize> Executor<F>
64    for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
65where
66    F: PrimeField32,
67{
68    #[inline(always)]
69    fn pre_compute_size(&self) -> usize {
70        size_of::<DivRemPreCompute>()
71    }
72
73    #[cfg(not(feature = "tco"))]
74    #[inline(always)]
75    fn pre_compute<Ctx: ExecutionCtxTrait>(
76        &self,
77        pc: u32,
78        inst: &Instruction<F>,
79        data: &mut [u8],
80    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
81        let data: &mut DivRemPreCompute = data.borrow_mut();
82        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
83        dispatch!(execute_e1_handler, local_opcode)
84    }
85
86    #[cfg(feature = "tco")]
87    fn handler<Ctx>(
88        &self,
89        pc: u32,
90        inst: &Instruction<F>,
91        data: &mut [u8],
92    ) -> Result<Handler<F, Ctx>, StaticProgramError>
93    where
94        Ctx: ExecutionCtxTrait,
95    {
96        let data: &mut DivRemPreCompute = data.borrow_mut();
97        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
98        dispatch!(execute_e1_handler, local_opcode)
99    }
100}
101
102impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
103    for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
104where
105    F: PrimeField32,
106{
107    fn metered_pre_compute_size(&self) -> usize {
108        size_of::<E2PreCompute<DivRemPreCompute>>()
109    }
110
111    #[cfg(not(feature = "tco"))]
112    fn metered_pre_compute<Ctx>(
113        &self,
114        chip_idx: usize,
115        pc: u32,
116        inst: &Instruction<F>,
117        data: &mut [u8],
118    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
119    where
120        Ctx: MeteredExecutionCtxTrait,
121    {
122        let data: &mut E2PreCompute<DivRemPreCompute> = data.borrow_mut();
123        data.chip_idx = chip_idx as u32;
124        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
125        dispatch!(execute_e2_handler, local_opcode)
126    }
127
128    #[cfg(feature = "tco")]
129    fn metered_handler<Ctx>(
130        &self,
131        chip_idx: usize,
132        pc: u32,
133        inst: &Instruction<F>,
134        data: &mut [u8],
135    ) -> Result<Handler<F, Ctx>, StaticProgramError>
136    where
137        Ctx: MeteredExecutionCtxTrait,
138    {
139        let data: &mut E2PreCompute<DivRemPreCompute> = data.borrow_mut();
140        data.chip_idx = chip_idx as u32;
141        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
142        dispatch!(execute_e2_handler, local_opcode)
143    }
144}
145
146#[inline(always)]
147unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: DivRemOp>(
148    pre_compute: &DivRemPreCompute,
149    instret: &mut u64,
150    pc: &mut u32,
151    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
152) {
153    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
154    let rs2 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
155    let result = <OP as DivRemOp>::compute(rs1, rs2);
156    exec_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &result);
157    *pc = pc.wrapping_add(DEFAULT_PC_STEP);
158    *instret += 1;
159}
160
161#[create_handler]
162#[inline(always)]
163unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: DivRemOp>(
164    pre_compute: &[u8],
165    instret: &mut u64,
166    pc: &mut u32,
167    _instret_end: u64,
168    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
169) {
170    let pre_compute: &DivRemPreCompute = pre_compute.borrow();
171    execute_e12_impl::<F, CTX, OP>(pre_compute, instret, pc, exec_state);
172}
173
174#[create_handler]
175#[inline(always)]
176unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: DivRemOp>(
177    pre_compute: &[u8],
178    instret: &mut u64,
179    pc: &mut u32,
180    _arg: u64,
181    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
182) {
183    let pre_compute: &E2PreCompute<DivRemPreCompute> = pre_compute.borrow();
184    exec_state
185        .ctx
186        .on_height_change(pre_compute.chip_idx as usize, 1);
187    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, instret, pc, exec_state);
188}
189
190trait DivRemOp {
191    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4];
192}
193struct DivOp;
194struct DivuOp;
195struct RemOp;
196struct RemuOp;
197
198impl DivRemOp for DivOp {
199    #[inline(always)]
200    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
201        let rs1_i32 = i32::from_le_bytes(rs1);
202        let rs2_i32 = i32::from_le_bytes(rs2);
203        match (rs1_i32, rs2_i32) {
204            (_, 0) => [u8::MAX; 4],
205            (i32::MIN, -1) => rs1,
206            _ => (rs1_i32 / rs2_i32).to_le_bytes(),
207        }
208    }
209}
210
211impl DivRemOp for DivuOp {
212    #[inline(always)]
213    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
214        if rs2 == [0; 4] {
215            [u8::MAX; 4]
216        } else {
217            let rs1 = u32::from_le_bytes(rs1);
218            let rs2 = u32::from_le_bytes(rs2);
219            (rs1 / rs2).to_le_bytes()
220        }
221    }
222}
223
224impl DivRemOp for RemOp {
225    #[inline(always)]
226    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
227        let rs1_i32 = i32::from_le_bytes(rs1);
228        let rs2_i32 = i32::from_le_bytes(rs2);
229        match (rs1_i32, rs2_i32) {
230            (_, 0) => rs1,
231            (i32::MIN, -1) => [0; 4],
232            _ => (rs1_i32 % rs2_i32).to_le_bytes(),
233        }
234    }
235}
236
237impl DivRemOp for RemuOp {
238    #[inline(always)]
239    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
240        if rs2 == [0; 4] {
241            rs1
242        } else {
243            let rs1 = u32::from_le_bytes(rs1);
244            let rs2 = u32::from_le_bytes(rs2);
245            (rs1 % rs2).to_le_bytes()
246        }
247    }
248}