openvm_rv32im_circuit/branch_eq/
execution.rs1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
7use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8use openvm_instructions::{
9 instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode,
10};
11use openvm_rv32im_transpiler::BranchEqualOpcode;
12use openvm_stark_backend::p3_field::PrimeField32;
13
14use super::BranchEqualExecutor;
15
16#[derive(AlignedBytesBorrow, Clone)]
17#[repr(C)]
18struct BranchEqualPreCompute {
19 imm: isize,
20 a: u8,
21 b: u8,
22}
23
24impl<A, const NUM_LIMBS: usize> BranchEqualExecutor<A, NUM_LIMBS> {
25 #[inline(always)]
27 fn pre_compute_impl<F: PrimeField32>(
28 &self,
29 pc: u32,
30 inst: &Instruction<F>,
31 data: &mut BranchEqualPreCompute,
32 ) -> Result<bool, StaticProgramError> {
33 let data: &mut BranchEqualPreCompute = data.borrow_mut();
34 let &Instruction {
35 opcode, a, b, c, d, ..
36 } = inst;
37 let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset));
38 let c = c.as_canonical_u32();
39 let imm = if F::ORDER_U32 - c < c {
40 -((F::ORDER_U32 - c) as isize)
41 } else {
42 c as isize
43 };
44 if d.as_canonical_u32() != RV32_REGISTER_AS {
45 return Err(StaticProgramError::InvalidInstruction(pc));
46 }
47 *data = BranchEqualPreCompute {
48 imm,
49 a: a.as_canonical_u32() as u8,
50 b: b.as_canonical_u32() as u8,
51 };
52 Ok(local_opcode == BranchEqualOpcode::BNE)
53 }
54}
55
56macro_rules! dispatch {
57 ($execute_impl:ident, $is_bne:ident) => {
58 if $is_bne {
59 Ok($execute_impl::<_, _, true>)
60 } else {
61 Ok($execute_impl::<_, _, false>)
62 }
63 };
64}
65
66impl<F, A, const NUM_LIMBS: usize> Executor<F> for BranchEqualExecutor<A, NUM_LIMBS>
67where
68 F: PrimeField32,
69{
70 #[inline(always)]
71 fn pre_compute_size(&self) -> usize {
72 size_of::<BranchEqualPreCompute>()
73 }
74
75 #[inline(always)]
76 fn pre_compute<Ctx: ExecutionCtxTrait>(
77 &self,
78 pc: u32,
79 inst: &Instruction<F>,
80 data: &mut [u8],
81 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
82 let data: &mut BranchEqualPreCompute = data.borrow_mut();
83 let is_bne = self.pre_compute_impl(pc, inst, data)?;
84 dispatch!(execute_e1_impl, is_bne)
85 }
86
87 #[cfg(feature = "tco")]
88 fn handler<Ctx>(
89 &self,
90 pc: u32,
91 inst: &Instruction<F>,
92 data: &mut [u8],
93 ) -> Result<Handler<F, Ctx>, StaticProgramError>
94 where
95 Ctx: ExecutionCtxTrait,
96 {
97 let data: &mut BranchEqualPreCompute = data.borrow_mut();
98 let is_bne = self.pre_compute_impl(pc, inst, data)?;
99 dispatch!(execute_e1_tco_handler, is_bne)
100 }
101}
102
103impl<F, A, const NUM_LIMBS: usize> MeteredExecutor<F> for BranchEqualExecutor<A, NUM_LIMBS>
104where
105 F: PrimeField32,
106{
107 fn metered_pre_compute_size(&self) -> usize {
108 size_of::<E2PreCompute<BranchEqualPreCompute>>()
109 }
110
111 fn metered_pre_compute<Ctx>(
112 &self,
113 chip_idx: usize,
114 pc: u32,
115 inst: &Instruction<F>,
116 data: &mut [u8],
117 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
118 where
119 Ctx: MeteredExecutionCtxTrait,
120 {
121 let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
122 data.chip_idx = chip_idx as u32;
123 let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
124 dispatch!(execute_e2_impl, is_bne)
125 }
126
127 #[cfg(feature = "tco")]
128 fn metered_handler<Ctx>(
129 &self,
130 chip_idx: usize,
131 pc: u32,
132 inst: &Instruction<F>,
133 data: &mut [u8],
134 ) -> Result<Handler<F, Ctx>, StaticProgramError>
135 where
136 Ctx: MeteredExecutionCtxTrait,
137 {
138 let data: &mut E2PreCompute<BranchEqualPreCompute> = data.borrow_mut();
139 data.chip_idx = chip_idx as u32;
140 let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?;
141 dispatch!(execute_e2_tco_handler, is_bne)
142 }
143}
144
145#[inline(always)]
146unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
147 pre_compute: &BranchEqualPreCompute,
148 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
149) {
150 let rs1 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
151 let rs2 = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
152 if (rs1 == rs2) ^ IS_NE {
153 vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32;
154 } else {
155 vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
156 }
157 vm_state.instret += 1;
158}
159
160#[create_tco_handler]
161unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
162 pre_compute: &[u8],
163 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
164) {
165 let pre_compute: &BranchEqualPreCompute = pre_compute.borrow();
166 execute_e12_impl::<F, CTX, IS_NE>(pre_compute, vm_state);
167}
168
169#[create_tco_handler]
170unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
171 pre_compute: &[u8],
172 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
173) {
174 let pre_compute: &E2PreCompute<BranchEqualPreCompute> = pre_compute.borrow();
175 vm_state
176 .ctx
177 .on_height_change(pre_compute.chip_idx as usize, 1);
178 execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, vm_state);
179}