1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32LessThan256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10 instruction::Instruction,
11 program::DEFAULT_PC_STEP,
12 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13 LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
16use openvm_rv32im_circuit::LessThanExecutor;
17use openvm_rv32im_transpiler::LessThanOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{common, Rv32LessThan256Executor, INT256_NUM_LIMBS};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32LessThan256Executor {
25 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26 Self(LessThanExecutor::new(adapter, offset))
27 }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct LessThanPreCompute {
33 a: u8,
34 b: u8,
35 c: u8,
36}
37
38macro_rules! dispatch {
39 ($execute_impl:ident, $local_opcode:ident) => {
40 Ok(match $local_opcode {
41 LessThanOpcode::SLT => $execute_impl::<_, _, false>,
42 LessThanOpcode::SLTU => $execute_impl::<_, _, true>,
43 })
44 };
45}
46
47impl<F: PrimeField32> InterpreterExecutor<F> for Rv32LessThan256Executor {
48 fn pre_compute_size(&self) -> usize {
49 size_of::<LessThanPreCompute>()
50 }
51
52 #[cfg(not(feature = "tco"))]
53 fn pre_compute<Ctx>(
54 &self,
55 pc: u32,
56 inst: &Instruction<F>,
57 data: &mut [u8],
58 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
59 where
60 Ctx: ExecutionCtxTrait,
61 {
62 let data: &mut LessThanPreCompute = data.borrow_mut();
63 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
64 dispatch!(execute_e1_handler, local_opcode)
65 }
66
67 #[cfg(feature = "tco")]
68 fn handler<Ctx>(
69 &self,
70 pc: u32,
71 inst: &Instruction<F>,
72 data: &mut [u8],
73 ) -> Result<Handler<F, Ctx>, StaticProgramError>
74 where
75 Ctx: ExecutionCtxTrait,
76 {
77 let data: &mut LessThanPreCompute = data.borrow_mut();
78 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
79 dispatch!(execute_e1_handler, local_opcode)
80 }
81}
82
83#[cfg(feature = "aot")]
84impl<F: PrimeField32> AotExecutor<F> for Rv32LessThan256Executor {}
85
86impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32LessThan256Executor {
87 fn metered_pre_compute_size(&self) -> usize {
88 size_of::<E2PreCompute<LessThanPreCompute>>()
89 }
90
91 #[cfg(not(feature = "tco"))]
92 fn metered_pre_compute<Ctx>(
93 &self,
94 chip_idx: usize,
95 pc: u32,
96 inst: &Instruction<F>,
97 data: &mut [u8],
98 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
99 where
100 Ctx: MeteredExecutionCtxTrait,
101 {
102 let data: &mut E2PreCompute<LessThanPreCompute> = data.borrow_mut();
103 data.chip_idx = chip_idx as u32;
104 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
105 dispatch!(execute_e2_handler, local_opcode)
106 }
107
108 #[cfg(feature = "tco")]
109 fn metered_handler<Ctx>(
110 &self,
111 chip_idx: usize,
112 pc: u32,
113 inst: &Instruction<F>,
114 data: &mut [u8],
115 ) -> Result<Handler<F, Ctx>, StaticProgramError>
116 where
117 Ctx: MeteredExecutionCtxTrait,
118 {
119 let data: &mut E2PreCompute<LessThanPreCompute> = data.borrow_mut();
120 data.chip_idx = chip_idx as u32;
121 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
122 dispatch!(execute_e2_handler, local_opcode)
123 }
124}
125
126#[cfg(feature = "aot")]
127impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32LessThan256Executor {}
128
129#[inline(always)]
130unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_U256: bool>(
131 pre_compute: &LessThanPreCompute,
132 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
133) {
134 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
135 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
136 let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
137 let rs1 =
138 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
139 let rs2 =
140 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
141 let cmp_result = if IS_U256 {
142 common::u256_lt(rs1, rs2)
143 } else {
144 common::i256_lt(rs1, rs2)
145 };
146 let mut rd = [0u8; INT256_NUM_LIMBS];
147 rd[0] = cmp_result as u8;
148 exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
149
150 let pc = exec_state.pc();
151 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
152}
153
154#[create_handler]
155#[inline(always)]
156unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_U256: bool>(
157 pre_compute: *const u8,
158 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
159) {
160 let pre_compute: &LessThanPreCompute =
161 std::slice::from_raw_parts(pre_compute, size_of::<LessThanPreCompute>()).borrow();
162 execute_e12_impl::<F, CTX, IS_U256>(pre_compute, exec_state);
163}
164
165#[create_handler]
166#[inline(always)]
167unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_U256: bool>(
168 pre_compute: *const u8,
169 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
170) {
171 let pre_compute: &E2PreCompute<LessThanPreCompute> =
172 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<LessThanPreCompute>>())
173 .borrow();
174 exec_state
175 .ctx
176 .on_height_change(pre_compute.chip_idx as usize, 1);
177 execute_e12_impl::<F, CTX, IS_U256>(&pre_compute.data, exec_state);
178}
179
180impl Rv32LessThan256Executor {
181 fn pre_compute_impl<F: PrimeField32>(
182 &self,
183 pc: u32,
184 inst: &Instruction<F>,
185 data: &mut LessThanPreCompute,
186 ) -> Result<LessThanOpcode, StaticProgramError> {
187 let Instruction {
188 opcode,
189 a,
190 b,
191 c,
192 d,
193 e,
194 ..
195 } = inst;
196 let e_u32 = e.as_canonical_u32();
197 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
198 return Err(StaticProgramError::InvalidInstruction(pc));
199 }
200 *data = LessThanPreCompute {
201 a: a.as_canonical_u32() as u8,
202 b: b.as_canonical_u32() as u8,
203 c: c.as_canonical_u32() as u8,
204 };
205 let local_opcode = LessThanOpcode::from_usize(
206 opcode.local_opcode_idx(Rv32LessThan256Opcode::CLASS_OFFSET),
207 );
208 Ok(local_opcode)
209 }
210}