1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6#[cfg(feature = "aot")]
7use openvm_circuit::arch::aot::common::REG_A_W;
8use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
9use openvm_circuit_primitives_derive::AlignedBytesBorrow;
10use openvm_instructions::{
11 instruction::Instruction,
12 program::DEFAULT_PC_STEP,
13 riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
14 LocalOpcode,
15};
16use openvm_rv32im_transpiler::DivRemOpcode;
17use openvm_stark_backend::p3_field::PrimeField32;
18
19use super::core::DivRemExecutor;
20#[cfg(feature = "aot")]
21use crate::common::{gpr_to_xmm, xmm_to_gpr};
22
23#[derive(AlignedBytesBorrow, Clone)]
24#[repr(C)]
25struct DivRemPreCompute {
26 a: u8,
27 b: u8,
28 c: u8,
29}
30
31impl<A, const LIMB_BITS: usize> DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
32 #[inline(always)]
33 fn pre_compute_impl<F: PrimeField32>(
34 &self,
35 pc: u32,
36 inst: &Instruction<F>,
37 data: &mut DivRemPreCompute,
38 ) -> Result<DivRemOpcode, StaticProgramError> {
39 let &Instruction {
40 opcode, a, b, c, d, ..
41 } = inst;
42 let local_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset));
43 if d.as_canonical_u32() != RV32_REGISTER_AS {
44 return Err(StaticProgramError::InvalidInstruction(pc));
45 }
46 let pre_compute: &mut DivRemPreCompute = data.borrow_mut();
47 *pre_compute = DivRemPreCompute {
48 a: a.as_canonical_u32() as u8,
49 b: b.as_canonical_u32() as u8,
50 c: c.as_canonical_u32() as u8,
51 };
52 Ok(local_opcode)
53 }
54}
55
56macro_rules! dispatch {
57 ($execute_impl:ident, $local_opcode:ident) => {
58 match $local_opcode {
59 DivRemOpcode::DIV => Ok($execute_impl::<_, _, DivOp>),
60 DivRemOpcode::DIVU => Ok($execute_impl::<_, _, DivuOp>),
61 DivRemOpcode::REM => Ok($execute_impl::<_, _, RemOp>),
62 DivRemOpcode::REMU => Ok($execute_impl::<_, _, RemuOp>),
63 }
64 };
65}
66
67impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
68 for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
69where
70 F: PrimeField32,
71{
72 #[inline(always)]
73 fn pre_compute_size(&self) -> usize {
74 size_of::<DivRemPreCompute>()
75 }
76
77 #[cfg(not(feature = "tco"))]
78 #[inline(always)]
79 fn pre_compute<Ctx: ExecutionCtxTrait>(
80 &self,
81 pc: u32,
82 inst: &Instruction<F>,
83 data: &mut [u8],
84 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
85 let data: &mut DivRemPreCompute = data.borrow_mut();
86 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
87 dispatch!(execute_e1_handler, local_opcode)
88 }
89
90 #[cfg(feature = "tco")]
91 fn handler<Ctx>(
92 &self,
93 pc: u32,
94 inst: &Instruction<F>,
95 data: &mut [u8],
96 ) -> Result<Handler<F, Ctx>, StaticProgramError>
97 where
98 Ctx: ExecutionCtxTrait,
99 {
100 let data: &mut DivRemPreCompute = data.borrow_mut();
101 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
102 dispatch!(execute_e1_handler, local_opcode)
103 }
104}
105
106#[cfg(feature = "aot")]
107impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
108 for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
109where
110 F: PrimeField32,
111{
112 fn generate_x86_asm(&self, inst: &Instruction<F>, pc: u32) -> Result<String, AotError> {
113 let &Instruction {
114 opcode, a, b, c, d, ..
115 } = inst;
116 let local_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset));
117 if d.as_canonical_u32() != RV32_REGISTER_AS {
118 return Err(AotError::InvalidInstruction);
119 }
120
121 let mut asm_str = String::new();
122 let a_reg = a.as_canonical_u32() / 4;
123 let b_reg = b.as_canonical_u32() / 4;
124 let c_reg = c.as_canonical_u32() / 4;
125
126 let (_, delta_str_b) = &xmm_to_gpr(b_reg as u8, "eax", true);
130 asm_str += delta_str_b;
131 let (reg_c, delta_str_c) = &xmm_to_gpr(c_reg as u8, REG_A_W, false);
132 asm_str += delta_str_c;
133 asm_str += " mov edx, 0\n";
134
135 let label_prefix = format!(
136 ".asm_divrem_{}_{}",
137 pc,
138 match local_opcode {
139 DivRemOpcode::DIV => "div",
140 DivRemOpcode::DIVU => "divu",
141 DivRemOpcode::REM => "rem",
142 DivRemOpcode::REMU => "remu",
143 }
144 );
145 let done_label = format!("{label_prefix}__done");
146
147 let zero_label = format!("{label_prefix}__divisor_zero");
148 let overflow_label = format!("{label_prefix}__overflow");
149 let normal_label = format!("{label_prefix}__normal");
150 match local_opcode {
151 DivRemOpcode::DIV => {
152 asm_str += &format!(" test {reg_c}, {reg_c}\n");
153 asm_str += &format!(" je {zero_label}\n");
154 asm_str += " cmp eax, 0x80000000\n";
155 asm_str += &format!(" jne {normal_label}\n");
156 asm_str += &format!(" cmp {reg_c}, -1\n");
157 asm_str += &format!(" jne {normal_label}\n");
158 asm_str += &format!(" jmp {overflow_label}\n");
159
160 asm_str += &format!("{normal_label}:\n");
161 asm_str += " cdq\n";
163 asm_str += &format!(" idiv {reg_c}\n");
165 asm_str += " mov edx, eax\n";
166 asm_str += &format!(" jmp {done_label}\n");
167
168 asm_str += &format!("{zero_label}:\n");
169 asm_str += " mov edx, -1\n";
170 asm_str += &format!(" jmp {done_label}\n");
171
172 asm_str += &format!("{overflow_label}:\n");
173 asm_str += " mov edx, eax\n";
174 }
175 DivRemOpcode::DIVU => {
176 asm_str += &format!(" test {reg_c}, {reg_c}\n");
177 asm_str += &format!(" je {zero_label}\n");
178 asm_str += &format!(" div {reg_c}\n");
180 asm_str += " mov edx, eax\n";
181 asm_str += &format!(" jmp {done_label}\n");
182
183 asm_str += &format!("{zero_label}:\n");
184 asm_str += " mov edx, -1\n";
185 }
186 DivRemOpcode::REM => {
187 asm_str += &format!(" test {reg_c}, {reg_c}\n");
188 asm_str += &format!(" je {zero_label}\n");
189 asm_str += " cmp eax, 0x80000000\n";
190 asm_str += &format!(" jne {normal_label}\n");
191 asm_str += &format!(" cmp {reg_c}, -1\n");
192 asm_str += &format!(" jne {normal_label}\n");
193 asm_str += " mov edx, 0\n";
194 asm_str += &format!(" jmp {done_label}\n");
195
196 asm_str += &format!("{normal_label}:\n");
197 asm_str += " cdq\n";
199 asm_str += &format!(" idiv {reg_c}\n");
201 asm_str += &format!(" jmp {done_label}\n");
202
203 asm_str += &format!("{zero_label}:\n");
204 asm_str += " mov edx, eax\n";
205 }
206 DivRemOpcode::REMU => {
207 asm_str += &format!(" test {reg_c}, {reg_c}\n");
208 asm_str += &format!(" je {zero_label}\n");
209 asm_str += &format!(" div {reg_c}\n");
211 asm_str += &format!(" jmp {done_label}\n");
212
213 asm_str += &format!("{zero_label}:\n");
214 asm_str += " mov edx, eax\n";
215 }
216 }
217 asm_str += &format!("{done_label}:\n");
218 asm_str += &gpr_to_xmm("edx", a_reg as u8);
219
220 Ok(asm_str)
221 }
222
223 fn is_aot_supported(&self, _inst: &Instruction<F>) -> bool {
224 true
225 }
226}
227
228impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
229 for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
230where
231 F: PrimeField32,
232{
233 fn metered_pre_compute_size(&self) -> usize {
234 size_of::<E2PreCompute<DivRemPreCompute>>()
235 }
236
237 #[cfg(not(feature = "tco"))]
238 fn metered_pre_compute<Ctx>(
239 &self,
240 chip_idx: usize,
241 pc: u32,
242 inst: &Instruction<F>,
243 data: &mut [u8],
244 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
245 where
246 Ctx: MeteredExecutionCtxTrait,
247 {
248 let data: &mut E2PreCompute<DivRemPreCompute> = data.borrow_mut();
249 data.chip_idx = chip_idx as u32;
250 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
251 dispatch!(execute_e2_handler, local_opcode)
252 }
253
254 #[cfg(feature = "tco")]
255 fn metered_handler<Ctx>(
256 &self,
257 chip_idx: usize,
258 pc: u32,
259 inst: &Instruction<F>,
260 data: &mut [u8],
261 ) -> Result<Handler<F, Ctx>, StaticProgramError>
262 where
263 Ctx: MeteredExecutionCtxTrait,
264 {
265 let data: &mut E2PreCompute<DivRemPreCompute> = data.borrow_mut();
266 data.chip_idx = chip_idx as u32;
267 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
268 dispatch!(execute_e2_handler, local_opcode)
269 }
270}
271#[cfg(feature = "aot")]
272impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
273 for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
274where
275 F: PrimeField32,
276{
277 fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
278 true
279 }
280 fn generate_x86_metered_asm(
281 &self,
282 inst: &Instruction<F>,
283 pc: u32,
284 chip_idx: usize,
285 config: &SystemConfig,
286 ) -> Result<String, AotError> {
287 use crate::common::{update_adapter_heights_asm, update_height_change_asm};
288
289 let mut asm_str = self.generate_x86_asm(inst, pc)?;
290
291 asm_str += &update_height_change_asm(chip_idx, 1)?;
292 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
294 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
296 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
298
299 Ok(asm_str)
300 }
301}
302
303#[inline(always)]
304unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: DivRemOp>(
305 pre_compute: &DivRemPreCompute,
306 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
307) {
308 let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
309 let rs2 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
310 let result = <OP as DivRemOp>::compute(rs1, rs2);
311 exec_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &result);
312 let pc = exec_state.pc();
313 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
314}
315
316#[create_handler]
317#[inline(always)]
318unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: DivRemOp>(
319 pre_compute: *const u8,
320 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
321) {
322 let pre_compute: &DivRemPreCompute =
323 std::slice::from_raw_parts(pre_compute, size_of::<DivRemPreCompute>()).borrow();
324 execute_e12_impl::<F, CTX, OP>(pre_compute, exec_state);
325}
326
327#[create_handler]
328#[inline(always)]
329unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: DivRemOp>(
330 pre_compute: *const u8,
331 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
332) {
333 let pre_compute: &E2PreCompute<DivRemPreCompute> =
334 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<DivRemPreCompute>>())
335 .borrow();
336 exec_state
337 .ctx
338 .on_height_change(pre_compute.chip_idx as usize, 1);
339 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, exec_state);
340}
341
342trait DivRemOp {
343 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4];
344}
345struct DivOp;
346struct DivuOp;
347struct RemOp;
348struct RemuOp;
349
350impl DivRemOp for DivOp {
351 #[inline(always)]
352 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
353 let rs1_i32 = i32::from_le_bytes(rs1);
354 let rs2_i32 = i32::from_le_bytes(rs2);
355 match (rs1_i32, rs2_i32) {
356 (_, 0) => [u8::MAX; 4],
357 (i32::MIN, -1) => rs1,
358 _ => (rs1_i32 / rs2_i32).to_le_bytes(),
359 }
360 }
361}
362
363impl DivRemOp for DivuOp {
364 #[inline(always)]
365 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
366 if rs2 == [0; 4] {
367 [u8::MAX; 4]
368 } else {
369 let rs1 = u32::from_le_bytes(rs1);
370 let rs2 = u32::from_le_bytes(rs2);
371 (rs1 / rs2).to_le_bytes()
372 }
373 }
374}
375
376impl DivRemOp for RemOp {
377 #[inline(always)]
378 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
379 let rs1_i32 = i32::from_le_bytes(rs1);
380 let rs2_i32 = i32::from_le_bytes(rs2);
381 match (rs1_i32, rs2_i32) {
382 (_, 0) => rs1,
383 (i32::MIN, -1) => [0; 4],
384 _ => (rs1_i32 % rs2_i32).to_le_bytes(),
385 }
386 }
387}
388
389impl DivRemOp for RemuOp {
390 #[inline(always)]
391 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
392 if rs2 == [0; 4] {
393 rs1
394 } else {
395 let rs1 = u32::from_le_bytes(rs1);
396 let rs2 = u32::from_le_bytes(rs2);
397 (rs1 % rs2).to_le_bytes()
398 }
399 }
400}