openvm_rv32im_circuit/branch_eq/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode,
10};
11use openvm_rv32im_transpiler::BranchEqualOpcode;
12use openvm_stark_backend::p3_field::PrimeField32;
13
14use super::BranchEqualExecutor;
15
16#[derive(AlignedBytesBorrow, Clone)]
17#[repr(C)]
18struct BranchEqualPreCompute {
19    imm: isize,
20    a: u8,
21    b: u8,
22}
23
24impl<A, const NUM_LIMBS: usize> BranchEqualExecutor<A, NUM_LIMBS> {
25    /// Return `is_bne`, true if the local opcode is BNE.
26    #[inline(always)]
27    fn pre_compute_impl<F: PrimeField32>(
28        &self,
29        pc: u32,
30        inst: &Instruction<F>,
31        data: &mut BranchEqualPreCompute,
32    ) -> Result<bool, StaticProgramError> {
33        let data: &mut BranchEqualPreCompute = data.borrow_mut();
34        let &Instruction {
35            opcode, a, b, c, d, ..
36        } = inst;
37        let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset));
38        let c = c.as_canonical_u32();
39        let imm = if F::ORDER_U32 - c < c {
40            -((F::ORDER_U32 - c) as isize)
41        } else {
42            c as isize
43        };
44        if d.as_canonical_u32() != RV32_REGISTER_AS {
45            return Err(StaticProgramError::InvalidInstruction(pc));
46        }
47        *data = BranchEqualPreCompute {
48            imm,
49            a: a.as_canonical_u32() as u8,
50            b: b.as_canonical_u32() as u8,
51        };
52        Ok(local_opcode == BranchEqualOpcode::BNE)
53    }
54}
55
56macro_rules! dispatch {
57    ($execute_impl:ident, $is_bne:ident) => {
58        if $is_bne {
59            Ok($execute_impl::<_, _, true>)
60        } else {
61            Ok($execute_impl::<_, _, false>)
62        }
63    };
64}
65
66impl<F, A, const NUM_LIMBS: usize> Executor<F> for BranchEqualExecutor<A, NUM_LIMBS>
67where
68    F: PrimeField32,
69{
70    #[inline(always)]
71    fn pre_compute_size(&self) -> usize {
72        size_of::<BranchEqualPreCompute>()
73    }
74
75    #[cfg(not(feature = "tco"))]
76    #[inline(always)]
77    fn pre_compute<Ctx: ExecutionCtxTrait>(
78        &self,
79        pc: u32,
80        inst: &Instruction<F>,
81        data: &mut [u8],
82    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
83        let data: &mut BranchEqualPreCompute = data.borrow_mut();
84        let is_bne = self.pre_compute_impl(pc, inst, data)?;
85        dispatch!(execute_e1_handler, is_bne)
86    }
87
88    #[cfg(feature = "tco")]
89    fn handler<Ctx>(
90        &self,
91        pc: u32,
92        inst: &Instruction<F>,
93        data: &mut [u8],
94    ) -> Result<Handler<F, Ctx>, StaticProgramError>
95    where
96        Ctx: ExecutionCtxTrait,
97    {
98        let data: &mut BranchEqualPreCompute = data.borrow_mut();
99        let is_bne = self.pre_compute_impl(pc, inst, data)?;
100        dispatch!(execute_e1_handler, is_bne)
101    }
102}
103
104impl<F, A, const NUM_LIMBS: usize> MeteredExecutor<F> for BranchEqualExecutor<A, NUM_LIMBS>
105where
106    F: PrimeField32,
107{
108    fn metered_pre_compute_size(&self) -> usize {
109        size_of::<E2PreCompute<BranchEqualPreCompute>>()
110    }
111
112    #[cfg(not(feature = "tco"))]
113    fn metered_pre_compute<Ctx>(
114        &self,
115        chip_idx: usize,
116        pc: u32,
117        inst: &Instruction<F>,
118        data: &mut [u8],
119    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
120    where
121        Ctx: MeteredExecutionCtxTrait,
122    {
123        let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
124        data.chip_idx = chip_idx as u32;
125        let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
126        dispatch!(execute_e2_handler, is_bne)
127    }
128
129    #[cfg(feature = "tco")]
130    fn metered_handler<Ctx>(
131        &self,
132        chip_idx: usize,
133        pc: u32,
134        inst: &Instruction<F>,
135        data: &mut [u8],
136    ) -> Result<Handler<F, Ctx>, StaticProgramError>
137    where
138        Ctx: MeteredExecutionCtxTrait,
139    {
140        let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
141        data.chip_idx = chip_idx as u32;
142        let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
143        dispatch!(execute_e2_handler, is_bne)
144    }
145}
146
147#[inline(always)]
148unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
149    pre_compute: &BranchEqualPreCompute,
150    instret: &mut u64,
151    pc: &mut u32,
152    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
153) {
154    let rs1 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
155    let rs2 = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
156    if (rs1 == rs2) ^ IS_NE {
157        *pc = (*pc as isize + pre_compute.imm) as u32;
158    } else {
159        *pc = pc.wrapping_add(DEFAULT_PC_STEP);
160    }
161    *instret += 1;
162}
163
164#[create_handler]
165#[inline(always)]
166unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
167    pre_compute: &[u8],
168    instret: &mut u64,
169    pc: &mut u32,
170    _instret_end: u64,
171    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
172) {
173    let pre_compute: &BranchEqualPreCompute = pre_compute.borrow();
174    execute_e12_impl::<F, CTX, IS_NE>(pre_compute, instret, pc, exec_state);
175}
176
177#[create_handler]
178#[inline(always)]
179unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
180    pre_compute: &[u8],
181    instret: &mut u64,
182    pc: &mut u32,
183    _arg: u64,
184    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
185) {
186    let pre_compute: &E2PreCompute<BranchEqualPreCompute> = pre_compute.borrow();
187    exec_state
188        .ctx
189        .on_height_change(pre_compute.chip_idx as usize, 1);
190    execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, instret, pc, exec_state);
191}