openvm_rv32im_circuit/less_than/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6#[cfg(feature = "aot")]
7use openvm_circuit::arch::aot::common::convert_x86_reg;
8use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
9use openvm_circuit_primitives_derive::AlignedBytesBorrow;
10use openvm_instructions::{
11    instruction::Instruction,
12    program::DEFAULT_PC_STEP,
13    riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
14    LocalOpcode,
15};
16use openvm_rv32im_transpiler::LessThanOpcode;
17use openvm_stark_backend::p3_field::PrimeField32;
18
19use super::core::LessThanExecutor;
20#[cfg(feature = "aot")]
21use crate::less_than::execution::aot::common::Width;
22#[allow(unused_imports)]
23use crate::{adapters::imm_to_bytes, common::*};
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct LessThanPreCompute {
28    c: u32,
29    a: u8,
30    b: u8,
31}
32
33impl<A, const LIMB_BITS: usize> LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
34    #[inline(always)]
35    fn pre_compute_impl<F: PrimeField32>(
36        &self,
37        pc: u32,
38        inst: &Instruction<F>,
39        data: &mut LessThanPreCompute,
40    ) -> Result<(bool, bool), StaticProgramError> {
41        let Instruction {
42            opcode,
43            a,
44            b,
45            c,
46            d,
47            e,
48            ..
49        } = inst;
50        let e_u32 = e.as_canonical_u32();
51        if d.as_canonical_u32() != RV32_REGISTER_AS
52            || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
53        {
54            return Err(StaticProgramError::InvalidInstruction(pc));
55        }
56        let local_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset));
57        let is_imm = e_u32 == RV32_IMM_AS;
58        let c_u32 = c.as_canonical_u32();
59
60        *data = LessThanPreCompute {
61            c: if is_imm {
62                u32::from_le_bytes(imm_to_bytes(c_u32))
63            } else {
64                c_u32
65            },
66            a: a.as_canonical_u32() as u8,
67            b: b.as_canonical_u32() as u8,
68        };
69        Ok((is_imm, local_opcode == LessThanOpcode::SLTU))
70    }
71}
72
73macro_rules! dispatch {
74    ($execute_impl:ident, $is_imm:ident, $is_sltu:ident) => {
75        match ($is_imm, $is_sltu) {
76            (true, true) => Ok($execute_impl::<_, _, true, true>),
77            (true, false) => Ok($execute_impl::<_, _, true, false>),
78            (false, true) => Ok($execute_impl::<_, _, false, true>),
79            (false, false) => Ok($execute_impl::<_, _, false, false>),
80        }
81    };
82}
83
84impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
85    for LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
86where
87    F: PrimeField32,
88{
89    #[inline(always)]
90    fn pre_compute_size(&self) -> usize {
91        size_of::<LessThanPreCompute>()
92    }
93
94    #[cfg(not(feature = "tco"))]
95    #[inline(always)]
96    fn pre_compute<Ctx: ExecutionCtxTrait>(
97        &self,
98        pc: u32,
99        inst: &Instruction<F>,
100        data: &mut [u8],
101    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
102        let pre_compute: &mut LessThanPreCompute = data.borrow_mut();
103        let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, pre_compute)?;
104        dispatch!(execute_e1_handler, is_imm, is_sltu)
105    }
106
107    #[cfg(feature = "tco")]
108    fn handler<Ctx>(
109        &self,
110        pc: u32,
111        inst: &Instruction<F>,
112        data: &mut [u8],
113    ) -> Result<Handler<F, Ctx>, StaticProgramError>
114    where
115        Ctx: ExecutionCtxTrait,
116    {
117        let pre_compute: &mut LessThanPreCompute = data.borrow_mut();
118        let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, pre_compute)?;
119        dispatch!(execute_e1_handler, is_imm, is_sltu)
120    }
121}
122
123#[cfg(feature = "aot")]
124impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
125    for LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
126where
127    F: PrimeField32,
128{
129    fn is_aot_supported(&self, _inst: &Instruction<F>) -> bool {
130        true
131    }
132    fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
133        let to_i16 = |c: F| -> i16 {
134            let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
135            let c_i24 = ((c_u24 << 8) as i32) >> 8;
136            c_i24 as i16
137        };
138        let mut asm_str = String::new();
139        let a: i16 = to_i16(inst.a);
140        let b: i16 = to_i16(inst.b);
141        let c: i16 = to_i16(inst.c);
142        let e: i16 = to_i16(inst.e);
143        assert!(a % 4 == 0, "instruction.a must be a multiple of 4");
144        assert!(b % 4 == 0, "instruction.b must be a multiple of 4");
145
146        let mut asm_opcode = String::new();
147        if inst.opcode == LessThanOpcode::SLT.global_opcode() {
148            asm_opcode += "SETL";
149        } else if inst.opcode == LessThanOpcode::SLTU.global_opcode() {
150            asm_opcode += "SETB";
151        }
152
153        if e == 0 {
154            let (reg_b, delta_str_b) = &xmm_to_gpr((b / 4) as u8, REG_B_W, false);
155            asm_str += delta_str_b;
156
157            let reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
158                RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
159            } else {
160                REG_A_W
161            };
162
163            let reg_a_w8l =
164                convert_x86_reg(reg_a, Width::W8L).ok_or(AotError::InvalidInstruction)?;
165
166            asm_str += &format!("   CMP {reg_b}, {c}\n");
167            asm_str += &format!("   {asm_opcode} {reg_a_w8l}\n");
168            asm_str += &format!("   MOVZX {reg_a}, {reg_a_w8l}\n");
169
170            asm_str += &gpr_to_xmm(reg_a, (a / 4) as u8);
171        } else {
172            let (reg_b, delta_str_b) = &xmm_to_gpr((b / 4) as u8, REG_B_W, false);
173            asm_str += delta_str_b;
174
175            let (reg_c, delta_str_c) = &xmm_to_gpr((c / 4) as u8, REG_C_W, false);
176            asm_str += delta_str_c;
177
178            let reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
179                RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
180            } else {
181                REG_A_W
182            };
183
184            let reg_a_w8l =
185                convert_x86_reg(reg_a, Width::W8L).ok_or(AotError::InvalidInstruction)?;
186
187            asm_str += &format!("   CMP {reg_b}, {reg_c}\n");
188            asm_str += &format!("   {asm_opcode} {reg_a_w8l}\n");
189            asm_str += &format!("   MOVZX {reg_a}, {reg_a_w8l}\n");
190
191            asm_str += &gpr_to_xmm(reg_a, (a / 4) as u8);
192        }
193
194        // let it fall to the next instruction
195        Ok(asm_str)
196    }
197}
198
199impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
200    for LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
201where
202    F: PrimeField32,
203{
204    fn metered_pre_compute_size(&self) -> usize {
205        size_of::<E2PreCompute<LessThanPreCompute>>()
206    }
207
208    #[cfg(not(feature = "tco"))]
209    fn metered_pre_compute<Ctx>(
210        &self,
211        chip_idx: usize,
212        pc: u32,
213        inst: &Instruction<F>,
214        data: &mut [u8],
215    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
216    where
217        Ctx: MeteredExecutionCtxTrait,
218    {
219        let pre_compute: &mut E2PreCompute<LessThanPreCompute> = data.borrow_mut();
220        pre_compute.chip_idx = chip_idx as u32;
221        let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
222        dispatch!(execute_e2_handler, is_imm, is_sltu)
223    }
224
225    #[cfg(feature = "tco")]
226    fn metered_handler<Ctx>(
227        &self,
228        chip_idx: usize,
229        pc: u32,
230        inst: &Instruction<F>,
231        data: &mut [u8],
232    ) -> Result<Handler<F, Ctx>, StaticProgramError>
233    where
234        Ctx: MeteredExecutionCtxTrait,
235    {
236        let pre_compute: &mut E2PreCompute<LessThanPreCompute> = data.borrow_mut();
237        pre_compute.chip_idx = chip_idx as u32;
238        let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
239        dispatch!(execute_e2_handler, is_imm, is_sltu)
240    }
241}
242#[cfg(feature = "aot")]
243impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
244    for LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
245where
246    F: PrimeField32,
247{
248    fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
249        true
250    }
251
252    fn generate_x86_metered_asm(
253        &self,
254        inst: &Instruction<F>,
255        pc: u32,
256        chip_idx: usize,
257        config: &SystemConfig,
258    ) -> Result<String, AotError> {
259        let mut asm_str = self.generate_x86_asm(inst, pc)?;
260        asm_str += &update_height_change_asm(chip_idx, 1)?;
261        // read [b:4]_1
262        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
263        if inst.e.as_canonical_u32() != RV32_IMM_AS {
264            // read [c:4]_e
265            asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
266        }
267        // write [a:4]_1
268        asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
269        Ok(asm_str)
270    }
271}
272
273#[inline(always)]
274unsafe fn execute_e12_impl<
275    F: PrimeField32,
276    CTX: ExecutionCtxTrait,
277    const E_IS_IMM: bool,
278    const IS_U32: bool,
279>(
280    pre_compute: &LessThanPreCompute,
281    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
282) {
283    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
284    let rs2 = if E_IS_IMM {
285        pre_compute.c.to_le_bytes()
286    } else {
287        exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
288    };
289    let cmp_result = if IS_U32 {
290        u32::from_le_bytes(rs1) < u32::from_le_bytes(rs2)
291    } else {
292        i32::from_le_bytes(rs1) < i32::from_le_bytes(rs2)
293    };
294    let mut rd = [0u8; RV32_REGISTER_NUM_LIMBS];
295    rd[0] = cmp_result as u8;
296    exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
297
298    let pc = exec_state.pc();
299    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
300}
301
302#[create_handler]
303#[inline(always)]
304unsafe fn execute_e1_impl<
305    F: PrimeField32,
306    CTX: ExecutionCtxTrait,
307    const E_IS_IMM: bool,
308    const IS_U32: bool,
309>(
310    pre_compute: *const u8,
311    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
312) {
313    let pre_compute: &LessThanPreCompute =
314        std::slice::from_raw_parts(pre_compute, size_of::<LessThanPreCompute>()).borrow();
315    execute_e12_impl::<F, CTX, E_IS_IMM, IS_U32>(pre_compute, exec_state);
316}
317
318#[create_handler]
319#[inline(always)]
320unsafe fn execute_e2_impl<
321    F: PrimeField32,
322    CTX: MeteredExecutionCtxTrait,
323    const E_IS_IMM: bool,
324    const IS_U32: bool,
325>(
326    pre_compute: *const u8,
327    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
328) {
329    let pre_compute: &E2PreCompute<LessThanPreCompute> =
330        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<LessThanPreCompute>>())
331            .borrow();
332    exec_state
333        .ctx
334        .on_height_change(pre_compute.chip_idx as usize, 1);
335    execute_e12_impl::<F, CTX, E_IS_IMM, IS_U32>(&pre_compute.data, exec_state);
336}