openvm_rv32im_circuit/branch_lt/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode,
10};
11use openvm_rv32im_transpiler::BranchLessThanOpcode;
12use openvm_stark_backend::p3_field::PrimeField32;
13
14use super::core::BranchLessThanExecutor;
15
16#[derive(AlignedBytesBorrow, Clone)]
17#[repr(C)]
18struct BranchLePreCompute {
19    imm: isize,
20    a: u8,
21    b: u8,
22}
23
24macro_rules! dispatch {
25    ($execute_impl:ident, $local_opcode:ident) => {
26        match $local_opcode {
27            BranchLessThanOpcode::BLT => Ok($execute_impl::<_, _, BltOp>),
28            BranchLessThanOpcode::BLTU => Ok($execute_impl::<_, _, BltuOp>),
29            BranchLessThanOpcode::BGE => Ok($execute_impl::<_, _, BgeOp>),
30            BranchLessThanOpcode::BGEU => Ok($execute_impl::<_, _, BgeuOp>),
31        }
32    };
33}
34
35impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize>
36    BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
37{
38    #[inline(always)]
39    fn pre_compute_impl<F: PrimeField32>(
40        &self,
41        pc: u32,
42        inst: &Instruction<F>,
43        data: &mut BranchLePreCompute,
44    ) -> Result<BranchLessThanOpcode, StaticProgramError> {
45        let &Instruction {
46            opcode, a, b, c, d, ..
47        } = inst;
48        let local_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset));
49        let c = c.as_canonical_u32();
50        let imm = if F::ORDER_U32 - c < c {
51            -((F::ORDER_U32 - c) as isize)
52        } else {
53            c as isize
54        };
55        if d.as_canonical_u32() != RV32_REGISTER_AS {
56            return Err(StaticProgramError::InvalidInstruction(pc));
57        }
58        *data = BranchLePreCompute {
59            imm,
60            a: a.as_canonical_u32() as u8,
61            b: b.as_canonical_u32() as u8,
62        };
63        Ok(local_opcode)
64    }
65}
66
67impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> Executor<F>
68    for BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
69where
70    F: PrimeField32,
71{
72    #[inline(always)]
73    fn pre_compute_size(&self) -> usize {
74        size_of::<BranchLePreCompute>()
75    }
76
77    #[inline(always)]
78    fn pre_compute<Ctx: ExecutionCtxTrait>(
79        &self,
80        pc: u32,
81        inst: &Instruction<F>,
82        data: &mut [u8],
83    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
84        let data: &mut BranchLePreCompute = data.borrow_mut();
85        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
86        dispatch!(execute_e1_impl, local_opcode)
87    }
88
89    #[cfg(feature = "tco")]
90    fn handler<Ctx>(
91        &self,
92        pc: u32,
93        inst: &Instruction<F>,
94        data: &mut [u8],
95    ) -> Result<Handler<F, Ctx>, StaticProgramError>
96    where
97        Ctx: ExecutionCtxTrait,
98    {
99        let data: &mut BranchLePreCompute = data.borrow_mut();
100        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
101        dispatch!(execute_e1_tco_handler, local_opcode)
102    }
103}
104
105impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> MeteredExecutor<F>
106    for BranchLessThanExecutor<A, NUM_LIMBS, LIMB_BITS>
107where
108    F: PrimeField32,
109{
110    fn metered_pre_compute_size(&self) -> usize {
111        size_of::<E2PreCompute<BranchLePreCompute>>()
112    }
113
114    fn metered_pre_compute<Ctx>(
115        &self,
116        chip_idx: usize,
117        pc: u32,
118        inst: &Instruction<F>,
119        data: &mut [u8],
120    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
121    where
122        Ctx: MeteredExecutionCtxTrait,
123    {
124        let data: &mut E2PreCompute<BranchLePreCompute> = data.borrow_mut();
125        data.chip_idx = chip_idx as u32;
126        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
127        dispatch!(execute_e2_impl, local_opcode)
128    }
129
130    #[cfg(feature = "tco")]
131    fn metered_handler<Ctx>(
132        &self,
133        chip_idx: usize,
134        pc: u32,
135        inst: &Instruction<F>,
136        data: &mut [u8],
137    ) -> Result<Handler<F, Ctx>, StaticProgramError>
138    where
139        Ctx: MeteredExecutionCtxTrait,
140    {
141        let data: &mut E2PreCompute<BranchLePreCompute> = data.borrow_mut();
142        data.chip_idx = chip_idx as u32;
143        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
144        dispatch!(execute_e2_tco_handler, local_opcode)
145    }
146}
147
148#[inline(always)]
149unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
150    pre_compute: &BranchLePreCompute,
151    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
152) {
153    let rs1 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
154    let rs2 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
155    let jmp = <OP as BranchLessThanOp>::compute(rs1, rs2);
156    if jmp {
157        vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32;
158    } else {
159        vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
160    };
161    vm_state.instret += 1;
162}
163
164#[create_tco_handler]
165unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
166    pre_compute: &[u8],
167    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
168) {
169    let pre_compute: &BranchLePreCompute = pre_compute.borrow();
170    execute_e12_impl::<F, CTX, OP>(pre_compute, vm_state);
171}
172
173#[create_tco_handler]
174unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: BranchLessThanOp>(
175    pre_compute: &[u8],
176    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
177) {
178    let pre_compute: &E2PreCompute<BranchLePreCompute> = pre_compute.borrow();
179    vm_state
180        .ctx
181        .on_height_change(pre_compute.chip_idx as usize, 1);
182    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, vm_state);
183}
184
185trait BranchLessThanOp {
186    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool;
187}
188struct BltOp;
189struct BltuOp;
190struct BgeOp;
191struct BgeuOp;
192
193impl BranchLessThanOp for BltOp {
194    #[inline(always)]
195    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
196        let rs1 = i32::from_le_bytes(rs1);
197        let rs2 = i32::from_le_bytes(rs2);
198        rs1 < rs2
199    }
200}
201impl BranchLessThanOp for BltuOp {
202    #[inline(always)]
203    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
204        let rs1 = u32::from_le_bytes(rs1);
205        let rs2 = u32::from_le_bytes(rs2);
206        rs1 < rs2
207    }
208}
209impl BranchLessThanOp for BgeOp {
210    #[inline(always)]
211    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
212        let rs1 = i32::from_le_bytes(rs1);
213        let rs2 = i32::from_le_bytes(rs2);
214        rs1 >= rs2
215    }
216}
217impl BranchLessThanOp for BgeuOp {
218    #[inline(always)]
219    fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool {
220        let rs1 = u32::from_le_bytes(rs1);
221        let rs2 = u32::from_le_bytes(rs2);
222        rs1 >= rs2
223    }
224}