1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32LessThan256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10 instruction::Instruction,
11 program::DEFAULT_PC_STEP,
12 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13 LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
16use openvm_rv32im_circuit::LessThanExecutor;
17use openvm_rv32im_transpiler::LessThanOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{common, Rv32LessThan256Executor, INT256_NUM_LIMBS};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32LessThan256Executor {
25 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26 Self(LessThanExecutor::new(adapter, offset))
27 }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct LessThanPreCompute {
33 a: u8,
34 b: u8,
35 c: u8,
36}
37
38macro_rules! dispatch {
39 ($execute_impl:ident, $local_opcode:ident) => {
40 Ok(match $local_opcode {
41 LessThanOpcode::SLT => $execute_impl::<_, _, false>,
42 LessThanOpcode::SLTU => $execute_impl::<_, _, true>,
43 })
44 };
45}
46
47impl<F: PrimeField32> Executor<F> for Rv32LessThan256Executor {
48 fn pre_compute_size(&self) -> usize {
49 size_of::<LessThanPreCompute>()
50 }
51
52 #[cfg(not(feature = "tco"))]
53 fn pre_compute<Ctx>(
54 &self,
55 pc: u32,
56 inst: &Instruction<F>,
57 data: &mut [u8],
58 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
59 where
60 Ctx: ExecutionCtxTrait,
61 {
62 let data: &mut LessThanPreCompute = data.borrow_mut();
63 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
64 dispatch!(execute_e1_handler, local_opcode)
65 }
66
67 #[cfg(feature = "tco")]
68 fn handler<Ctx>(
69 &self,
70 pc: u32,
71 inst: &Instruction<F>,
72 data: &mut [u8],
73 ) -> Result<Handler<F, Ctx>, StaticProgramError>
74 where
75 Ctx: ExecutionCtxTrait,
76 {
77 let data: &mut LessThanPreCompute = data.borrow_mut();
78 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
79 dispatch!(execute_e1_handler, local_opcode)
80 }
81}
82
83impl<F: PrimeField32> MeteredExecutor<F> for Rv32LessThan256Executor {
84 fn metered_pre_compute_size(&self) -> usize {
85 size_of::<E2PreCompute<LessThanPreCompute>>()
86 }
87
88 #[cfg(not(feature = "tco"))]
89 fn metered_pre_compute<Ctx>(
90 &self,
91 chip_idx: usize,
92 pc: u32,
93 inst: &Instruction<F>,
94 data: &mut [u8],
95 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
96 where
97 Ctx: MeteredExecutionCtxTrait,
98 {
99 let data: &mut E2PreCompute<LessThanPreCompute> = data.borrow_mut();
100 data.chip_idx = chip_idx as u32;
101 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
102 dispatch!(execute_e2_handler, local_opcode)
103 }
104
105 #[cfg(feature = "tco")]
106 fn metered_handler<Ctx>(
107 &self,
108 chip_idx: usize,
109 pc: u32,
110 inst: &Instruction<F>,
111 data: &mut [u8],
112 ) -> Result<Handler<F, Ctx>, StaticProgramError>
113 where
114 Ctx: MeteredExecutionCtxTrait,
115 {
116 let data: &mut E2PreCompute<LessThanPreCompute> = data.borrow_mut();
117 data.chip_idx = chip_idx as u32;
118 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
119 dispatch!(execute_e2_handler, local_opcode)
120 }
121}
122
123#[inline(always)]
124unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_U256: bool>(
125 pre_compute: &LessThanPreCompute,
126 instret: &mut u64,
127 pc: &mut u32,
128 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
129) {
130 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
131 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
132 let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
133 let rs1 =
134 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
135 let rs2 =
136 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
137 let cmp_result = if IS_U256 {
138 common::u256_lt(rs1, rs2)
139 } else {
140 common::i256_lt(rs1, rs2)
141 };
142 let mut rd = [0u8; INT256_NUM_LIMBS];
143 rd[0] = cmp_result as u8;
144 exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
145
146 *pc += DEFAULT_PC_STEP;
147 *instret += 1;
148}
149
150#[create_handler]
151#[inline(always)]
152unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_U256: bool>(
153 pre_compute: &[u8],
154 instret: &mut u64,
155 pc: &mut u32,
156 _instret_end: u64,
157 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
158) {
159 let pre_compute: &LessThanPreCompute = pre_compute.borrow();
160 execute_e12_impl::<F, CTX, IS_U256>(pre_compute, instret, pc, exec_state);
161}
162
163#[create_handler]
164#[inline(always)]
165unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_U256: bool>(
166 pre_compute: &[u8],
167 instret: &mut u64,
168 pc: &mut u32,
169 _arg: u64,
170 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
171) {
172 let pre_compute: &E2PreCompute<LessThanPreCompute> = pre_compute.borrow();
173 exec_state
174 .ctx
175 .on_height_change(pre_compute.chip_idx as usize, 1);
176 execute_e12_impl::<F, CTX, IS_U256>(&pre_compute.data, instret, pc, exec_state);
177}
178
179impl Rv32LessThan256Executor {
180 fn pre_compute_impl<F: PrimeField32>(
181 &self,
182 pc: u32,
183 inst: &Instruction<F>,
184 data: &mut LessThanPreCompute,
185 ) -> Result<LessThanOpcode, StaticProgramError> {
186 let Instruction {
187 opcode,
188 a,
189 b,
190 c,
191 d,
192 e,
193 ..
194 } = inst;
195 let e_u32 = e.as_canonical_u32();
196 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
197 return Err(StaticProgramError::InvalidInstruction(pc));
198 }
199 *data = LessThanPreCompute {
200 a: a.as_canonical_u32() as u8,
201 b: b.as_canonical_u32() as u8,
202 c: c.as_canonical_u32() as u8,
203 };
204 let local_opcode = LessThanOpcode::from_usize(
205 opcode.local_opcode_idx(Rv32LessThan256Opcode::CLASS_OFFSET),
206 );
207 Ok(local_opcode)
208 }
209}