openvm_rv32im_circuit/branch_lt/
execution.rs1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode,
10};
11use openvm_rv32im_transpiler::BranchLessThanOpcode;
12use openvm_stark_backend::p3_field::PrimeField32;
13
14use super::core::BranchLessThanExecutor;
15
16#[derive(AlignedBytesBorrow, Clone)]
17#[repr(C)]
18struct BranchLePreCompute {
19 imm: isize,
20 a: u8,
21 b: u8,
22}
23
24macro_rules! dispatch {
25 ($execute_impl:ident, $local_opcode:ident) => {
26 match $local_opcode {
27 BranchLessThanOpcode::BLT => Ok($execute_impl::<_, _, BltOp>),
28 BranchLessThanOpcode::BLTU => Ok($execute_impl::<_, _, BltuOp>),
29 BranchLessThanOpcode::BGE => Ok($execute_impl::<_, _, BgeOp>),
30 BranchLessThanOpcode::BGEU => Ok($execute_impl::<_, _, BgeuOp>),
31 }
32 };
33}
34
35impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize>
36 BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
37{
38 #[inline(always)]
39 fn pre_compute_impl<F: PrimeField32>(
40 &self,
41 pc: u32,
42 inst: &Instruction<F>,
43 data: &mut BranchLePreCompute,
44 ) -> Result<BranchLessThanOpcode, StaticProgramError> {
45 let &Instruction {
46 opcode, a, b, c, d, ..
47 } = inst;
48 let local_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset));
49 let c = c.as_canonical_u32();
50 let imm = if F::ORDER_U32 - c < c {
51 -((F::ORDER_U32 - c) as isize)
52 } else {
53 c as isize
54 };
55 if d.as_canonical_u32() != RV32_REGISTER_AS {
56 return Err(StaticProgramError::InvalidInstruction(pc));
57 }
58 *data = BranchLePreCompute {
59 imm,
60 a: a.as_canonical_u32() as u8,
61 b: b.as_canonical_u32() as u8,
62 };
63 Ok(local_opcode)
64 }
65}
66
67impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> Executor<F>
68 for BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
69where
70 F: PrimeField32,
71{
72 #[inline(always)]
73 fn pre_compute_size(&self) -> usize {
74 size_of::<BranchLePreCompute>()
75 }
76
77 #[inline(always)]
78 fn pre_compute<Ctx: ExecutionCtxTrait>(
79 &self,
80 pc: u32,
81 inst: &Instruction<F>,
82 data: &mut [u8],
83 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
84 let data: &mut BranchLePreCompute = data.borrow_mut();
85 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
86 dispatch!(execute_e1_impl, local_opcode)
87 }
88
89 #[cfg(feature = "tco")]
90 fn handler<Ctx>(
91 &self,
92 pc: u32,
93 inst: &Instruction<F>,
94 data: &mut [u8],
95 ) -> Result<Handler<F, Ctx>, StaticProgramError>
96 where
97 Ctx: ExecutionCtxTrait,
98 {
99 let data: &mut BranchLePreCompute = data.borrow_mut();
100 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
101 dispatch!(execute_e1_tco_handler, local_opcode)
102 }
103}
104
105impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> MeteredExecutor<F>
106 for BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
107where
108 F: PrimeField32,
109{
110 fn metered_pre_compute_size(&self) -> usize {
111 size_of::<E2PreCompute<BranchLePreCompute>>()
112 }
113
114 fn metered_pre_compute<Ctx>(
115 &self,
116 chip_idx: usize,
117 pc: u32,
118 inst: &Instruction<F>,
119 data: &mut [u8],
120 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
121 where
122 Ctx: MeteredExecutionCtxTrait,
123 {
124 let data: &mut E2PreCompute<BranchLePreCompute> = data.borrow_mut();
125 data.chip_idx = chip_idx as u32;
126 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
127 dispatch!(execute_e2_impl, local_opcode)
128 }
129
130 #[cfg(feature = "tco")]
131 fn metered_handler<Ctx>(
132 &self,
133 chip_idx: usize,
134 pc: u32,
135 inst: &Instruction<F>,
136 data: &mut [u8],
137 ) -> Result<Handler<F, Ctx>, StaticProgramError>
138 where
139 Ctx: MeteredExecutionCtxTrait,
140 {
141 let data: &mut E2PreCompute<BranchLePreCompute> = data.borrow_mut();
142 data.chip_idx = chip_idx as u32;
143 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
144 dispatch!(execute_e2_tco_handler, local_opcode)
145 }
146}
147
148#[inline(always)]
149unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
150 pre_compute: &BranchLePreCompute,
151 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
152) {
153 let rs1 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
154 let rs2 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
155 let jmp = <OP as BranchLessThanOp>::compute(rs1, rs2);
156 if jmp {
157 vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32;
158 } else {
159 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
160 };
161 vm_state.instret += 1;
162}
163
164#[create_tco_handler]
165unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
166 pre_compute: &[u8],
167 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
168) {
169 let pre_compute: &BranchLePreCompute = pre_compute.borrow();
170 execute_e12_impl::<F, CTX, OP>(pre_compute, vm_state);
171}
172
173#[create_tco_handler]
174unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: BranchLessThanOp>(
175 pre_compute: &[u8],
176 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
177) {
178 let pre_compute: &E2PreCompute<BranchLePreCompute> = pre_compute.borrow();
179 vm_state
180 .ctx
181 .on_height_change(pre_compute.chip_idx as usize, 1);
182 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, vm_state);
183}
184
185trait BranchLessThanOp {
186 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool;
187}
188struct BltOp;
189struct BltuOp;
190struct BgeOp;
191struct BgeuOp;
192
193impl BranchLessThanOp for BltOp {
194 #[inline(always)]
195 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
196 let rs1 = i32::from_le_bytes(rs1);
197 let rs2 = i32::from_le_bytes(rs2);
198 rs1 < rs2
199 }
200}
201impl BranchLessThanOp for BltuOp {
202 #[inline(always)]
203 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
204 let rs1 = u32::from_le_bytes(rs1);
205 let rs2 = u32::from_le_bytes(rs2);
206 rs1 < rs2
207 }
208}
209impl BranchLessThanOp for BgeOp {
210 #[inline(always)]
211 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
212 let rs1 = i32::from_le_bytes(rs1);
213 let rs2 = i32::from_le_bytes(rs2);
214 rs1 >= rs2
215 }
216}
217impl BranchLessThanOp for BgeuOp {
218 #[inline(always)]
219 fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
220 let rs1 = u32::from_le_bytes(rs1);
221 let rs2 = u32::from_le_bytes(rs2);
222 rs1 >= rs2
223 }
224}