openvm_bigint_circuit/
branch_lt.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32BranchLessThan256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10    instruction::Instruction,
11    program::DEFAULT_PC_STEP,
12    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13    LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor;
16use openvm_rv32im_circuit::BranchLessThanExecutor;
17use openvm_rv32im_transpiler::BranchLessThanOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{
21    common::{i256_lt, u256_lt},
22    Rv32BranchLessThan256Executor, INT256_NUM_LIMBS,
23};
24
25type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>;
26
27impl Rv32BranchLessThan256Executor {
28    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
29        Self(BranchLessThanExecutor::new(adapter, offset))
30    }
31}
32
33#[derive(AlignedBytesBorrow, Clone)]
34#[repr(C)]
35struct BranchLtPreCompute {
36    imm: isize,
37    a: u8,
38    b: u8,
39}
40
41macro_rules! dispatch {
42    ($execute_impl:ident, $local_opcode:ident) => {
43        Ok(match $local_opcode {
44            BranchLessThanOpcode::BLT => $execute_impl::<_, _, BltOp>,
45            BranchLessThanOpcode::BLTU => $execute_impl::<_, _, BltuOp>,
46            BranchLessThanOpcode::BGE => $execute_impl::<_, _, BgeOp>,
47            BranchLessThanOpcode::BGEU => $execute_impl::<_, _, BgeuOp>,
48        })
49    };
50}
51
52impl<F: PrimeField32> Executor<F> for Rv32BranchLessThan256Executor {
53    fn pre_compute_size(&self) -> usize {
54        size_of::<BranchLtPreCompute>()
55    }
56
57    #[cfg(not(feature = "tco"))]
58    fn pre_compute<Ctx>(
59        &self,
60        pc: u32,
61        inst: &Instruction<F>,
62        data: &mut [u8],
63    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
64    where
65        Ctx: ExecutionCtxTrait,
66    {
67        let data: &mut BranchLtPreCompute = data.borrow_mut();
68        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
69        dispatch!(execute_e1_handler, local_opcode)
70    }
71
72    #[cfg(feature = "tco")]
73    fn handler<Ctx>(
74        &self,
75        pc: u32,
76        inst: &Instruction<F>,
77        data: &mut [u8],
78    ) -> Result<Handler<F, Ctx>, StaticProgramError>
79    where
80        Ctx: ExecutionCtxTrait,
81    {
82        let data: &mut BranchLtPreCompute = data.borrow_mut();
83        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
84        dispatch!(execute_e1_handler, local_opcode)
85    }
86}
87
88impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchLessThan256Executor {
89    fn metered_pre_compute_size(&self) -> usize {
90        size_of::<E2PreCompute<BranchLtPreCompute>>()
91    }
92
93    #[cfg(not(feature = "tco"))]
94    fn metered_pre_compute<Ctx>(
95        &self,
96        chip_idx: usize,
97        pc: u32,
98        inst: &Instruction<F>,
99        data: &mut [u8],
100    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
101    where
102        Ctx: MeteredExecutionCtxTrait,
103    {
104        let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
105        data.chip_idx = chip_idx as u32;
106        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
107        dispatch!(execute_e2_handler, local_opcode)
108    }
109
110    #[cfg(feature = "tco")]
111    fn metered_handler<Ctx>(
112        &self,
113        chip_idx: usize,
114        pc: u32,
115        inst: &Instruction<F>,
116        data: &mut [u8],
117    ) -> Result<Handler<F, Ctx>, StaticProgramError>
118    where
119        Ctx: MeteredExecutionCtxTrait,
120    {
121        let data: &mut E2PreCompute<BranchLtPreCompute> = data.borrow_mut();
122        data.chip_idx = chip_idx as u32;
123        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
124        dispatch!(execute_e2_handler, local_opcode)
125    }
126}
127
128#[inline(always)]
129unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
130    pre_compute: &BranchLtPreCompute,
131    instret: &mut u64,
132    pc: &mut u32,
133    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
134) {
135    let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
136    let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
137    let rs1 =
138        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
139    let rs2 =
140        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
141    let cmp_result = OP::compute(rs1, rs2);
142    if cmp_result {
143        *pc = (*pc as isize + pre_compute.imm) as u32;
144    } else {
145        *pc = pc.wrapping_add(DEFAULT_PC_STEP);
146    }
147    *instret += 1;
148}
149
150#[create_handler]
151#[inline(always)]
152unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
153    pre_compute: &[u8],
154    instret: &mut u64,
155    pc: &mut u32,
156    _instret_end: u64,
157    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
158) {
159    let pre_compute: &BranchLtPreCompute = pre_compute.borrow();
160    execute_e12_impl::<F, CTX, OP>(pre_compute, instret, pc, exec_state);
161}
162
163#[create_handler]
164#[inline(always)]
165unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: BranchLessThanOp>(
166    pre_compute: &[u8],
167    instret: &mut u64,
168    pc: &mut u32,
169    _arg: u64,
170    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
171) {
172    let pre_compute: &E2PreCompute<BranchLtPreCompute> = pre_compute.borrow();
173    exec_state
174        .ctx
175        .on_height_change(pre_compute.chip_idx as usize, 1);
176    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, instret, pc, exec_state);
177}
178
179impl Rv32BranchLessThan256Executor {
180    fn pre_compute_impl<F: PrimeField32>(
181        &self,
182        pc: u32,
183        inst: &Instruction<F>,
184        data: &mut BranchLtPreCompute,
185    ) -> Result<BranchLessThanOpcode, StaticProgramError> {
186        let Instruction {
187            opcode,
188            a,
189            b,
190            c,
191            d,
192            e,
193            ..
194        } = inst;
195        let c = c.as_canonical_u32();
196        let imm = if F::ORDER_U32 - c < c {
197            -((F::ORDER_U32 - c) as isize)
198        } else {
199            c as isize
200        };
201        let e_u32 = e.as_canonical_u32();
202        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
203            return Err(StaticProgramError::InvalidInstruction(pc));
204        }
205        *data = BranchLtPreCompute {
206            imm,
207            a: a.as_canonical_u32() as u8,
208            b: b.as_canonical_u32() as u8,
209        };
210        let local_opcode = BranchLessThanOpcode::from_usize(
211            opcode.local_opcode_idx(Rv32BranchLessThan256Opcode::CLASS_OFFSET),
212        );
213        Ok(local_opcode)
214    }
215}
216
217trait BranchLessThanOp {
218    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool;
219}
220struct BltOp;
221struct BltuOp;
222struct BgeOp;
223struct BgeuOp;
224
225impl BranchLessThanOp for BltOp {
226    #[inline(always)]
227    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
228        i256_lt(rs1, rs2)
229    }
230}
231impl BranchLessThanOp for BltuOp {
232    #[inline(always)]
233    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
234        u256_lt(rs1, rs2)
235    }
236}
237impl BranchLessThanOp for BgeOp {
238    #[inline(always)]
239    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
240        !i256_lt(rs1, rs2)
241    }
242}
243impl BranchLessThanOp for BgeuOp {
244    #[inline(always)]
245    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool {
246        !u256_lt(rs1, rs2)
247    }
248}