openvm_bigint_circuit/
branch_lt.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32BranchLessThan256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7    instruction::Instruction,
8    program::DEFAULT_PC_STEP,
9    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10    LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
13use openvm_rv32im_circuit::BranchLessThanExecutor;
14use openvm_rv32im_transpiler::BranchLessThanOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18    common::{i256_lt, u256_lt},
19    Rv32BranchLessThan256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
23
24impl Rv32BranchLessThan256Executor {
25    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26        Self(BranchLessThanExecutor::new(adapter, offset))
27    }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct BranchLtPreCompute {
33    imm: isize,
34    a: u8,
35    b: u8,
36}
37
38macro_rules! dispatch {
39    ($execute_impl:ident, $local_opcode:ident) => {
40        Ok(match $local_opcode {
41            BranchLessThanOpcode::BLT => $execute_impl::<_, _, BltOp>,
42            BranchLessThanOpcode::BLTU => $execute_impl::<_, _, BltuOp>,
43            BranchLessThanOpcode::BGE => $execute_impl::<_, _, BgeOp>,
44            BranchLessThanOpcode::BGEU => $execute_impl::<_, _, BgeuOp>,
45        })
46    };
47}
48
49impl<F: PrimeField32> Executor<F> for Rv32BranchLessThan256Executor {
50    fn pre_compute_size(&self) -> usize {
51        size_of::<BranchLtPreCompute>()
52    }
53
54    fn pre_compute<Ctx>(
55        &self,
56        pc: u32,
57        inst: &Instruction<F>,
58        data: &mut [u8],
59    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
60    where
61        Ctx: ExecutionCtxTrait,
62    {
63        let data: &mut BranchLtPreCompute = data.borrow_mut();
64        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
65        dispatch!(execute_e1_impl, local_opcode)
66    }
67
68    #[cfg(feature = "tco")]
69    fn handler<Ctx>(
70        &self,
71        pc: u32,
72        inst: &Instruction<F>,
73        data: &mut [u8],
74    ) -> Result<Handler<F, Ctx>, StaticProgramError>
75    where
76        Ctx: ExecutionCtxTrait,
77    {
78        let data: &mut BranchLtPreCompute = data.borrow_mut();
79        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
80        dispatch!(execute_e1_tco_handler, local_opcode)
81    }
82}
83
84impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchLessThan256Executor {
85    fn metered_pre_compute_size(&self) -> usize {
86        size_of::<E2PreCompute<BranchLtPreCompute>>()
87    }
88
89    fn metered_pre_compute<Ctx>(
90        &self,
91        chip_idx: usize,
92        pc: u32,
93        inst: &Instruction<F>,
94        data: &mut [u8],
95    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
96    where
97        Ctx: MeteredExecutionCtxTrait,
98    {
99        let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
100        data.chip_idx = chip_idx as u32;
101        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
102        dispatch!(execute_e2_impl, local_opcode)
103    }
104
105    #[cfg(feature = "tco")]
106    fn metered_handler<Ctx>(
107        &self,
108        chip_idx: usize,
109        pc: u32,
110        inst: &Instruction<F>,
111        data: &mut [u8],
112    ) -> Result<Handler<F, Ctx>, StaticProgramError>
113    where
114        Ctx: MeteredExecutionCtxTrait,
115    {
116        let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
117        data.chip_idx = chip_idx as u32;
118        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
119        dispatch!(execute_e2_tco_handler, local_opcode)
120    }
121}
122
123#[inline(always)]
124unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
125    pre_compute: &BranchLtPreCompute,
126    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
127) {
128    let rs1_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
129    let rs2_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
130    let rs1 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
131    let rs2 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
132    let cmp_result = OP::compute(rs1, rs2);
133    if cmp_result {
134        vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32;
135    } else {
136        vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
137    }
138    vm_state.instret += 1;
139}
140
141#[create_tco_handler]
142unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
143    pre_compute: &[u8],
144    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
145) {
146    let pre_compute: &BranchLtPreCompute = pre_compute.borrow();
147    execute_e12_impl::<F, CTX, OP>(pre_compute, vm_state);
148}
149
150#[create_tco_handler]
151unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: BranchLessThanOp>(
152    pre_compute: &[u8],
153    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
154) {
155    let pre_compute: &E2PreCompute<BranchLtPreCompute> = pre_compute.borrow();
156    vm_state
157        .ctx
158        .on_height_change(pre_compute.chip_idx as usize, 1);
159    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, vm_state);
160}
161
162impl Rv32BranchLessThan256Executor {
163    fn pre_compute_impl<F: PrimeField32>(
164        &self,
165        pc: u32,
166        inst: &Instruction<F>,
167        data: &mut BranchLtPreCompute,
168    ) -> Result<BranchLessThanOpcode, StaticProgramError> {
169        let Instruction {
170            opcode,
171            a,
172            b,
173            c,
174            d,
175            e,
176            ..
177        } = inst;
178        let c = c.as_canonical_u32();
179        let imm = if F::ORDER_U32 - c < c {
180            -((F::ORDER_U32 - c) as isize)
181        } else {
182            c as isize
183        };
184        let e_u32 = e.as_canonical_u32();
185        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
186            return Err(StaticProgramError::InvalidInstruction(pc));
187        }
188        *data = BranchLtPreCompute {
189            imm,
190            a: a.as_canonical_u32() as u8,
191            b: b.as_canonical_u32() as u8,
192        };
193        let local_opcode = BranchLessThanOpcode::from_usize(
194            opcode.local_opcode_idx(Rv32BranchLessThan256Opcode::CLASS_OFFSET),
195        );
196        Ok(local_opcode)
197    }
198}
199
200trait BranchLessThanOp {
201    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool;
202}
203struct BltOp;
204struct BltuOp;
205struct BgeOp;
206struct BgeuOp;
207
208impl BranchLessThanOp for BltOp {
209    #[inline(always)]
210    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
211        i256_lt(rs1, rs2)
212    }
213}
214impl BranchLessThanOp for BltuOp {
215    #[inline(always)]
216    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
217        u256_lt(rs1, rs2)
218    }
219}
220impl BranchLessThanOp for BgeOp {
221    #[inline(always)]
222    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
223        !i256_lt(rs1, rs2)
224    }
225}
226impl BranchLessThanOp for BgeuOp {
227    #[inline(always)]
228    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
229        !u256_lt(rs1, rs2)
230    }
231}