openvm_bigint_circuit/
branch_eq.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32BranchEqual256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7    instruction::Instruction,
8    program::DEFAULT_PC_STEP,
9    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10    LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
13use openvm_rv32im_circuit::BranchEqualExecutor;
14use openvm_rv32im_transpiler::BranchEqualOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{common::bytes_to_u64_array, Rv32BranchEqual256Executor, INT256_NUM_LIMBS};
18
19type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
20
21impl Rv32BranchEqual256Executor {
22    pub fn new(adapter_step: AdapterExecutor, offset: usize, pc_step: u32) -> Self {
23        Self(BranchEqualExecutor::new(adapter_step, offset, pc_step))
24    }
25}
26
27#[derive(AlignedBytesBorrow, Clone)]
28#[repr(C)]
29struct BranchEqPreCompute {
30    imm: isize,
31    a: u8,
32    b: u8,
33}
34
35macro_rules! dispatch {
36    ($execute_impl:ident, $local_opcode:ident) => {
37        match $local_opcode {
38            BranchEqualOpcode::BEQ => Ok($execute_impl::<_, _, false>),
39            BranchEqualOpcode::BNE => Ok($execute_impl::<_, _, true>),
40        }
41    };
42}
43
44impl<F: PrimeField32> Executor<F> for Rv32BranchEqual256Executor {
45    fn pre_compute_size(&self) -> usize {
46        size_of::<BranchEqPreCompute>()
47    }
48
49    #[cfg(not(feature = "tco"))]
50    fn pre_compute<Ctx>(
51        &self,
52        pc: u32,
53        inst: &Instruction<F>,
54        data: &mut [u8],
55    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
56    where
57        Ctx: ExecutionCtxTrait,
58    {
59        let data: &mut BranchEqPreCompute = data.borrow_mut();
60        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
61        dispatch!(execute_e1_handler, local_opcode)
62    }
63
64    #[cfg(feature = "tco")]
65    fn handler<Ctx>(
66        &self,
67        pc: u32,
68        inst: &Instruction<F>,
69        data: &mut [u8],
70    ) -> Result<Handler<F, Ctx>, StaticProgramError>
71    where
72        Ctx: ExecutionCtxTrait,
73    {
74        let data: &mut BranchEqPreCompute = data.borrow_mut();
75        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
76        dispatch!(execute_e1_handler, local_opcode)
77    }
78}
79
80impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchEqual256Executor {
81    fn metered_pre_compute_size(&self) -> usize {
82        size_of::<E2PreCompute<BranchEqPreCompute>>()
83    }
84
85    #[cfg(not(feature = "tco"))]
86    fn metered_pre_compute<Ctx>(
87        &self,
88        chip_idx: usize,
89        pc: u32,
90        inst: &Instruction<F>,
91        data: &mut [u8],
92    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
93    where
94        Ctx: MeteredExecutionCtxTrait,
95    {
96        let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
97        data.chip_idx = chip_idx as u32;
98        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
99        dispatch!(execute_e2_handler, local_opcode)
100    }
101
102    #[cfg(feature = "tco")]
103    fn metered_handler<Ctx>(
104        &self,
105        chip_idx: usize,
106        pc: u32,
107        inst: &Instruction<F>,
108        data: &mut [u8],
109    ) -> Result<Handler<F, Ctx>, StaticProgramError>
110    where
111        Ctx: MeteredExecutionCtxTrait,
112    {
113        let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
114        data.chip_idx = chip_idx as u32;
115        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
116        dispatch!(execute_e2_handler, local_opcode)
117    }
118}
119
120#[inline(always)]
121unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
122    pre_compute: &BranchEqPreCompute,
123    instret: &mut u64,
124    pc: &mut u32,
125    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
126) {
127    let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
128    let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
129    let rs1 =
130        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
131    let rs2 =
132        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
133    let cmp_result = u256_eq(rs1, rs2);
134    if cmp_result ^ IS_NE {
135        *pc = (*pc as isize + pre_compute.imm) as u32;
136    } else {
137        *pc = pc.wrapping_add(DEFAULT_PC_STEP);
138    }
139
140    *instret += 1;
141}
142
143#[create_handler]
144#[inline(always)]
145unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
146    pre_compute: &[u8],
147    instret: &mut u64,
148    pc: &mut u32,
149    _instret_end: u64,
150    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
151) {
152    let pre_compute: &BranchEqPreCompute = pre_compute.borrow();
153    execute_e12_impl::<F, CTX, IS_NE>(pre_compute, instret, pc, exec_state);
154}
155
156#[create_handler]
157#[inline(always)]
158unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
159    pre_compute: &[u8],
160    instret: &mut u64,
161    pc: &mut u32,
162    _arg: u64,
163    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
164) {
165    let pre_compute: &E2PreCompute<BranchEqPreCompute> = pre_compute.borrow();
166    exec_state
167        .ctx
168        .on_height_change(pre_compute.chip_idx as usize, 1);
169    execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, instret, pc, exec_state);
170}
171
172impl Rv32BranchEqual256Executor {
173    fn pre_compute_impl<F: PrimeField32>(
174        &self,
175        pc: u32,
176        inst: &Instruction<F>,
177        data: &mut BranchEqPreCompute,
178    ) -> Result<BranchEqualOpcode, StaticProgramError> {
179        let Instruction {
180            opcode,
181            a,
182            b,
183            c,
184            d,
185            e,
186            ..
187        } = inst;
188        let c = c.as_canonical_u32();
189        let imm = if F::ORDER_U32 - c < c {
190            -((F::ORDER_U32 - c) as isize)
191        } else {
192            c as isize
193        };
194        let e_u32 = e.as_canonical_u32();
195        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
196            return Err(StaticProgramError::InvalidInstruction(pc));
197        }
198        *data = BranchEqPreCompute {
199            imm,
200            a: a.as_canonical_u32() as u8,
201            b: b.as_canonical_u32() as u8,
202        };
203        let local_opcode = BranchEqualOpcode::from_usize(
204            opcode.local_opcode_idx(Rv32BranchEqual256Opcode::CLASS_OFFSET),
205        );
206        Ok(local_opcode)
207    }
208}
209
210fn u256_eq(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
211    let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
212    let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
213    for i in 0..4 {
214        if rs1_u64[i] != rs2_u64[i] {
215            return false;
216        }
217    }
218    true
219}