openvm_rv32im_circuit/divrem/
execution.rs1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction,
10 program::DEFAULT_PC_STEP,
11 riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
12 LocalOpcode,
13};
14use openvm_rv32im_transpiler::DivRemOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use super::core::DivRemExecutor;
18
19#[derive(AlignedBytesBorrow, Clone)]
20#[repr(C)]
21struct DivRemPreCompute {
22 a: u8,
23 b: u8,
24 c: u8,
25}
26
27impl<A, const LIMB_BITS: usize> DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
28 #[inline(always)]
29 fn pre_compute_impl<F: PrimeField32>(
30 &self,
31 pc: u32,
32 inst: &Instruction<F>,
33 data: &mut DivRemPreCompute,
34 ) -> Result<DivRemOpcode, StaticProgramError> {
35 let &Instruction {
36 opcode, a, b, c, d, ..
37 } = inst;
38 let local_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset));
39 if d.as_canonical_u32() != RV32_REGISTER_AS {
40 return Err(StaticProgramError::InvalidInstruction(pc));
41 }
42 let pre_compute: &mut DivRemPreCompute = data.borrow_mut();
43 *pre_compute = DivRemPreCompute {
44 a: a.as_canonical_u32() as u8,
45 b: b.as_canonical_u32() as u8,
46 c: c.as_canonical_u32() as u8,
47 };
48 Ok(local_opcode)
49 }
50}
51
52macro_rules! dispatch {
53 ($execute_impl:ident, $local_opcode:ident) => {
54 match $local_opcode {
55 DivRemOpcode::DIV => Ok($execute_impl::<_, _, DivOp>),
56 DivRemOpcode::DIVU => Ok($execute_impl::<_, _, DivuOp>),
57 DivRemOpcode::REM => Ok($execute_impl::<_, _, RemOp>),
58 DivRemOpcode::REMU => Ok($execute_impl::<_, _, RemuOp>),
59 }
60 };
61}
62
63impl<F, A, const LIMB_BITS: usize> Executor<F>
64 for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
65where
66 F: PrimeField32,
67{
68 #[inline(always)]
69 fn pre_compute_size(&self) -> usize {
70 size_of::<DivRemPreCompute>()
71 }
72
73 #[inline(always)]
74 fn pre_compute<Ctx: ExecutionCtxTrait>(
75 &self,
76 pc: u32,
77 inst: &Instruction<F>,
78 data: &mut [u8],
79 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
80 let data: &mut DivRemPreCompute = data.borrow_mut();
81 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
82 dispatch!(execute_e1_impl, local_opcode)
83 }
84
85 #[cfg(feature = "tco")]
86 fn handler<Ctx>(
87 &self,
88 pc: u32,
89 inst: &Instruction<F>,
90 data: &mut [u8],
91 ) -> Result<Handler<F, Ctx>, StaticProgramError>
92 where
93 Ctx: ExecutionCtxTrait,
94 {
95 let data: &mut DivRemPreCompute = data.borrow_mut();
96 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
97 dispatch!(execute_e1_tco_handler, local_opcode)
98 }
99}
100
101impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
102 for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
103where
104 F: PrimeField32,
105{
106 fn metered_pre_compute_size(&self) -> usize {
107 size_of::<E2PreCompute<DivRemPreCompute>>()
108 }
109
110 fn metered_pre_compute<Ctx>(
111 &self,
112 chip_idx: usize,
113 pc: u32,
114 inst: &Instruction<F>,
115 data: &mut [u8],
116 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
117 where
118 Ctx: MeteredExecutionCtxTrait,
119 {
120 let data: &mut E2PreCompute<DivRemPreCompute> = data.borrow_mut();
121 data.chip_idx = chip_idx as u32;
122 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
123 dispatch!(execute_e2_impl, local_opcode)
124 }
125
126 #[cfg(feature = "tco")]
127 fn metered_handler<Ctx>(
128 &self,
129 chip_idx: usize,
130 pc: u32,
131 inst: &Instruction<F>,
132 data: &mut [u8],
133 ) -> Result<Handler<F, Ctx>, StaticProgramError>
134 where
135 Ctx: MeteredExecutionCtxTrait,
136 {
137 let data: &mut E2PreCompute<DivRemPreCompute> = data.borrow_mut();
138 data.chip_idx = chip_idx as u32;
139 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
140 dispatch!(execute_e2_tco_handler, local_opcode)
141 }
142}
143
144unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: DivRemOp>(
145 pre_compute: &DivRemPreCompute,
146 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
147) {
148 let rs1 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
149 let rs2 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
150 let result = <OP as DivRemOp>::compute(rs1, rs2);
151 vm_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &result);
152 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
153 vm_state.instret += 1;
154}
155
156#[create_tco_handler]
157unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: DivRemOp>(
158 pre_compute: &[u8],
159 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
160) {
161 let pre_compute: &DivRemPreCompute = pre_compute.borrow();
162 execute_e12_impl::<F, CTX, OP>(pre_compute, vm_state);
163}
164
165#[create_tco_handler]
166unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: DivRemOp>(
167 pre_compute: &[u8],
168 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
169) {
170 let pre_compute: &E2PreCompute<DivRemPreCompute> = pre_compute.borrow();
171 vm_state
172 .ctx
173 .on_height_change(pre_compute.chip_idx as usize, 1);
174 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, vm_state);
175}
176
177trait DivRemOp {
178 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4];
179}
180struct DivOp;
181struct DivuOp;
182struct RemOp;
183struct RemuOp;
184
185impl DivRemOp for DivOp {
186 #[inline(always)]
187 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
188 let rs1_i32 = i32::from_le_bytes(rs1);
189 let rs2_i32 = i32::from_le_bytes(rs2);
190 match (rs1_i32, rs2_i32) {
191 (_, 0) => [u8::MAX; 4],
192 (i32::MIN, -1) => rs1,
193 _ => (rs1_i32 / rs2_i32).to_le_bytes(),
194 }
195 }
196}
197
198impl DivRemOp for DivuOp {
199 #[inline(always)]
200 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
201 if rs2 == [0; 4] {
202 [u8::MAX; 4]
203 } else {
204 let rs1 = u32::from_le_bytes(rs1);
205 let rs2 = u32::from_le_bytes(rs2);
206 (rs1 / rs2).to_le_bytes()
207 }
208 }
209}
210
211impl DivRemOp for RemOp {
212 #[inline(always)]
213 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
214 let rs1_i32 = i32::from_le_bytes(rs1);
215 let rs2_i32 = i32::from_le_bytes(rs2);
216 match (rs1_i32, rs2_i32) {
217 (_, 0) => rs1,
218 (i32::MIN, -1) => [0; 4],
219 _ => (rs1_i32 % rs2_i32).to_le_bytes(),
220 }
221 }
222}
223
224impl DivRemOp for RemuOp {
225 #[inline(always)]
226 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
227 if rs2 == [0; 4] {
228 rs1
229 } else {
230 let rs1 = u32::from_le_bytes(rs1);
231 let rs2 = u32::from_le_bytes(rs2);
232 (rs1 % rs2).to_le_bytes()
233 }
234 }
235}