1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6#[cfg(feature = "aot")]
7use openvm_circuit::arch::aot::common::convert_x86_reg;
8use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
9use openvm_circuit_primitives_derive::AlignedBytesBorrow;
10use openvm_instructions::{
11 instruction::Instruction,
12 program::DEFAULT_PC_STEP,
13 riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
14 LocalOpcode,
15};
16use openvm_rv32im_transpiler::LessThanOpcode;
17use openvm_stark_backend::p3_field::PrimeField32;
18
19use super::core::LessThanExecutor;
20#[cfg(feature = "aot")]
21use crate::less_than::execution::aot::common::Width;
22#[allow(unused_imports)]
23use crate::{adapters::imm_to_bytes, common::*};
24
25#[derive(AlignedBytesBorrow, Clone)]
26#[repr(C)]
27struct LessThanPreCompute {
28 c: u32,
29 a: u8,
30 b: u8,
31}
32
33impl<A, const LIMB_BITS: usize> LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS> {
34 #[inline(always)]
35 fn pre_compute_impl<F: PrimeField32>(
36 &self,
37 pc: u32,
38 inst: &Instruction<F>,
39 data: &mut LessThanPreCompute,
40 ) -> Result<(bool, bool), StaticProgramError> {
41 let Instruction {
42 opcode,
43 a,
44 b,
45 c,
46 d,
47 e,
48 ..
49 } = inst;
50 let e_u32 = e.as_canonical_u32();
51 if d.as_canonical_u32() != RV32_REGISTER_AS
52 || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS)
53 {
54 return Err(StaticProgramError::InvalidInstruction(pc));
55 }
56 let local_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset));
57 let is_imm = e_u32 == RV32_IMM_AS;
58 let c_u32 = c.as_canonical_u32();
59
60 *data = LessThanPreCompute {
61 c: if is_imm {
62 u32::from_le_bytes(imm_to_bytes(c_u32))
63 } else {
64 c_u32
65 },
66 a: a.as_canonical_u32() as u8,
67 b: b.as_canonical_u32() as u8,
68 };
69 Ok((is_imm, local_opcode == LessThanOpcode::SLTU))
70 }
71}
72
73macro_rules! dispatch {
74 ($execute_impl:ident, $is_imm:ident, $is_sltu:ident) => {
75 match ($is_imm, $is_sltu) {
76 (true, true) => Ok($execute_impl::<_, _, true, true>),
77 (true, false) => Ok($execute_impl::<_, _, true, false>),
78 (false, true) => Ok($execute_impl::<_, _, false, true>),
79 (false, false) => Ok($execute_impl::<_, _, false, false>),
80 }
81 };
82}
83
84impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
85 for LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
86where
87 F: PrimeField32,
88{
89 #[inline(always)]
90 fn pre_compute_size(&self) -> usize {
91 size_of::<LessThanPreCompute>()
92 }
93
94 #[cfg(not(feature = "tco"))]
95 #[inline(always)]
96 fn pre_compute<Ctx: ExecutionCtxTrait>(
97 &self,
98 pc: u32,
99 inst: &Instruction<F>,
100 data: &mut [u8],
101 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
102 let pre_compute: &mut LessThanPreCompute = data.borrow_mut();
103 let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, pre_compute)?;
104 dispatch!(execute_e1_handler, is_imm, is_sltu)
105 }
106
107 #[cfg(feature = "tco")]
108 fn handler<Ctx>(
109 &self,
110 pc: u32,
111 inst: &Instruction<F>,
112 data: &mut [u8],
113 ) -> Result<Handler<F, Ctx>, StaticProgramError>
114 where
115 Ctx: ExecutionCtxTrait,
116 {
117 let pre_compute: &mut LessThanPreCompute = data.borrow_mut();
118 let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, pre_compute)?;
119 dispatch!(execute_e1_handler, is_imm, is_sltu)
120 }
121}
122
123#[cfg(feature = "aot")]
124impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
125 for LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
126where
127 F: PrimeField32,
128{
129 fn is_aot_supported(&self, _inst: &Instruction<F>) -> bool {
130 true
131 }
132 fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
133 let to_i16 = |c: F| -> i16 {
134 let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
135 let c_i24 = ((c_u24 << 8) as i32) >> 8;
136 c_i24 as i16
137 };
138 let mut asm_str = String::new();
139 let a: i16 = to_i16(inst.a);
140 let b: i16 = to_i16(inst.b);
141 let c: i16 = to_i16(inst.c);
142 let e: i16 = to_i16(inst.e);
143 assert!(a % 4 == 0, "instruction.a must be a multiple of 4");
144 assert!(b % 4 == 0, "instruction.b must be a multiple of 4");
145
146 let mut asm_opcode = String::new();
147 if inst.opcode == LessThanOpcode::SLT.global_opcode() {
148 asm_opcode += "SETL";
149 } else if inst.opcode == LessThanOpcode::SLTU.global_opcode() {
150 asm_opcode += "SETB";
151 }
152
153 if e == 0 {
154 let (reg_b, delta_str_b) = &xmm_to_gpr((b / 4) as u8, REG_B_W, false);
155 asm_str += delta_str_b;
156
157 let reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
158 RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
159 } else {
160 REG_A_W
161 };
162
163 let reg_a_w8l =
164 convert_x86_reg(reg_a, Width::W8L).ok_or(AotError::InvalidInstruction)?;
165
166 asm_str += &format!(" CMP {reg_b}, {c}\n");
167 asm_str += &format!(" {asm_opcode} {reg_a_w8l}\n");
168 asm_str += &format!(" MOVZX {reg_a}, {reg_a_w8l}\n");
169
170 asm_str += &gpr_to_xmm(reg_a, (a / 4) as u8);
171 } else {
172 let (reg_b, delta_str_b) = &xmm_to_gpr((b / 4) as u8, REG_B_W, false);
173 asm_str += delta_str_b;
174
175 let (reg_c, delta_str_c) = &xmm_to_gpr((c / 4) as u8, REG_C_W, false);
176 asm_str += delta_str_c;
177
178 let reg_a = if RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].is_some() {
179 RISCV_TO_X86_OVERRIDE_MAP[(a / 4) as usize].unwrap()
180 } else {
181 REG_A_W
182 };
183
184 let reg_a_w8l =
185 convert_x86_reg(reg_a, Width::W8L).ok_or(AotError::InvalidInstruction)?;
186
187 asm_str += &format!(" CMP {reg_b}, {reg_c}\n");
188 asm_str += &format!(" {asm_opcode} {reg_a_w8l}\n");
189 asm_str += &format!(" MOVZX {reg_a}, {reg_a_w8l}\n");
190
191 asm_str += &gpr_to_xmm(reg_a, (a / 4) as u8);
192 }
193
194 Ok(asm_str)
196 }
197}
198
199impl<F, A, const LIMB_BITS: usize> InterpreterMeteredExecutor<F>
200 for LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
201where
202 F: PrimeField32,
203{
204 fn metered_pre_compute_size(&self) -> usize {
205 size_of::<E2PreCompute<LessThanPreCompute>>()
206 }
207
208 #[cfg(not(feature = "tco"))]
209 fn metered_pre_compute<Ctx>(
210 &self,
211 chip_idx: usize,
212 pc: u32,
213 inst: &Instruction<F>,
214 data: &mut [u8],
215 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
216 where
217 Ctx: MeteredExecutionCtxTrait,
218 {
219 let pre_compute: &mut E2PreCompute<LessThanPreCompute> = data.borrow_mut();
220 pre_compute.chip_idx = chip_idx as u32;
221 let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
222 dispatch!(execute_e2_handler, is_imm, is_sltu)
223 }
224
225 #[cfg(feature = "tco")]
226 fn metered_handler<Ctx>(
227 &self,
228 chip_idx: usize,
229 pc: u32,
230 inst: &Instruction<F>,
231 data: &mut [u8],
232 ) -> Result<Handler<F, Ctx>, StaticProgramError>
233 where
234 Ctx: MeteredExecutionCtxTrait,
235 {
236 let pre_compute: &mut E2PreCompute<LessThanPreCompute> = data.borrow_mut();
237 pre_compute.chip_idx = chip_idx as u32;
238 let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
239 dispatch!(execute_e2_handler, is_imm, is_sltu)
240 }
241}
242#[cfg(feature = "aot")]
243impl<F, A, const LIMB_BITS: usize> AotMeteredExecutor<F>
244 for LessThanExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
245where
246 F: PrimeField32,
247{
248 fn is_aot_metered_supported(&self, _inst: &Instruction<F>) -> bool {
249 true
250 }
251
252 fn generate_x86_metered_asm(
253 &self,
254 inst: &Instruction<F>,
255 pc: u32,
256 chip_idx: usize,
257 config: &SystemConfig,
258 ) -> Result<String, AotError> {
259 let mut asm_str = self.generate_x86_asm(inst, pc)?;
260 asm_str += &update_height_change_asm(chip_idx, 1)?;
261 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
263 if inst.e.as_canonical_u32() != RV32_IMM_AS {
264 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
266 }
267 asm_str += &update_adapter_heights_asm(config, RV32_REGISTER_AS)?;
269 Ok(asm_str)
270 }
271}
272
273#[inline(always)]
274unsafe fn execute_e12_impl<
275 F: PrimeField32,
276 CTX: ExecutionCtxTrait,
277 const E_IS_IMM: bool,
278 const IS_U32: bool,
279>(
280 pre_compute: &LessThanPreCompute,
281 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
282) {
283 let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
284 let rs2 = if E_IS_IMM {
285 pre_compute.c.to_le_bytes()
286 } else {
287 exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c)
288 };
289 let cmp_result = if IS_U32 {
290 u32::from_le_bytes(rs1) < u32::from_le_bytes(rs2)
291 } else {
292 i32::from_le_bytes(rs1) < i32::from_le_bytes(rs2)
293 };
294 let mut rd = [0u8; RV32_REGISTER_NUM_LIMBS];
295 rd[0] = cmp_result as u8;
296 exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd);
297
298 let pc = exec_state.pc();
299 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
300}
301
302#[create_handler]
303#[inline(always)]
304unsafe fn execute_e1_impl<
305 F: PrimeField32,
306 CTX: ExecutionCtxTrait,
307 const E_IS_IMM: bool,
308 const IS_U32: bool,
309>(
310 pre_compute: *const u8,
311 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
312) {
313 let pre_compute: &LessThanPreCompute =
314 std::slice::from_raw_parts(pre_compute, size_of::<LessThanPreCompute>()).borrow();
315 execute_e12_impl::<F, CTX, E_IS_IMM, IS_U32>(pre_compute, exec_state);
316}
317
318#[create_handler]
319#[inline(always)]
320unsafe fn execute_e2_impl<
321 F: PrimeField32,
322 CTX: MeteredExecutionCtxTrait,
323 const E_IS_IMM: bool,
324 const IS_U32: bool,
325>(
326 pre_compute: *const u8,
327 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
328) {
329 let pre_compute: &E2PreCompute<LessThanPreCompute> =
330 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<LessThanPreCompute>>())
331 .borrow();
332 exec_state
333 .ctx
334 .on_height_change(pre_compute.chip_idx as usize, 1);
335 execute_e12_impl::<F, CTX, E_IS_IMM, IS_U32>(&pre_compute.data, exec_state);
336}