1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32BranchLessThan256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10 instruction::Instruction,
11 program::DEFAULT_PC_STEP,
12 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13 LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
16use openvm_rv32im_circuit::BranchLessThanExecutor;
17use openvm_rv32im_transpiler::BranchLessThanOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{
21 common::{i256_lt, u256_lt},
22 Rv32BranchLessThan256Executor, INT256_NUM_LIMBS,
23};
24
25type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
26
27impl Rv32BranchLessThan256Executor {
28 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
29 Self(BranchLessThanExecutor::new(adapter, offset))
30 }
31}
32
33#[derive(AlignedBytesBorrow, Clone)]
34#[repr(C)]
35struct BranchLtPreCompute {
36 imm: isize,
37 a: u8,
38 b: u8,
39}
40
41macro_rules! dispatch {
42 ($execute_impl:ident, $local_opcode:ident) => {
43 Ok(match $local_opcode {
44 BranchLessThanOpcode::BLT => $execute_impl::<_, _, BltOp>,
45 BranchLessThanOpcode::BLTU => $execute_impl::<_, _, BltuOp>,
46 BranchLessThanOpcode::BGE => $execute_impl::<_, _, BgeOp>,
47 BranchLessThanOpcode::BGEU => $execute_impl::<_, _, BgeuOp>,
48 })
49 };
50}
51
52impl<F: PrimeField32> InterpreterExecutor<F> for Rv32BranchLessThan256Executor {
53 fn pre_compute_size(&self) -> usize {
54 size_of::<BranchLtPreCompute>()
55 }
56
57 #[cfg(not(feature = "tco"))]
58 fn pre_compute<Ctx>(
59 &self,
60 pc: u32,
61 inst: &Instruction<F>,
62 data: &mut [u8],
63 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
64 where
65 Ctx: ExecutionCtxTrait,
66 {
67 let data: &mut BranchLtPreCompute = data.borrow_mut();
68 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
69 dispatch!(execute_e1_handler, local_opcode)
70 }
71
72 #[cfg(feature = "tco")]
73 fn handler<Ctx>(
74 &self,
75 pc: u32,
76 inst: &Instruction<F>,
77 data: &mut [u8],
78 ) -> Result<Handler<F, Ctx>, StaticProgramError>
79 where
80 Ctx: ExecutionCtxTrait,
81 {
82 let data: &mut BranchLtPreCompute = data.borrow_mut();
83 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
84 dispatch!(execute_e1_handler, local_opcode)
85 }
86}
87
88#[cfg(feature = "aot")]
89impl<F: PrimeField32> AotExecutor<F> for Rv32BranchLessThan256Executor {}
90
91impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32BranchLessThan256Executor {
92 fn metered_pre_compute_size(&self) -> usize {
93 size_of::<E2PreCompute<BranchLtPreCompute>>()
94 }
95
96 #[cfg(not(feature = "tco"))]
97 fn metered_pre_compute<Ctx>(
98 &self,
99 chip_idx: usize,
100 pc: u32,
101 inst: &Instruction<F>,
102 data: &mut [u8],
103 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
104 where
105 Ctx: MeteredExecutionCtxTrait,
106 {
107 let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
108 data.chip_idx = chip_idx as u32;
109 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
110 dispatch!(execute_e2_handler, local_opcode)
111 }
112
113 #[cfg(feature = "tco")]
114 fn metered_handler<Ctx>(
115 &self,
116 chip_idx: usize,
117 pc: u32,
118 inst: &Instruction<F>,
119 data: &mut [u8],
120 ) -> Result<Handler<F, Ctx>, StaticProgramError>
121 where
122 Ctx: MeteredExecutionCtxTrait,
123 {
124 let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
125 data.chip_idx = chip_idx as u32;
126 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
127 dispatch!(execute_e2_handler, local_opcode)
128 }
129}
130
131#[cfg(feature = "aot")]
132impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32BranchLessThan256Executor {}
133
134#[inline(always)]
135unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
136 pre_compute: &BranchLtPreCompute,
137 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
138) {
139 let mut pc = exec_state.pc();
140 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
141 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
142 let rs1 =
143 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
144 let rs2 =
145 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
146 let cmp_result = OP::compute(rs1, rs2);
147 if cmp_result {
148 pc = (pc as isize + pre_compute.imm) as u32;
149 } else {
150 pc = pc.wrapping_add(DEFAULT_PC_STEP);
151 }
152 exec_state.set_pc(pc);
153}
154
155#[create_handler]
156#[inline(always)]
157unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
158 pre_compute: *const u8,
159 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
160) {
161 let pre_compute: &BranchLtPreCompute =
162 std::slice::from_raw_parts(pre_compute, size_of::<BranchLtPreCompute>()).borrow();
163 execute_e12_impl::<F, CTX, OP>(pre_compute, exec_state);
164}
165
166#[create_handler]
167#[inline(always)]
168unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: BranchLessThanOp>(
169 pre_compute: *const u8,
170 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
171) {
172 let pre_compute: &E2PreCompute<BranchLtPreCompute> =
173 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<BranchLtPreCompute>>())
174 .borrow();
175 exec_state
176 .ctx
177 .on_height_change(pre_compute.chip_idx as usize, 1);
178 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, exec_state);
179}
180
181impl Rv32BranchLessThan256Executor {
182 fn pre_compute_impl<F: PrimeField32>(
183 &self,
184 pc: u32,
185 inst: &Instruction<F>,
186 data: &mut BranchLtPreCompute,
187 ) -> Result<BranchLessThanOpcode, StaticProgramError> {
188 let Instruction {
189 opcode,
190 a,
191 b,
192 c,
193 d,
194 e,
195 ..
196 } = inst;
197 let c = c.as_canonical_u32();
198 let imm = if F::ORDER_U32 - c < c {
199 -((F::ORDER_U32 - c) as isize)
200 } else {
201 c as isize
202 };
203 let e_u32 = e.as_canonical_u32();
204 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
205 return Err(StaticProgramError::InvalidInstruction(pc));
206 }
207 *data = BranchLtPreCompute {
208 imm,
209 a: a.as_canonical_u32() as u8,
210 b: b.as_canonical_u32() as u8,
211 };
212 let local_opcode = BranchLessThanOpcode::from_usize(
213 opcode.local_opcode_idx(Rv32BranchLessThan256Opcode::CLASS_OFFSET),
214 );
215 Ok(local_opcode)
216 }
217}
218
219trait BranchLessThanOp {
220 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool;
221}
222struct BltOp;
223struct BltuOp;
224struct BgeOp;
225struct BgeuOp;
226
227impl BranchLessThanOp for BltOp {
228 #[inline(always)]
229 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
230 i256_lt(rs1, rs2)
231 }
232}
233impl BranchLessThanOp for BltuOp {
234 #[inline(always)]
235 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
236 u256_lt(rs1, rs2)
237 }
238}
239impl BranchLessThanOp for BgeOp {
240 #[inline(always)]
241 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
242 !i256_lt(rs1, rs2)
243 }
244}
245impl BranchLessThanOp for BgeuOp {
246 #[inline(always)]
247 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
248 !u256_lt(rs1, rs2)
249 }
250}