openvm_bigint_circuit/
branch_lt.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32BranchLessThan256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10    instruction::Instruction,
11    program::DEFAULT_PC_STEP,
12    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13    LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
16use openvm_rv32im_circuit::BranchLessThanExecutor;
17use openvm_rv32im_transpiler::BranchLessThanOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{
21    common::{i256_lt, u256_lt},
22    Rv32BranchLessThan256Executor, INT256_NUM_LIMBS,
23};
24
25type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
26
27impl Rv32BranchLessThan256Executor {
28    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
29        Self(BranchLessThanExecutor::new(adapter, offset))
30    }
31}
32
33#[derive(AlignedBytesBorrow, Clone)]
34#[repr(C)]
35struct BranchLtPreCompute {
36    imm: isize,
37    a: u8,
38    b: u8,
39}
40
41macro_rules! dispatch {
42    ($execute_impl:ident, $local_opcode:ident) => {
43        Ok(match $local_opcode {
44            BranchLessThanOpcode::BLT => $execute_impl::<_, _, BltOp>,
45            BranchLessThanOpcode::BLTU => $execute_impl::<_, _, BltuOp>,
46            BranchLessThanOpcode::BGE => $execute_impl::<_, _, BgeOp>,
47            BranchLessThanOpcode::BGEU => $execute_impl::<_, _, BgeuOp>,
48        })
49    };
50}
51
52impl<F: PrimeField32> InterpreterExecutor<F> for Rv32BranchLessThan256Executor {
53    fn pre_compute_size(&self) -> usize {
54        size_of::<BranchLtPreCompute>()
55    }
56
57    #[cfg(not(feature = "tco"))]
58    fn pre_compute<Ctx>(
59        &self,
60        pc: u32,
61        inst: &Instruction<F>,
62        data: &mut [u8],
63    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
64    where
65        Ctx: ExecutionCtxTrait,
66    {
67        let data: &mut BranchLtPreCompute = data.borrow_mut();
68        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
69        dispatch!(execute_e1_handler, local_opcode)
70    }
71
72    #[cfg(feature = "tco")]
73    fn handler<Ctx>(
74        &self,
75        pc: u32,
76        inst: &Instruction<F>,
77        data: &mut [u8],
78    ) -> Result<Handler<F, Ctx>, StaticProgramError>
79    where
80        Ctx: ExecutionCtxTrait,
81    {
82        let data: &mut BranchLtPreCompute = data.borrow_mut();
83        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
84        dispatch!(execute_e1_handler, local_opcode)
85    }
86}
87
88#[cfg(feature = "aot")]
89impl<F: PrimeField32> AotExecutor<F> for Rv32BranchLessThan256Executor {}
90
91impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32BranchLessThan256Executor {
92    fn metered_pre_compute_size(&self) -> usize {
93        size_of::<E2PreCompute<BranchLtPreCompute>>()
94    }
95
96    #[cfg(not(feature = "tco"))]
97    fn metered_pre_compute<Ctx>(
98        &self,
99        chip_idx: usize,
100        pc: u32,
101        inst: &Instruction<F>,
102        data: &mut [u8],
103    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
104    where
105        Ctx: MeteredExecutionCtxTrait,
106    {
107        let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
108        data.chip_idx = chip_idx as u32;
109        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
110        dispatch!(execute_e2_handler, local_opcode)
111    }
112
113    #[cfg(feature = "tco")]
114    fn metered_handler<Ctx>(
115        &self,
116        chip_idx: usize,
117        pc: u32,
118        inst: &Instruction<F>,
119        data: &mut [u8],
120    ) -> Result<Handler<F, Ctx>, StaticProgramError>
121    where
122        Ctx: MeteredExecutionCtxTrait,
123    {
124        let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
125        data.chip_idx = chip_idx as u32;
126        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
127        dispatch!(execute_e2_handler, local_opcode)
128    }
129}
130
131#[cfg(feature = "aot")]
132impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32BranchLessThan256Executor {}
133
134#[inline(always)]
135unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
136    pre_compute: &BranchLtPreCompute,
137    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
138) {
139    let mut pc = exec_state.pc();
140    let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
141    let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
142    let rs1 =
143        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
144    let rs2 =
145        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
146    let cmp_result = OP::compute(rs1, rs2);
147    if cmp_result {
148        pc = (pc as isize + pre_compute.imm) as u32;
149    } else {
150        pc = pc.wrapping_add(DEFAULT_PC_STEP);
151    }
152    exec_state.set_pc(pc);
153}
154
155#[create_handler]
156#[inline(always)]
157unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
158    pre_compute: *const u8,
159    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
160) {
161    let pre_compute: &BranchLtPreCompute =
162        std::slice::from_raw_parts(pre_compute, size_of::<BranchLtPreCompute>()).borrow();
163    execute_e12_impl::<F, CTX, OP>(pre_compute, exec_state);
164}
165
166#[create_handler]
167#[inline(always)]
168unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: BranchLessThanOp>(
169    pre_compute: *const u8,
170    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
171) {
172    let pre_compute: &E2PreCompute<BranchLtPreCompute> =
173        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<BranchLtPreCompute>>())
174            .borrow();
175    exec_state
176        .ctx
177        .on_height_change(pre_compute.chip_idx as usize, 1);
178    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, exec_state);
179}
180
181impl Rv32BranchLessThan256Executor {
182    fn pre_compute_impl<F: PrimeField32>(
183        &self,
184        pc: u32,
185        inst: &Instruction<F>,
186        data: &mut BranchLtPreCompute,
187    ) -> Result<BranchLessThanOpcode, StaticProgramError> {
188        let Instruction {
189            opcode,
190            a,
191            b,
192            c,
193            d,
194            e,
195            ..
196        } = inst;
197        let c = c.as_canonical_u32();
198        let imm = if F::ORDER_U32 - c < c {
199            -((F::ORDER_U32 - c) as isize)
200        } else {
201            c as isize
202        };
203        let e_u32 = e.as_canonical_u32();
204        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
205            return Err(StaticProgramError::InvalidInstruction(pc));
206        }
207        *data = BranchLtPreCompute {
208            imm,
209            a: a.as_canonical_u32() as u8,
210            b: b.as_canonical_u32() as u8,
211        };
212        let local_opcode = BranchLessThanOpcode::from_usize(
213            opcode.local_opcode_idx(Rv32BranchLessThan256Opcode::CLASS_OFFSET),
214        );
215        Ok(local_opcode)
216    }
217}
218
219trait BranchLessThanOp {
220    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool;
221}
222struct BltOp;
223struct BltuOp;
224struct BgeOp;
225struct BgeuOp;
226
227impl BranchLessThanOp for BltOp {
228    #[inline(always)]
229    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
230        i256_lt(rs1, rs2)
231    }
232}
233impl BranchLessThanOp for BltuOp {
234    #[inline(always)]
235    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
236        u256_lt(rs1, rs2)
237    }
238}
239impl BranchLessThanOp for BgeOp {
240    #[inline(always)]
241    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
242        !i256_lt(rs1, rs2)
243    }
244}
245impl BranchLessThanOp for BgeuOp {
246    #[inline(always)]
247    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
248        !u256_lt(rs1, rs2)
249    }
250}