1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32BranchLessThan256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7 instruction::Instruction,
8 program::DEFAULT_PC_STEP,
9 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10 LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
13use openvm_rv32im_circuit::BranchLessThanExecutor;
14use openvm_rv32im_transpiler::BranchLessThanOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18 common::{i256_lt, u256_lt},
19 Rv32BranchLessThan256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
23
24impl Rv32BranchLessThan256Executor {
25 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26 Self(BranchLessThanExecutor::new(adapter, offset))
27 }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct BranchLtPreCompute {
33 imm: isize,
34 a: u8,
35 b: u8,
36}
37
38macro_rules! dispatch {
39 ($execute_impl:ident, $local_opcode:ident) => {
40 Ok(match $local_opcode {
41 BranchLessThanOpcode::BLT => $execute_impl::<_, _, BltOp>,
42 BranchLessThanOpcode::BLTU => $execute_impl::<_, _, BltuOp>,
43 BranchLessThanOpcode::BGE => $execute_impl::<_, _, BgeOp>,
44 BranchLessThanOpcode::BGEU => $execute_impl::<_, _, BgeuOp>,
45 })
46 };
47}
48
49impl<F: PrimeField32> Executor<F> for Rv32BranchLessThan256Executor {
50 fn pre_compute_size(&self) -> usize {
51 size_of::<BranchLtPreCompute>()
52 }
53
54 fn pre_compute<Ctx>(
55 &self,
56 pc: u32,
57 inst: &Instruction<F>,
58 data: &mut [u8],
59 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
60 where
61 Ctx: ExecutionCtxTrait,
62 {
63 let data: &mut BranchLtPreCompute = data.borrow_mut();
64 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
65 dispatch!(execute_e1_impl, local_opcode)
66 }
67
68 #[cfg(feature = "tco")]
69 fn handler<Ctx>(
70 &self,
71 pc: u32,
72 inst: &Instruction<F>,
73 data: &mut [u8],
74 ) -> Result<Handler<F, Ctx>, StaticProgramError>
75 where
76 Ctx: ExecutionCtxTrait,
77 {
78 let data: &mut BranchLtPreCompute = data.borrow_mut();
79 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
80 dispatch!(execute_e1_tco_handler, local_opcode)
81 }
82}
83
84impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchLessThan256Executor {
85 fn metered_pre_compute_size(&self) -> usize {
86 size_of::<E2PreCompute<BranchLtPreCompute>>()
87 }
88
89 fn metered_pre_compute<Ctx>(
90 &self,
91 chip_idx: usize,
92 pc: u32,
93 inst: &Instruction<F>,
94 data: &mut [u8],
95 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
96 where
97 Ctx: MeteredExecutionCtxTrait,
98 {
99 let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
100 data.chip_idx = chip_idx as u32;
101 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
102 dispatch!(execute_e2_impl, local_opcode)
103 }
104
105 #[cfg(feature = "tco")]
106 fn metered_handler<Ctx>(
107 &self,
108 chip_idx: usize,
109 pc: u32,
110 inst: &Instruction<F>,
111 data: &mut [u8],
112 ) -> Result<Handler<F, Ctx>, StaticProgramError>
113 where
114 Ctx: MeteredExecutionCtxTrait,
115 {
116 let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
117 data.chip_idx = chip_idx as u32;
118 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
119 dispatch!(execute_e2_tco_handler, local_opcode)
120 }
121}
122
123#[inline(always)]
124unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
125 pre_compute: &BranchLtPreCompute,
126 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
127) {
128 let rs1_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
129 let rs2_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
130 let rs1 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
131 let rs2 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
132 let cmp_result = OP::compute(rs1, rs2);
133 if cmp_result {
134 vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32;
135 } else {
136 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
137 }
138 vm_state.instret += 1;
139}
140
141#[create_tco_handler]
142unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
143 pre_compute: &[u8],
144 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
145) {
146 let pre_compute: &BranchLtPreCompute = pre_compute.borrow();
147 execute_e12_impl::<F, CTX, OP>(pre_compute, vm_state);
148}
149
150#[create_tco_handler]
151unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: BranchLessThanOp>(
152 pre_compute: &[u8],
153 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
154) {
155 let pre_compute: &E2PreCompute<BranchLtPreCompute> = pre_compute.borrow();
156 vm_state
157 .ctx
158 .on_height_change(pre_compute.chip_idx as usize, 1);
159 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, vm_state);
160}
161
162impl Rv32BranchLessThan256Executor {
163 fn pre_compute_impl<F: PrimeField32>(
164 &self,
165 pc: u32,
166 inst: &Instruction<F>,
167 data: &mut BranchLtPreCompute,
168 ) -> Result<BranchLessThanOpcode, StaticProgramError> {
169 let Instruction {
170 opcode,
171 a,
172 b,
173 c,
174 d,
175 e,
176 ..
177 } = inst;
178 let c = c.as_canonical_u32();
179 let imm = if F::ORDER_U32 - c < c {
180 -((F::ORDER_U32 - c) as isize)
181 } else {
182 c as isize
183 };
184 let e_u32 = e.as_canonical_u32();
185 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
186 return Err(StaticProgramError::InvalidInstruction(pc));
187 }
188 *data = BranchLtPreCompute {
189 imm,
190 a: a.as_canonical_u32() as u8,
191 b: b.as_canonical_u32() as u8,
192 };
193 let local_opcode = BranchLessThanOpcode::from_usize(
194 opcode.local_opcode_idx(Rv32BranchLessThan256Opcode::CLASS_OFFSET),
195 );
196 Ok(local_opcode)
197 }
198}
199
200trait BranchLessThanOp {
201 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool;
202}
203struct BltOp;
204struct BltuOp;
205struct BgeOp;
206struct BgeuOp;
207
208impl BranchLessThanOp for BltOp {
209 #[inline(always)]
210 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
211 i256_lt(rs1, rs2)
212 }
213}
214impl BranchLessThanOp for BltuOp {
215 #[inline(always)]
216 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
217 u256_lt(rs1, rs2)
218 }
219}
220impl BranchLessThanOp for BgeOp {
221 #[inline(always)]
222 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
223 !i256_lt(rs1, rs2)
224 }
225}
226impl BranchLessThanOp for BgeuOp {
227 #[inline(always)]
228 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
229 !u256_lt(rs1, rs2)
230 }
231}