openvm_bigint_circuit/
base_alu.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32BaseAlu256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7    instruction::Instruction,
8    program::DEFAULT_PC_STEP,
9    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10    LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
13use openvm_rv32im_circuit::BaseAluExecutor;
14use openvm_rv32im_transpiler::BaseAluOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18    common::{bytes_to_u64_array, u64_array_to_bytes},
19    Rv32BaseAlu256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32BaseAlu256Executor {
25    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26        Self(BaseAluExecutor::new(adapter, offset))
27    }
28}
29
30#[derive(AlignedBytesBorrow)]
31struct BaseAluPreCompute {
32    a: u8,
33    b: u8,
34    c: u8,
35}
36
37macro_rules! dispatch {
38    ($execute_impl:ident, $local_opcode:ident) => {
39        Ok(match $local_opcode {
40            BaseAluOpcode::ADD => $execute_impl::<_, _, AddOp>,
41            BaseAluOpcode::SUB => $execute_impl::<_, _, SubOp>,
42            BaseAluOpcode::XOR => $execute_impl::<_, _, XorOp>,
43            BaseAluOpcode::OR => $execute_impl::<_, _, OrOp>,
44            BaseAluOpcode::AND => $execute_impl::<_, _, AndOp>,
45        })
46    };
47}
48
49impl<F: PrimeField32> Executor<F> for Rv32BaseAlu256Executor {
50    fn pre_compute_size(&self) -> usize {
51        size_of::<BaseAluPreCompute>()
52    }
53
54    fn pre_compute<Ctx>(
55        &self,
56        pc: u32,
57        inst: &Instruction<F>,
58        data: &mut [u8],
59    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
60    where
61        Ctx: ExecutionCtxTrait,
62    {
63        let data: &mut BaseAluPreCompute = data.borrow_mut();
64        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
65
66        dispatch!(execute_e1_impl, local_opcode)
67    }
68
69    #[cfg(feature = "tco")]
70    fn handler<Ctx>(
71        &self,
72        pc: u32,
73        inst: &Instruction<F>,
74        data: &mut [u8],
75    ) -> Result<Handler<F, Ctx>, StaticProgramError>
76    where
77        Ctx: ExecutionCtxTrait,
78    {
79        let data: &mut BaseAluPreCompute = data.borrow_mut();
80        let local_opcode = self.pre_compute_impl(pc, inst, data)?;
81
82        dispatch!(execute_e1_tco_handler, local_opcode)
83    }
84}
85
86impl<F: PrimeField32> MeteredExecutor<F> for Rv32BaseAlu256Executor {
87    fn metered_pre_compute_size(&self) -> usize {
88        size_of::<E2PreCompute<BaseAluPreCompute>>()
89    }
90
91    fn metered_pre_compute<Ctx>(
92        &self,
93        chip_idx: usize,
94        pc: u32,
95        inst: &Instruction<F>,
96        data: &mut [u8],
97    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
98    where
99        Ctx: MeteredExecutionCtxTrait,
100    {
101        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
102        data.chip_idx = chip_idx as u32;
103        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
104
105        dispatch!(execute_e2_impl, local_opcode)
106    }
107
108    #[cfg(feature = "tco")]
109    fn metered_handler<Ctx>(
110        &self,
111        chip_idx: usize,
112        pc: u32,
113        inst: &Instruction<F>,
114        data: &mut [u8],
115    ) -> Result<Handler<F, Ctx>, StaticProgramError>
116    where
117        Ctx: MeteredExecutionCtxTrait,
118    {
119        let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
120        data.chip_idx = chip_idx as u32;
121        let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
122
123        dispatch!(execute_e2_tco_handler, local_opcode)
124    }
125}
126
127#[inline(always)]
128unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: AluOp>(
129    pre_compute: &BaseAluPreCompute,
130    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
131) {
132    let rs1_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
133    let rs2_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
134    let rd_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
135    let rs1 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
136    let rs2 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
137    let rd = <OP as AluOp>::compute(rs1, rs2);
138    vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
139    vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP);
140    vm_state.instret += 1;
141}
142
143#[create_tco_handler]
144unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: AluOp>(
145    pre_compute: &[u8],
146    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
147) {
148    let pre_compute: &BaseAluPreCompute = pre_compute.borrow();
149    execute_e12_impl::<F, CTX, OP>(pre_compute, vm_state);
150}
151
152#[create_tco_handler]
153unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: AluOp>(
154    pre_compute: &[u8],
155    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
156) {
157    let pre_compute: &E2PreCompute<BaseAluPreCompute> = pre_compute.borrow();
158    vm_state
159        .ctx
160        .on_height_change(pre_compute.chip_idx as usize, 1);
161    execute_e12_impl::<F, CTX, OP>(&pre_compute.data, vm_state);
162}
163
164impl Rv32BaseAlu256Executor {
165    fn pre_compute_impl<F: PrimeField32>(
166        &self,
167        pc: u32,
168        inst: &Instruction<F>,
169        data: &mut BaseAluPreCompute,
170    ) -> Result<BaseAluOpcode, StaticProgramError> {
171        let Instruction {
172            opcode,
173            a,
174            b,
175            c,
176            d,
177            e,
178            ..
179        } = inst;
180        let e_u32 = e.as_canonical_u32();
181        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
182            return Err(StaticProgramError::InvalidInstruction(pc));
183        }
184        *data = BaseAluPreCompute {
185            a: a.as_canonical_u32() as u8,
186            b: b.as_canonical_u32() as u8,
187            c: c.as_canonical_u32() as u8,
188        };
189        let local_opcode =
190            BaseAluOpcode::from_usize(opcode.local_opcode_idx(Rv32BaseAlu256Opcode::CLASS_OFFSET));
191        Ok(local_opcode)
192    }
193}
194
195trait AluOp {
196    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS];
197}
198struct AddOp;
199struct SubOp;
200struct XorOp;
201struct OrOp;
202struct AndOp;
203impl AluOp for AddOp {
204    #[inline(always)]
205    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
206        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
207        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
208        let mut rd_u64 = [0u64; 4];
209        let (res, mut carry) = rs1_u64[0].overflowing_add(rs2_u64[0]);
210        rd_u64[0] = res;
211        for i in 1..4 {
212            let (res1, c1) = rs1_u64[i].overflowing_add(rs2_u64[i]);
213            let (res2, c2) = res1.overflowing_add(carry as u64);
214            carry = c1 || c2;
215            rd_u64[i] = res2;
216        }
217        u64_array_to_bytes(rd_u64)
218    }
219}
220impl AluOp for SubOp {
221    #[inline(always)]
222    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
223        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
224        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
225        let mut rd_u64 = [0u64; 4];
226        let (res, mut borrow) = rs1_u64[0].overflowing_sub(rs2_u64[0]);
227        rd_u64[0] = res;
228        for i in 1..4 {
229            let (res1, c1) = rs1_u64[i].overflowing_sub(rs2_u64[i]);
230            let (res2, c2) = res1.overflowing_sub(borrow as u64);
231            borrow = c1 || c2;
232            rd_u64[i] = res2;
233        }
234        u64_array_to_bytes(rd_u64)
235    }
236}
237impl AluOp for XorOp {
238    #[inline(always)]
239    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
240        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
241        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
242        let mut rd_u64 = [0u64; 4];
243        // Compiler will expand this loop.
244        for i in 0..4 {
245            rd_u64[i] = rs1_u64[i] ^ rs2_u64[i];
246        }
247        u64_array_to_bytes(rd_u64)
248    }
249}
250impl AluOp for OrOp {
251    #[inline(always)]
252    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
253        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
254        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
255        let mut rd_u64 = [0u64; 4];
256        // Compiler will expand this loop.
257        for i in 0..4 {
258            rd_u64[i] = rs1_u64[i] | rs2_u64[i];
259        }
260        u64_array_to_bytes(rd_u64)
261    }
262}
263impl AluOp for AndOp {
264    #[inline(always)]
265    fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
266        let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
267        let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
268        let mut rd_u64 = [0u64; 4];
269        // Compiler will expand this loop.
270        for i in 0..4 {
271            rd_u64[i] = rs1_u64[i] & rs2_u64[i];
272        }
273        u64_array_to_bytes(rd_u64)
274    }
275}