openvm_bigint_circuit/
branch_eq.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32BranchEqual256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7    instruction::Instruction,
8    program::DEFAULT_PC_STEP,
9    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10    LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
13use openvm_rv32im_circuit::BranchEqualExecutor;
14use openvm_rv32im_transpiler::BranchEqualOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{common::bytes_to_u64_array, Rv32BranchEqual256Executor, INT256_NUM_LIMBS};
18
19type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
20
21impl Rv32BranchEqual256Executor {
22    pub fn new(adapter_step: AdapterExecutor, offset: usize, pc_step: u32) -> Self {
23        Self(BranchEqualExecutor::new(adapter_step, offset, pc_step))
24    }
25}
26
27#[derive(AlignedBytesBorrow, Clone)]
28#[repr(C)]
29struct BranchEqPreCompute {
30    imm: isize,
31    a: u8,
32    b: u8,
33}
34
35macro_rules! dispatch {
36    ($execute_impl:ident, $local_opcode:ident) => {
37        match $local_opcode {
38            BranchEqualOpcode::BEQ => Ok($execute_impl::<_, _, false>),
39            BranchEqualOpcode::BNE => Ok($execute_impl::<_, _, true>),
40        }
41    };
42}
43
44impl<F: PrimeField32> Executor<F> for Rv32BranchEqual256Executor {
45    fn pre_compute_size(&self) -> usize {
46        size_of::<BranchEqPreCompute>()
47    }
48
49    fn pre_compute<Ctx>(
50        &self,
51        pc: u32,
52        inst: &Instruction<F>,
53        data: &mut [u8],
54    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
55    where
56        Ctx: ExecutionCtxTrait,
57    {
58        let data: &mut BranchEqPreCompute = data.borrow_mut();
59        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
60        dispatch!(execute_e1_impl, local_opcode)
61    }
62
63    #[cfg(feature = "tco")]
64    fn handler<Ctx>(
65        &self,
66        pc: u32,
67        inst: &Instruction<F>,
68        data: &mut [u8],
69    ) -> Result<Handler<F, Ctx>, StaticProgramError>
70    where
71        Ctx: ExecutionCtxTrait,
72    {
73        let data: &mut BranchEqPreCompute = data.borrow_mut();
74        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
75        dispatch!(execute_e1_tco_handler, local_opcode)
76    }
77}
78
79impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchEqual256Executor {
80    fn metered_pre_compute_size(&self) -> usize {
81        size_of::<E2PreCompute<BranchEqPreCompute>>()
82    }
83
84    fn metered_pre_compute<Ctx>(
85        &self,
86        chip_idx: usize,
87        pc: u32,
88        inst: &Instruction<F>,
89        data: &mut [u8],
90    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
91    where
92        Ctx: MeteredExecutionCtxTrait,
93    {
94        let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
95        data.chip_idx = chip_idx as u32;
96        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
97        dispatch!(execute_e2_impl, local_opcode)
98    }
99
100    #[cfg(feature = "tco")]
101    fn metered_handler<Ctx>(
102        &self,
103        chip_idx: usize,
104        pc: u32,
105        inst: &Instruction<F>,
106        data: &mut [u8],
107    ) -> Result<Handler<F, Ctx>, StaticProgramError>
108    where
109        Ctx: MeteredExecutionCtxTrait,
110    {
111        let data: &mut E2PreCompute<BranchEqPreCompute> = data.borrow_mut();
112        data.chip_idx = chip_idx as u32;
113        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
114        dispatch!(execute_e2_tco_handler, local_opcode)
115    }
116}
117
118#[inline(always)]
119unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
120    pre_compute: &BranchEqPreCompute,
121    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
122) {
123    let rs1_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
124    let rs2_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
125    let rs1 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
126    let rs2 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
127    let cmp_result = u256_eq(rs1, rs2);
128    if cmp_result ^ IS_NE {
129        vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32;
130    } else {
131        vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
132    }
133
134    vm_state.instret += 1;
135}
136
137#[create_tco_handler]
138unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
139    pre_compute: &[u8],
140    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
141) {
142    let pre_compute: &BranchEqPreCompute = pre_compute.borrow();
143    execute_e12_impl::<F, CTX, IS_NE>(pre_compute, vm_state);
144}
145
146#[create_tco_handler]
147unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, const IS_NE: bool>(
148    pre_compute: &[u8],
149    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
150) {
151    let pre_compute: &E2PreCompute<BranchEqPreCompute> = pre_compute.borrow();
152    vm_state
153        .ctx
154        .on_height_change(pre_compute.chip_idx as usize, 1);
155    execute_e12_impl::<F, CTX, IS_NE>(&pre_compute.data, vm_state);
156}
157
158impl Rv32BranchEqual256Executor {
159    fn pre_compute_impl<F: PrimeField32>(
160        &self,
161        pc: u32,
162        inst: &Instruction<F>,
163        data: &mut BranchEqPreCompute,
164    ) -> Result<BranchEqualOpcode, StaticProgramError> {
165        let Instruction {
166            opcode,
167            a,
168            b,
169            c,
170            d,
171            e,
172            ..
173        } = inst;
174        let c = c.as_canonical_u32();
175        let imm = if F::ORDER_U32 - c < c {
176            -((F::ORDER_U32 - c) as isize)
177        } else {
178            c as isize
179        };
180        let e_u32 = e.as_canonical_u32();
181        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
182            return Err(StaticProgramError::InvalidInstruction(pc));
183        }
184        *data = BranchEqPreCompute {
185            imm,
186            a: a.as_canonical_u32() as u8,
187            b: b.as_canonical_u32() as u8,
188        };
189        let local_opcode = BranchEqualOpcode::from_usize(
190            opcode.local_opcode_idx(Rv32BranchEqual256Opcode::CLASS_OFFSET),
191        );
192        Ok(local_opcode)
193    }
194}
195
196fn u256_eq(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
197    let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
198    let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
199    for i in 0..4 {
200        if rs1_u64[i] != rs2_u64[i] {
201            return false;
202        }
203    }
204    true
205}