openvm_rv32im_circuit/branch_lt/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode,
10};
11use openvm_rv32im_transpiler::BranchLessThanOpcode;
12use openvm_stark_backend::p3_field::PrimeField32;
13
14use super::core::BranchLessThanExecutor;
15
16#[derive(AlignedBytesBorrow, Clone)]
17#[repr(C)]
18struct BranchLePreCompute {
19    imm: isize,
20    a: u8,
21    b: u8,
22}
23
24macro_rules! dispatch {
25    ($execute_impl:ident, $local_opcode:ident) => {
26        match $local_opcode {
27            BranchLessThanOpcode::BLT => Ok($execute_impl::<_, _, BltOp>),
28            BranchLessThanOpcode::BLTU => Ok($execute_impl::<_, _, BltuOp>),
29            BranchLessThanOpcode::BGE => Ok($execute_impl::<_, _, BgeOp>),
30            BranchLessThanOpcode::BGEU => Ok($execute_impl::<_, _, BgeuOp>),
31        }
32    };
33}
34
35impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize>
36    BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
37{
38    #[inline(always)]
39    fn pre_compute_impl<F: PrimeField32>(
40        &self,
41        pc: u32,
42        inst: &Instruction<F>,
43        data: &mut BranchLePreCompute,
44    ) -> Result<BranchLessThanOpcode, StaticProgramError> {
45        let &Instruction {
46            opcode, a, b, c, d, ..
47        } = inst;
48        let local_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset));
49        let c = c.as_canonical_u32();
50        let imm = if F::ORDER_U32 - c < c {
51            -((F::ORDER_U32 - c) as isize)
52        } else {
53            c as isize
54        };
55        if d.as_canonical_u32() != RV32_REGISTER_AS {
56            return Err(StaticProgramError::InvalidInstruction(pc));
57        }
58        *data = BranchLePreCompute {
59            imm,
60            a: a.as_canonical_u32() as u8,
61            b: b.as_canonical_u32() as u8,
62        };
63        Ok(local_opcode)
64    }
65}
66
67impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> Executor<F>
68    for BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
69where
70    F: PrimeField32,
71{
72    #[inline(always)]
73    fn pre_compute_size(&self) -> usize {
74        size_of::<BranchLePreCompute>()
75    }
76
77    #[inline(always)]
78    #[cfg(not(feature = "tco"))]
79    fn pre_compute<Ctx: ExecutionCtxTrait>(
80        &self,
81        pc: u32,
82        inst: &Instruction<F>,
83        data: &mut [u8],
84    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
85        let data: &mut BranchLePreCompute = data.borrow_mut();
86        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
87        dispatch!(execute_e1_handler, local_opcode)
88    }
89
90    #[cfg(feature = "tco")]
91    fn handler<Ctx>(
92        &self,
93        pc: u32,
94        inst: &Instruction<F>,
95        data: &mut [u8],
96    ) -> Result<Handler<F, Ctx>, StaticProgramError>
97    where
98        Ctx: ExecutionCtxTrait,
99    {
100        let data: &mut BranchLePreCompute = data.borrow_mut();
101        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
102        dispatch!(execute_e1_handler, local_opcode)
103    }
104}
105
106impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> MeteredExecutor<F>
107    for BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
108where
109    F: PrimeField32,
110{
111    fn metered_pre_compute_size(&self) -> usize {
112        size_of::<E2PreCompute<BranchLePreCompute>>()
113    }
114
115    #[cfg(not(feature = "tco"))]
116    fn metered_pre_compute<Ctx>(
117        &self,
118        chip_idx: usize,
119        pc: u32,
120        inst: &Instruction<F>,
121        data: &mut [u8],
122    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
123    where
124        Ctx: MeteredExecutionCtxTrait,
125    {
126        let data: &mut E2PreCompute<BranchLePreCompute> = data.borrow_mut();
127        data.chip_idx = chip_idx as u32;
128        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
129        dispatch!(execute_e2_handler, local_opcode)
130    }
131
132    #[cfg(feature = "tco")]
133    fn metered_handler<Ctx>(
134        &self,
135        chip_idx: usize,
136        pc: u32,
137        inst: &Instruction<F>,
138        data: &mut [u8],
139    ) -> Result<Handler<F, Ctx>, StaticProgramError>
140    where
141        Ctx: MeteredExecutionCtxTrait,
142    {
143        let data: &mut E2PreCompute<BranchLePreCompute> = data.borrow_mut();
144        data.chip_idx = chip_idx as u32;
145        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
146        dispatch!(execute_e2_handler, local_opcode)
147    }
148}
149
150#[inline(always)]
151unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
152    pre_compute: &BranchLePreCompute,
153    instret: &mut u64,
154    pc: &mut u32,
155    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
156) {
157    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
158    let rs2 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
159    let jmp = <OP as BranchLessThanOp>::compute(rs1, rs2);
160    if jmp {
161        *pc = (*pc as isize + pre_compute.imm) as u32;
162    } else {
163        *pc = pc.wrapping_add(DEFAULT_PC_STEP);
164    };
165    *instret += 1;
166}
167
168#[create_handler]
169#[inline(always)]
170unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
171    pre_compute: &[u8],
172    instret: &mut u64,
173    pc: &mut u32,
174    _instret_end: u64,
175    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
176) {
177    let pre_compute: &BranchLePreCompute = pre_compute.borrow();
178    execute_e12_impl::<F, CTX, OP>(pre_compute, instret, pc, exec_state);
179}
180
181#[create_handler]
182#[inline(always)]
183unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: BranchLessThanOp>(
184    pre_compute: &[u8],
185    instret: &mut u64,
186    pc: &mut u32,
187    _arg: u64,
188    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
189) {
190    let pre_compute: &E2PreCompute<BranchLePreCompute> = pre_compute.borrow();
191    exec_state
192        .ctx
193        .on_height_change(pre_compute.chip_idx as usize, 1);
194    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, instret, pc, exec_state);
195}
196
197trait BranchLessThanOp {
198    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool;
199}
200struct BltOp;
201struct BltuOp;
202struct BgeOp;
203struct BgeuOp;
204
205impl BranchLessThanOp for BltOp {
206    #[inline(always)]
207    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
208        let rs1 = i32::from_le_bytes(rs1);
209        let rs2 = i32::from_le_bytes(rs2);
210        rs1 < rs2
211    }
212}
213impl BranchLessThanOp for BltuOp {
214    #[inline(always)]
215    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
216        let rs1 = u32::from_le_bytes(rs1);
217        let rs2 = u32::from_le_bytes(rs2);
218        rs1 < rs2
219    }
220}
221impl BranchLessThanOp for BgeOp {
222    #[inline(always)]
223    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
224        let rs1 = i32::from_le_bytes(rs1);
225        let rs2 = i32::from_le_bytes(rs2);
226        rs1 >= rs2
227    }
228}
229impl BranchLessThanOp for BgeuOp {
230    #[inline(always)]
231    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
232        let rs1 = u32::from_le_bytes(rs1);
233        let rs2 = u32::from_le_bytes(rs2);
234        rs1 >= rs2
235    }
236}