openvm_rv32im_circuit/divrem/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6#[cfg(feature = "aot")]
7use openvm_circuit::arch::aot::common::REG_A_W;
8use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
9use openvm_circuit_primitives_derive::AlignedBytesBorrow;
10use openvm_instructions::{
11    instruction::Instruction,
12    program::DEFAULT_PC_STEP,
13    riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
14    LocalOpcode,
15};
16use openvm_rv32im_transpiler::DivRemOpcode;
17use openvm_stark_backend::p3_field::PrimeField32;
18
19use super::core::DivRemExecutor;
20#[cfg(feature = "aot")]
21use crate::common::{gpr_to_xmm, xmm_to_gpr};
22
23#[derive(AlignedBytesBorrow, Clone)]
24#[repr(C)]
25struct DivRemPreCompute {
26    a: u8,
27    b: u8,
28    c: u8,
29}
30
31impl<A, const LIMB_BITS: usize> DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
32    #[inline(always)]
33    fn pre_compute_impl<F: PrimeField32>(
34        &self,
35        pc: u32,
36        inst: &Instruction<F>,
37        data: &mut DivRemPreCompute,
38    ) -> Result<DivRemOpcode, StaticProgramError> {
39        let &Instruction {
40            opcode, a, b, c, d, ..
41        } = inst;
42        let local_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset));
43        if d.as_canonical_u32() != RV32_REGISTER_AS {
44            return Err(StaticProgramError::InvalidInstruction(pc));
45        }
46        let pre_compute: &mut DivRemPreCompute = data.borrow_mut();
47        *pre_compute = DivRemPreCompute {
48            a: a.as_canonical_u32() as u8,
49            b: b.as_canonical_u32() as u8,
50            c: c.as_canonical_u32() as u8,
51        };
52        Ok(local_opcode)
53    }
54}
55
56macro_rules! dispatch {
57    ($execute_impl:ident, $local_opcode:ident) => {
58        match $local_opcode {
59            DivRemOpcode::DIV => Ok($execute_impl::<_, _, DivOp>),
60            DivRemOpcode::DIVU => Ok($execute_impl::<_, _, DivuOp>),
61            DivRemOpcode::REM => Ok($execute_impl::<_, _, RemOp>),
62            DivRemOpcode::REMU => Ok($execute_impl::<_, _, RemuOp>),
63        }
64    };
65}
66
67impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
68    for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
69where
70    F: PrimeField32,
71{
72    #[inline(always)]
73    fn pre_compute_size(&self) -> usize {
74        size_of::<DivRemPreCompute>()
75    }
76
77    #[cfg(not(feature = "tco"))]
78    #[inline(always)]
79    fn pre_compute<Ctx: ExecutionCtxTrait>(
80        &self,
81        pc: u32,
82        inst: &Instruction<F>,
83        data: &mut [u8],
84    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
85        let data: &mut DivRemPreCompute = data.borrow_mut();
86        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
87        dispatch!(execute_e1_handler, local_opcode)
88    }
89
90    #[cfg(feature = "tco")]
91    fn handler<Ctx>(
92        &self,
93        pc: u32,
94        inst: &Instruction<F>,
95        data: &mut [u8],
96    ) -> Result<Handler<F, Ctx>, StaticProgramError>
97    where
98        Ctx: ExecutionCtxTrait,
99    {
100        let data: &mut DivRemPreCompute = data.borrow_mut();
101        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
102        dispatch!(execute_e1_handler, local_opcode)
103    }
104}
105
106#[cfg(feature = "aot")]
107impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
108    for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
109where
110    F: PrimeField32,
111{
112    fn generate_x86_asm(&self, inst: &Instruction<F>, pc: u32) -> Result<String, AotError> {
113        let &Instruction {
114            opcode, a, b, c, d, ..
115        } = inst;
116        let local_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset));
117        if d.as_canonical_u32() != RV32_REGISTER_AS {
118            return Err(AotError::InvalidInstruction);
119        }
120
121        let mut asm_str = String::new();
122        let a_reg = a.as_canonical_u32() / 4;
123        let b_reg = b.as_canonical_u32() / 4;
124        let c_reg = c.as_canonical_u32() / 4;
125
126        // Calculate the result. Inputs: eax, ecx. Outputs: edx.
127        // Note that for div/rem we are tied to eax/edx because of idiv requirements
128
129        let (_, delta_str_b) = &xmm_to_gpr(b_reg as u8, "eax", true);
130        asm_str += delta_str_b;
131        let (reg_c, delta_str_c) = &xmm_to_gpr(c_reg as u8, REG_A_W, false);
132        asm_str += delta_str_c;
133        asm_str += "   mov edx, 0\n";
134
135        let label_prefix = format!(
136            ".asm_divrem_{}_{}",
137            pc,
138            match local_opcode {
139                DivRemOpcode::DIV => "div",
140                DivRemOpcode::DIVU => "divu",
141                DivRemOpcode::REM => "rem",
142                DivRemOpcode::REMU => "remu",
143            }
144        );
145        let done_label = format!("{label_prefix}__done");
146
147        let zero_label = format!("{label_prefix}__divisor_zero");
148        let overflow_label = format!("{label_prefix}__overflow");
149        let normal_label = format!("{label_prefix}__normal");
150        match local_opcode {
151            DivRemOpcode::DIV => {
152                asm_str += &format!("   test {reg_c}, {reg_c}\n");
153                asm_str += &format!("   je {zero_label}\n");
154                asm_str += "   cmp eax, 0x80000000\n";
155                asm_str += &format!("   jne {normal_label}\n");
156                asm_str += &format!("   cmp {reg_c}, -1\n");
157                asm_str += &format!("   jne {normal_label}\n");
158                asm_str += &format!("   jmp {overflow_label}\n");
159
160                asm_str += &format!("{normal_label}:\n");
161                // sign-extend EAX into EDX:EAX
162                asm_str += "   cdq\n";
163                // eax = eax / ecx, edx = eax % ecx
164                asm_str += &format!("   idiv {reg_c}\n");
165                asm_str += "   mov edx, eax\n";
166                asm_str += &format!("   jmp {done_label}\n");
167
168                asm_str += &format!("{zero_label}:\n");
169                asm_str += "   mov edx, -1\n";
170                asm_str += &format!("   jmp {done_label}\n");
171
172                asm_str += &format!("{overflow_label}:\n");
173                asm_str += "   mov edx, eax\n";
174            }
175            DivRemOpcode::DIVU => {
176                asm_str += &format!("   test {reg_c}, {reg_c}\n");
177                asm_str += &format!("   je {zero_label}\n");
178                // eax = eax / ecx, edx = eax % ecx
179                asm_str += &format!("   div {reg_c}\n");
180                asm_str += "   mov edx, eax\n";
181                asm_str += &format!("   jmp {done_label}\n");
182
183                asm_str += &format!("{zero_label}:\n");
184                asm_str += "   mov edx, -1\n";
185            }
186            DivRemOpcode::REM => {
187                asm_str += &format!("   test {reg_c}, {reg_c}\n");
188                asm_str += &format!("   je {zero_label}\n");
189                asm_str += "   cmp eax, 0x80000000\n";
190                asm_str += &format!("   jne {normal_label}\n");
191                asm_str += &format!("   cmp {reg_c}, -1\n");
192                asm_str += &format!("   jne {normal_label}\n");
193                asm_str += "   mov edx, 0\n";
194                asm_str += &format!("   jmp {done_label}\n");
195
196                asm_str += &format!("{normal_label}:\n");
197                // sign-extend EAX into EDX:EAX
198                asm_str += "   cdq\n";
199                // eax = eax / ecx, edx = eax % ecx
200                asm_str += &format!("   idiv {reg_c}\n");
201                asm_str += &format!("   jmp {done_label}\n");
202
203                asm_str += &format!("{zero_label}:\n");
204                asm_str += "   mov edx, eax\n";
205            }
206            DivRemOpcode::REMU => {
207                asm_str += &format!("   test {reg_c}, {reg_c}\n");
208                asm_str += &format!("   je {zero_label}\n");
209                // eax = eax / ecx, edx = eax % ecx
210                asm_str += &format!("   div {reg_c}\n");
211                asm_str += &format!("   jmp {done_label}\n");
212
213                asm_str += &format!("{zero_label}:\n");
214                asm_str += "   mov edx, eax\n";
215            }
216        }
217        asm_str += &format!("{done_label}:\n");
218        asm_str += &gpr_to_xmm("edx", a_reg as u8);
219
220        Ok(asm_str)
221    }
222
223    fn is_aot_supported(&self, _inst: &Instruction<F>) -> bool {
224        true
225    }
226}
227
228impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
229    for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
230where
231    F: PrimeField32,
232{
233    fn metered_pre_compute_size(&self) -> usize {
234        size_of::<E2PreCompute<DivRemPreCompute>>()
235    }
236
237    #[cfg(not(feature = "tco"))]
238    fn metered_pre_compute<Ctx>(
239        &self,
240        chip_idx: usize,
241        pc: u32,
242        inst: &Instruction<F>,
243        data: &mut [u8],
244    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
245    where
246        Ctx: MeteredExecutionCtxTrait,
247    {
248        let data: &mut E2PreCompute<DivRemPreCompute> = data.borrow_mut();
249        data.chip_idx = chip_idx as u32;
250        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
251        dispatch!(execute_e2_handler, local_opcode)
252    }
253
254    #[cfg(feature = "tco")]
255    fn metered_handler<Ctx>(
256        &self,
257        chip_idx: usize,
258        pc: u32,
259        inst: &Instruction<F>,
260        data: &mut [u8],
261    ) -> Result<Handler<F, Ctx>, StaticProgramError>
262    where
263        Ctx: MeteredExecutionCtxTrait,
264    {
265        let data: &mut E2PreCompute<DivRemPreCompute> = data.borrow_mut();
266        data.chip_idx = chip_idx as u32;
267        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
268        dispatch!(execute_e2_handler, local_opcode)
269    }
270}
271#[cfg(feature = "aot")]
272impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
273    for DivRemExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
274where
275    F: PrimeField32,
276{
277    fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
278        true
279    }
280    fn generate_x86_metered_asm(
281        &self,
282        inst: &Instruction<F>,
283        pc: u32,
284        chip_idx: usize,
285        config: &SystemConfig,
286    ) -> Result<String, AotError> {
287        use crate::common::{update_adapter_heights_asm, update_height_change_asm};
288
289        let mut asm_str = self.generate_x86_asm(inst, pc)?;
290
291        asm_str += &update_height_change_asm(chip_idx, 1)?;
292        // read [b:4]_1
293        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
294        // read [c:4]_1
295        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
296        // write [a:4]_1
297        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
298
299        Ok(asm_str)
300    }
301}
302
303#[inline(always)]
304unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: DivRemOp>(
305    pre_compute: &DivRemPreCompute,
306    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
307) {
308    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
309    let rs2 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
310    let result = <OP as DivRemOp>::compute(rs1, rs2);
311    exec_state.vm_write::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32, &result);
312    let pc = exec_state.pc();
313    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
314}
315
316#[create_handler]
317#[inline(always)]
318unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: DivRemOp>(
319    pre_compute: *const u8,
320    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
321) {
322    let pre_compute: &DivRemPreCompute =
323        std::slice::from_raw_parts(pre_compute, size_of::<DivRemPreCompute>()).borrow();
324    execute_e12_impl::<F, CTX, OP>(pre_compute, exec_state);
325}
326
327#[create_handler]
328#[inline(always)]
329unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: DivRemOp>(
330    pre_compute: *const u8,
331    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
332) {
333    let pre_compute: &E2PreCompute<DivRemPreCompute> =
334        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<DivRemPreCompute>>())
335            .borrow();
336    exec_state
337        .ctx
338        .on_height_change(pre_compute.chip_idx as usize, 1);
339    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, exec_state);
340}
341
342trait DivRemOp {
343    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4];
344}
345struct DivOp;
346struct DivuOp;
347struct RemOp;
348struct RemuOp;
349
350impl DivRemOp for DivOp {
351    #[inline(always)]
352    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
353        let rs1_i32 = i32::from_le_bytes(rs1);
354        let rs2_i32 = i32::from_le_bytes(rs2);
355        match (rs1_i32, rs2_i32) {
356            (_, 0) => [u8::MAX; 4],
357            (i32::MIN, -1) => rs1,
358            _ => (rs1_i32 / rs2_i32).to_le_bytes(),
359        }
360    }
361}
362
363impl DivRemOp for DivuOp {
364    #[inline(always)]
365    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
366        if rs2 == [0; 4] {
367            [u8::MAX; 4]
368        } else {
369            let rs1 = u32::from_le_bytes(rs1);
370            let rs2 = u32::from_le_bytes(rs2);
371            (rs1 / rs2).to_le_bytes()
372        }
373    }
374}
375
376impl DivRemOp for RemOp {
377    #[inline(always)]
378    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
379        let rs1_i32 = i32::from_le_bytes(rs1);
380        let rs2_i32 = i32::from_le_bytes(rs2);
381        match (rs1_i32, rs2_i32) {
382            (_, 0) => rs1,
383            (i32::MIN, -1) => [0; 4],
384            _ => (rs1_i32 % rs2_i32).to_le_bytes(),
385        }
386    }
387}
388
389impl DivRemOp for RemuOp {
390    #[inline(always)]
391    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] {
392        if rs2 == [0; 4] {
393            rs1
394        } else {
395            let rs1 = u32::from_le_bytes(rs1);
396            let rs2 = u32::from_le_bytes(rs2);
397            (rs1 % rs2).to_le_bytes()
398        }
399    }
400}