openvm_bigint_circuit/
base_alu.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32BaseAlu256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10    instruction::Instruction,
11    program::DEFAULT_PC_STEP,
12    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13    LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
16use openvm_rv32im_circuit::BaseAluExecutor;
17use openvm_rv32im_transpiler::BaseAluOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{
21    common::{bytes_to_u64_array, u64_array_to_bytes},
22    Rv32BaseAlu256Executor, INT256_NUM_LIMBS,
23};
24
25type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
26
27impl Rv32BaseAlu256Executor {
28    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
29        Self(BaseAluExecutor::new(adapter, offset))
30    }
31}
32
33#[derive(AlignedBytesBorrow)]
34struct BaseAluPreCompute {
35    a: u8,
36    b: u8,
37    c: u8,
38}
39
40macro_rules! dispatch {
41    ($execute_impl:ident, $local_opcode:ident) => {
42        Ok(match $local_opcode {
43            BaseAluOpcode::ADD => $execute_impl::<_, _, AddOp>,
44            BaseAluOpcode::SUB => $execute_impl::<_, _, SubOp>,
45            BaseAluOpcode::XOR => $execute_impl::<_, _, XorOp>,
46            BaseAluOpcode::OR => $execute_impl::<_, _, OrOp>,
47            BaseAluOpcode::AND => $execute_impl::<_, _, AndOp>,
48        })
49    };
50}
51
52impl<F: PrimeField32> InterpreterExecutor<F> for Rv32BaseAlu256Executor {
53    fn pre_compute_size(&self) -> usize {
54        size_of::<BaseAluPreCompute>()
55    }
56
57    #[cfg(not(feature = "tco"))]
58    fn pre_compute<Ctx>(
59        &self,
60        pc: u32,
61        inst: &Instruction<F>,
62        data: &mut [u8],
63    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
64    where
65        Ctx: ExecutionCtxTrait,
66    {
67        let data: &mut BaseAluPreCompute = data.borrow_mut();
68        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
69
70        dispatch!(execute_e1_handler, local_opcode)
71    }
72
73    #[cfg(feature = "tco")]
74    fn handler<Ctx>(
75        &self,
76        pc: u32,
77        inst: &Instruction<F>,
78        data: &mut [u8],
79    ) -> Result<Handler<F, Ctx>, StaticProgramError>
80    where
81        Ctx: ExecutionCtxTrait,
82    {
83        let data: &mut BaseAluPreCompute = data.borrow_mut();
84        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
85
86        dispatch!(execute_e1_handler, local_opcode)
87    }
88}
89
90#[cfg(feature = "aot")]
91impl<F: PrimeField32> AotExecutor<F> for Rv32BaseAlu256Executor {}
92
93impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32BaseAlu256Executor {
94    fn metered_pre_compute_size(&self) -> usize {
95        size_of::<E2PreCompute<BaseAluPreCompute>>()
96    }
97
98    #[cfg(not(feature = "tco"))]
99    fn metered_pre_compute<Ctx>(
100        &self,
101        chip_idx: usize,
102        pc: u32,
103        inst: &Instruction<F>,
104        data: &mut [u8],
105    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
106    where
107        Ctx: MeteredExecutionCtxTrait,
108    {
109        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
110        data.chip_idx = chip_idx as u32;
111        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
112
113        dispatch!(execute_e2_handler, local_opcode)
114    }
115
116    #[cfg(feature = "tco")]
117    fn metered_handler<Ctx>(
118        &self,
119        chip_idx: usize,
120        pc: u32,
121        inst: &Instruction<F>,
122        data: &mut [u8],
123    ) -> Result<Handler<F, Ctx>, StaticProgramError>
124    where
125        Ctx: MeteredExecutionCtxTrait,
126    {
127        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
128        data.chip_idx = chip_idx as u32;
129        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
130
131        dispatch!(execute_e2_handler, local_opcode)
132    }
133}
134#[cfg(feature = "aot")]
135impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32BaseAlu256Executor {}
136
137#[inline(always)]
138unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: AluOp>(
139    pre_compute: &BaseAluPreCompute,
140    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
141) {
142    let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
143    let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
144    let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
145    let rs1 =
146        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
147    let rs2 =
148        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
149    let rd = <OP as AluOp>::compute(rs1, rs2);
150    exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
151    let pc = exec_state.pc();
152    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
153}
154
155#[create_handler]
156#[inline(always)]
157unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: AluOp>(
158    pre_compute: *const u8,
159    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
160) {
161    let pre_compute: &BaseAluPreCompute =
162        std::slice::from_raw_parts(pre_compute, size_of::<BaseAluPreCompute>()).borrow();
163    execute_e12_impl::<F, CTX, OP>(pre_compute, exec_state);
164}
165
166#[create_handler]
167#[inline(always)]
168unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: AluOp>(
169    pre_compute: *const u8,
170    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
171) {
172    let pre_compute: &E2PreCompute<BaseAluPreCompute> =
173        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<BaseAluPreCompute>>())
174            .borrow();
175    exec_state
176        .ctx
177        .on_height_change(pre_compute.chip_idx as usize, 1);
178    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, exec_state);
179}
180
181impl Rv32BaseAlu256Executor {
182    fn pre_compute_impl<F: PrimeField32>(
183        &self,
184        pc: u32,
185        inst: &Instruction<F>,
186        data: &mut BaseAluPreCompute,
187    ) -> Result<BaseAluOpcode, StaticProgramError> {
188        let Instruction {
189            opcode,
190            a,
191            b,
192            c,
193            d,
194            e,
195            ..
196        } = inst;
197        let e_u32 = e.as_canonical_u32();
198        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
199            return Err(StaticProgramError::InvalidInstruction(pc));
200        }
201        *data = BaseAluPreCompute {
202            a: a.as_canonical_u32() as u8,
203            b: b.as_canonical_u32() as u8,
204            c: c.as_canonical_u32() as u8,
205        };
206        let local_opcode =
207            BaseAluOpcode::from_usize(opcode.local_opcode_idx(Rv32BaseAlu256Opcode::CLASS_OFFSET));
208        Ok(local_opcode)
209    }
210}
211
212trait AluOp {
213    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS];
214}
215struct AddOp;
216struct SubOp;
217struct XorOp;
218struct OrOp;
219struct AndOp;
220impl AluOp for AddOp {
221    #[inline(always)]
222    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
223        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
224        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
225        let mut rd_u64 = [0u64; 4];
226        let (res, mut carry) = rs1_u64[0].overflowing_add(rs2_u64[0]);
227        rd_u64[0] = res;
228        for i in 1..4 {
229            let (res1, c1) = rs1_u64[i].overflowing_add(rs2_u64[i]);
230            let (res2, c2) = res1.overflowing_add(carry as u64);
231            carry = c1 || c2;
232            rd_u64[i] = res2;
233        }
234        u64_array_to_bytes(rd_u64)
235    }
236}
237impl AluOp for SubOp {
238    #[inline(always)]
239    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
240        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
241        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
242        let mut rd_u64 = [0u64; 4];
243        let (res, mut borrow) = rs1_u64[0].overflowing_sub(rs2_u64[0]);
244        rd_u64[0] = res;
245        for i in 1..4 {
246            let (res1, c1) = rs1_u64[i].overflowing_sub(rs2_u64[i]);
247            let (res2, c2) = res1.overflowing_sub(borrow as u64);
248            borrow = c1 || c2;
249            rd_u64[i] = res2;
250        }
251        u64_array_to_bytes(rd_u64)
252    }
253}
254impl AluOp for XorOp {
255    #[inline(always)]
256    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
257        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
258        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
259        let mut rd_u64 = [0u64; 4];
260        // Compiler will expand this loop.
261        for i in 0..4 {
262            rd_u64[i] = rs1_u64[i] ^ rs2_u64[i];
263        }
264        u64_array_to_bytes(rd_u64)
265    }
266}
267impl AluOp for OrOp {
268    #[inline(always)]
269    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
270        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
271        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
272        let mut rd_u64 = [0u64; 4];
273        // Compiler will expand this loop.
274        for i in 0..4 {
275            rd_u64[i] = rs1_u64[i] | rs2_u64[i];
276        }
277        u64_array_to_bytes(rd_u64)
278    }
279}
280impl AluOp for AndOp {
281    #[inline(always)]
282    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
283        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
284        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
285        let mut rd_u64 = [0u64; 4];
286        // Compiler will expand this loop.
287        for i in 0..4 {
288            rd_u64[i] = rs1_u64[i] & rs2_u64[i];
289        }
290        u64_array_to_bytes(rd_u64)
291    }
292}