1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32BranchLessThan256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10 instruction::Instruction,
11 program::DEFAULT_PC_STEP,
12 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13 LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
16use openvm_rv32im_circuit::BranchLessThanExecutor;
17use openvm_rv32im_transpiler::BranchLessThanOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{
21 common::{i256_lt, u256_lt},
22 Rv32BranchLessThan256Executor, INT256_NUM_LIMBS,
23};
24
25type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
26
27impl Rv32BranchLessThan256Executor {
28 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
29 Self(BranchLessThanExecutor::new(adapter, offset))
30 }
31}
32
33#[derive(AlignedBytesBorrow, Clone)]
34#[repr(C)]
35struct BranchLtPreCompute {
36 imm: isize,
37 a: u8,
38 b: u8,
39}
40
41macro_rules! dispatch {
42 ($execute_impl:ident, $local_opcode:ident) => {
43 Ok(match $local_opcode {
44 BranchLessThanOpcode::BLT => $execute_impl::<_, _, BltOp>,
45 BranchLessThanOpcode::BLTU => $execute_impl::<_, _, BltuOp>,
46 BranchLessThanOpcode::BGE => $execute_impl::<_, _, BgeOp>,
47 BranchLessThanOpcode::BGEU => $execute_impl::<_, _, BgeuOp>,
48 })
49 };
50}
51
52impl<F: PrimeField32> Executor<F> for Rv32BranchLessThan256Executor {
53 fn pre_compute_size(&self) -> usize {
54 size_of::<BranchLtPreCompute>()
55 }
56
57 #[cfg(not(feature = "tco"))]
58 fn pre_compute<Ctx>(
59 &self,
60 pc: u32,
61 inst: &Instruction<F>,
62 data: &mut [u8],
63 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
64 where
65 Ctx: ExecutionCtxTrait,
66 {
67 let data: &mut BranchLtPreCompute = data.borrow_mut();
68 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
69 dispatch!(execute_e1_handler, local_opcode)
70 }
71
72 #[cfg(feature = "tco")]
73 fn handler<Ctx>(
74 &self,
75 pc: u32,
76 inst: &Instruction<F>,
77 data: &mut [u8],
78 ) -> Result<Handler<F, Ctx>, StaticProgramError>
79 where
80 Ctx: ExecutionCtxTrait,
81 {
82 let data: &mut BranchLtPreCompute = data.borrow_mut();
83 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
84 dispatch!(execute_e1_handler, local_opcode)
85 }
86}
87
88impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchLessThan256Executor {
89 fn metered_pre_compute_size(&self) -> usize {
90 size_of::<E2PreCompute<BranchLtPreCompute>>()
91 }
92
93 #[cfg(not(feature = "tco"))]
94 fn metered_pre_compute<Ctx>(
95 &self,
96 chip_idx: usize,
97 pc: u32,
98 inst: &Instruction<F>,
99 data: &mut [u8],
100 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
101 where
102 Ctx: MeteredExecutionCtxTrait,
103 {
104 let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
105 data.chip_idx = chip_idx as u32;
106 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
107 dispatch!(execute_e2_handler, local_opcode)
108 }
109
110 #[cfg(feature = "tco")]
111 fn metered_handler<Ctx>(
112 &self,
113 chip_idx: usize,
114 pc: u32,
115 inst: &Instruction<F>,
116 data: &mut [u8],
117 ) -> Result<Handler<F, Ctx>, StaticProgramError>
118 where
119 Ctx: MeteredExecutionCtxTrait,
120 {
121 let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
122 data.chip_idx = chip_idx as u32;
123 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
124 dispatch!(execute_e2_handler, local_opcode)
125 }
126}
127
128#[inline(always)]
129unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
130 pre_compute: &BranchLtPreCompute,
131 instret: &mut u64,
132 pc: &mut u32,
133 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
134) {
135 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
136 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
137 let rs1 =
138 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
139 let rs2 =
140 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
141 let cmp_result = OP::compute(rs1, rs2);
142 if cmp_result {
143 *pc = (*pc as isize + pre_compute.imm) as u32;
144 } else {
145 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
146 }
147 *instret += 1;
148}
149
150#[create_handler]
151#[inline(always)]
152unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
153 pre_compute: &[u8],
154 instret: &mut u64,
155 pc: &mut u32,
156 _instret_end: u64,
157 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
158) {
159 let pre_compute: &BranchLtPreCompute = pre_compute.borrow();
160 execute_e12_impl::<F, CTX, OP>(pre_compute, instret, pc, exec_state);
161}
162
163#[create_handler]
164#[inline(always)]
165unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: BranchLessThanOp>(
166 pre_compute: &[u8],
167 instret: &mut u64,
168 pc: &mut u32,
169 _arg: u64,
170 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
171) {
172 let pre_compute: &E2PreCompute<BranchLtPreCompute> = pre_compute.borrow();
173 exec_state
174 .ctx
175 .on_height_change(pre_compute.chip_idx as usize, 1);
176 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, instret, pc, exec_state);
177}
178
179impl Rv32BranchLessThan256Executor {
180 fn pre_compute_impl<F: PrimeField32>(
181 &self,
182 pc: u32,
183 inst: &Instruction<F>,
184 data: &mut BranchLtPreCompute,
185 ) -> Result<BranchLessThanOpcode, StaticProgramError> {
186 let Instruction {
187 opcode,
188 a,
189 b,
190 c,
191 d,
192 e,
193 ..
194 } = inst;
195 let c = c.as_canonical_u32();
196 let imm = if F::ORDER_U32 - c < c {
197 -((F::ORDER_U32 - c) as isize)
198 } else {
199 c as isize
200 };
201 let e_u32 = e.as_canonical_u32();
202 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
203 return Err(StaticProgramError::InvalidInstruction(pc));
204 }
205 *data = BranchLtPreCompute {
206 imm,
207 a: a.as_canonical_u32() as u8,
208 b: b.as_canonical_u32() as u8,
209 };
210 let local_opcode = BranchLessThanOpcode::from_usize(
211 opcode.local_opcode_idx(Rv32BranchLessThan256Opcode::CLASS_OFFSET),
212 );
213 Ok(local_opcode)
214 }
215}
216
217trait BranchLessThanOp {
218 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool;
219}
220struct BltOp;
221struct BltuOp;
222struct BgeOp;
223struct BgeuOp;
224
225impl BranchLessThanOp for BltOp {
226 #[inline(always)]
227 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
228 i256_lt(rs1, rs2)
229 }
230}
231impl BranchLessThanOp for BltuOp {
232 #[inline(always)]
233 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
234 u256_lt(rs1, rs2)
235 }
236}
237impl BranchLessThanOp for BgeOp {
238 #[inline(always)]
239 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
240 !i256_lt(rs1, rs2)
241 }
242}
243impl BranchLessThanOp for BgeuOp {
244 #[inline(always)]
245 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
246 !u256_lt(rs1, rs2)
247 }
248}