openvm_rv32im_circuit/branch_eq/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode,
10};
11use openvm_rv32im_transpiler::BranchEqualOpcode;
12use openvm_stark_backend::p3_field::PrimeField32;
13
14use super::BranchEqualExecutor;
15
16#[derive(AlignedBytesBorrow, Clone)]
17#[repr(C)]
18struct BranchEqualPreCompute {
19    imm: isize,
20    a: u8,
21    b: u8,
22}
23
24impl<A, const NUM_LIMBS: usize> BranchEqualExecutor<A, NUM_LIMBS> {
25    /// Return `is_bne`, true if the local opcode is BNE.
26    #[inline(always)]
27    fn pre_compute_impl<F: PrimeField32>(
28        &self,
29        pc: u32,
30        inst: &Instruction<F>,
31        data: &mut BranchEqualPreCompute,
32    ) -> Result<bool, StaticProgramError> {
33        let data: &mut BranchEqualPreCompute = data.borrow_mut();
34        let &Instruction {
35            opcode, a, b, c, d, ..
36        } = inst;
37        let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset));
38        let c = c.as_canonical_u32();
39        let imm = if F::ORDER_U32 - c < c {
40            -((F::ORDER_U32 - c) as isize)
41        } else {
42            c as isize
43        };
44        if d.as_canonical_u32() != RV32_REGISTER_AS {
45            return Err(StaticProgramError::InvalidInstruction(pc));
46        }
47        *data = BranchEqualPreCompute {
48            imm,
49            a: a.as_canonical_u32() as u8,
50            b: b.as_canonical_u32() as u8,
51        };
52        Ok(local_opcode == BranchEqualOpcode::BNE)
53    }
54}
55
56macro_rules! dispatch {
57    ($execute_impl:ident, $is_bne:ident) => {
58        if $is_bne {
59            Ok($execute_impl::<_, _, true>)
60        } else {
61            Ok($execute_impl::<_, _, false>)
62        }
63    };
64}
65
66impl<F, A, const NUM_LIMBS: usize> Executor<F> for BranchEqualExecutor<A, NUM_LIMBS>
67where
68    F: PrimeField32,
69{
70    #[inline(always)]
71    fn pre_compute_size(&self) -> usize {
72        size_of::<BranchEqualPreCompute>()
73    }
74
75    #[inline(always)]
76    fn pre_compute<Ctx: ExecutionCtxTrait>(
77        &self,
78        pc: u32,
79        inst: &Instruction<F>,
80        data: &mut [u8],
81    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
82        let data: &mut BranchEqualPreCompute = data.borrow_mut();
83        let is_bne = self.pre_compute_impl(pc, inst, data)?;
84        dispatch!(execute_e1_impl, is_bne)
85    }
86
87    #[cfg(feature = "tco")]
88    fn handler<Ctx>(
89        &self,
90        pc: u32,
91        inst: &Instruction<F>,
92        data: &mut [u8],
93    ) -> Result<Handler<F, Ctx>, StaticProgramError>
94    where
95        Ctx: ExecutionCtxTrait,
96    {
97        let data: &mut BranchEqualPreCompute = data.borrow_mut();
98        let is_bne = self.pre_compute_impl(pc, inst, data)?;
99        dispatch!(execute_e1_tco_handler, is_bne)
100    }
101}
102
103impl<F, A, const NUM_LIMBS: usize> MeteredExecutor<F> for BranchEqualExecutor<A, NUM_LIMBS>
104where
105    F: PrimeField32,
106{
107    fn metered_pre_compute_size(&self) -> usize {
108        size_of::<E2PreCompute<BranchEqualPreCompute>>()
109    }
110
111    fn metered_pre_compute<Ctx>(
112        &self,
113        chip_idx: usize,
114        pc: u32,
115        inst: &Instruction<F>,
116        data: &mut [u8],
117    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
118    where
119        Ctx: MeteredExecutionCtxTrait,
120    {
121        let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
122        data.chip_idx = chip_idx as u32;
123        let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
124        dispatch!(execute_e2_impl, is_bne)
125    }
126
127    #[cfg(feature = "tco")]
128    fn metered_handler<Ctx>(
129        &self,
130        chip_idx: usize,
131        pc: u32,
132        inst: &Instruction<F>,
133        data: &mut [u8],
134    ) -> Result<Handler<F, Ctx>, StaticProgramError>
135    where
136        Ctx: MeteredExecutionCtxTrait,
137    {
138        let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
139        data.chip_idx = chip_idx as u32;
140        let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
141        dispatch!(execute_e2_tco_handler, is_bne)
142    }
143}
144
145#[inline(always)]
146unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
147    pre_compute: &BranchEqualPreCompute,
148    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
149) {
150    let rs1 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
151    let rs2 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
152    if (rs1 == rs2) ^ IS_NE {
153        vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32;
154    } else {
155        vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
156    }
157    vm_state.instret += 1;
158}
159
160#[create_tco_handler]
161unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
162    pre_compute: &[u8],
163    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
164) {
165    let pre_compute: &BranchEqualPreCompute = pre_compute.borrow();
166    execute_e12_impl::<F, CTX, IS_NE>(pre_compute, vm_state);
167}
168
169#[create_tco_handler]
170unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
171    pre_compute: &[u8],
172    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
173) {
174    let pre_compute: &E2PreCompute<BranchEqualPreCompute> = pre_compute.borrow();
175    vm_state
176        .ctx
177        .on_height_change(pre_compute.chip_idx as usize, 1);
178    execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, vm_state);
179}