1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use openvm_bigint_transpiler::Rv32BaseAlu256Opcode;
7use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
8use openvm_circuit_primitives_derive::AlignedBytesBorrow;
9use openvm_instructions::{
10 instruction::Instruction,
11 program::DEFAULT_PC_STEP,
12 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
13 LocalOpcode,
14};
15use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
16use openvm_rv32im_circuit::BaseAluExecutor;
17use openvm_rv32im_transpiler::BaseAluOpcode;
18use openvm_stark_backend::p3_field::PrimeField32;
19
20use crate::{
21 common::{bytes_to_u64_array, u64_array_to_bytes},
22 Rv32BaseAlu256Executor, INT256_NUM_LIMBS,
23};
24
25type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
26
27impl Rv32BaseAlu256Executor {
28 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
29 Self(BaseAluExecutor::new(adapter, offset))
30 }
31}
32
33#[derive(AlignedBytesBorrow)]
34struct BaseAluPreCompute {
35 a: u8,
36 b: u8,
37 c: u8,
38}
39
40macro_rules! dispatch {
41 ($execute_impl:ident, $local_opcode:ident) => {
42 Ok(match $local_opcode {
43 BaseAluOpcode::ADD => $execute_impl::<_, _, AddOp>,
44 BaseAluOpcode::SUB => $execute_impl::<_, _, SubOp>,
45 BaseAluOpcode::XOR => $execute_impl::<_, _, XorOp>,
46 BaseAluOpcode::OR => $execute_impl::<_, _, OrOp>,
47 BaseAluOpcode::AND => $execute_impl::<_, _, AndOp>,
48 })
49 };
50}
51
52impl<F: PrimeField32> Executor<F> for Rv32BaseAlu256Executor {
53 fn pre_compute_size(&self) -> usize {
54 size_of::<BaseAluPreCompute>()
55 }
56
57 #[cfg(not(feature = "tco"))]
58 fn pre_compute<Ctx>(
59 &self,
60 pc: u32,
61 inst: &Instruction<F>,
62 data: &mut [u8],
63 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
64 where
65 Ctx: ExecutionCtxTrait,
66 {
67 let data: &mut BaseAluPreCompute = data.borrow_mut();
68 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
69
70 dispatch!(execute_e1_handler, local_opcode)
71 }
72
73 #[cfg(feature = "tco")]
74 fn handler<Ctx>(
75 &self,
76 pc: u32,
77 inst: &Instruction<F>,
78 data: &mut [u8],
79 ) -> Result<Handler<F, Ctx>, StaticProgramError>
80 where
81 Ctx: ExecutionCtxTrait,
82 {
83 let data: &mut BaseAluPreCompute = data.borrow_mut();
84 let local_opcode = self.pre_compute_impl(pc, inst, data)?;
85
86 dispatch!(execute_e1_handler, local_opcode)
87 }
88}
89
90impl<F: PrimeField32> MeteredExecutor<F> for Rv32BaseAlu256Executor {
91 fn metered_pre_compute_size(&self) -> usize {
92 size_of::<E2PreCompute<BaseAluPreCompute>>()
93 }
94
95 #[cfg(not(feature = "tco"))]
96 fn metered_pre_compute<Ctx>(
97 &self,
98 chip_idx: usize,
99 pc: u32,
100 inst: &Instruction<F>,
101 data: &mut [u8],
102 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
103 where
104 Ctx: MeteredExecutionCtxTrait,
105 {
106 let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
107 data.chip_idx = chip_idx as u32;
108 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
109
110 dispatch!(execute_e2_handler, local_opcode)
111 }
112
113 #[cfg(feature = "tco")]
114 fn metered_handler<Ctx>(
115 &self,
116 chip_idx: usize,
117 pc: u32,
118 inst: &Instruction<F>,
119 data: &mut [u8],
120 ) -> Result<Handler<F, Ctx>, StaticProgramError>
121 where
122 Ctx: MeteredExecutionCtxTrait,
123 {
124 let data: &mut E2PreCompute<BaseAluPreCompute> = data.borrow_mut();
125 data.chip_idx = chip_idx as u32;
126 let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?;
127
128 dispatch!(execute_e2_handler, local_opcode)
129 }
130}
131
132#[inline(always)]
133unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: AluOp>(
134 pre_compute: &BaseAluPreCompute,
135 instret: &mut u64,
136 pc: &mut u32,
137 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
138) {
139 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
140 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
141 let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
142 let rs1 =
143 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
144 let rs2 =
145 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
146 let rd = <OP as AluOp>::compute(rs1, rs2);
147 exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
148 *pc = pc.wrapping_add(DEFAULT_PC_STEP);
149 *instret += 1;
150}
151
152#[create_handler]
153#[inline(always)]
154unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: AluOp>(
155 pre_compute: &[u8],
156 instret: &mut u64,
157 pc: &mut u32,
158 _instret_end: u64,
159 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
160) {
161 let pre_compute: &BaseAluPreCompute = pre_compute.borrow();
162 execute_e12_impl::<F, CTX, OP>(pre_compute, instret, pc, exec_state);
163}
164
165#[create_handler]
166#[inline(always)]
167unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait, OP: AluOp>(
168 pre_compute: &[u8],
169 instret: &mut u64,
170 pc: &mut u32,
171 _arg: u64,
172 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
173) {
174 let pre_compute: &E2PreCompute<BaseAluPreCompute> = pre_compute.borrow();
175 exec_state
176 .ctx
177 .on_height_change(pre_compute.chip_idx as usize, 1);
178 execute_e12_impl::<F, CTX, OP>(&pre_compute.data, instret, pc, exec_state);
179}
180
181impl Rv32BaseAlu256Executor {
182 fn pre_compute_impl<F: PrimeField32>(
183 &self,
184 pc: u32,
185 inst: &Instruction<F>,
186 data: &mut BaseAluPreCompute,
187 ) -> Result<BaseAluOpcode, StaticProgramError> {
188 let Instruction {
189 opcode,
190 a,
191 b,
192 c,
193 d,
194 e,
195 ..
196 } = inst;
197 let e_u32 = e.as_canonical_u32();
198 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
199 return Err(StaticProgramError::InvalidInstruction(pc));
200 }
201 *data = BaseAluPreCompute {
202 a: a.as_canonical_u32() as u8,
203 b: b.as_canonical_u32() as u8,
204 c: c.as_canonical_u32() as u8,
205 };
206 let local_opcode =
207 BaseAluOpcode::from_usize(opcode.local_opcode_idx(Rv32BaseAlu256Opcode::CLASS_OFFSET));
208 Ok(local_opcode)
209 }
210}
211
212trait AluOp {
213 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS];
214}
215struct AddOp;
216struct SubOp;
217struct XorOp;
218struct OrOp;
219struct AndOp;
220impl AluOp for AddOp {
221 #[inline(always)]
222 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
223 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
224 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
225 let mut rd_u64 = [0u64; 4];
226 let (res, mut carry) = rs1_u64[0].overflowing_add(rs2_u64[0]);
227 rd_u64[0] = res;
228 for i in 1..4 {
229 let (res1, c1) = rs1_u64[i].overflowing_add(rs2_u64[i]);
230 let (res2, c2) = res1.overflowing_add(carry as u64);
231 carry = c1 || c2;
232 rd_u64[i] = res2;
233 }
234 u64_array_to_bytes(rd_u64)
235 }
236}
237impl AluOp for SubOp {
238 #[inline(always)]
239 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
240 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
241 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
242 let mut rd_u64 = [0u64; 4];
243 let (res, mut borrow) = rs1_u64[0].overflowing_sub(rs2_u64[0]);
244 rd_u64[0] = res;
245 for i in 1..4 {
246 let (res1, c1) = rs1_u64[i].overflowing_sub(rs2_u64[i]);
247 let (res2, c2) = res1.overflowing_sub(borrow as u64);
248 borrow = c1 || c2;
249 rd_u64[i] = res2;
250 }
251 u64_array_to_bytes(rd_u64)
252 }
253}
254impl AluOp for XorOp {
255 #[inline(always)]
256 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
257 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
258 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
259 let mut rd_u64 = [0u64; 4];
260 for i in 0..4 {
262 rd_u64[i] = rs1_u64[i] ^ rs2_u64[i];
263 }
264 u64_array_to_bytes(rd_u64)
265 }
266}
267impl AluOp for OrOp {
268 #[inline(always)]
269 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
270 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
271 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
272 let mut rd_u64 = [0u64; 4];
273 for i in 0..4 {
275 rd_u64[i] = rs1_u64[i] | rs2_u64[i];
276 }
277 u64_array_to_bytes(rd_u64)
278 }
279}
280impl AluOp for AndOp {
281 #[inline(always)]
282 fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] {
283 let rs1_u64: [u64; 4] = bytes_to_u64_array(rs1);
284 let rs2_u64: [u64; 4] = bytes_to_u64_array(rs2);
285 let mut rd_u64 = [0u64; 4];
286 for i in 0..4 {
288 rd_u64[i] = rs1_u64[i] & rs2_u64[i];
289 }
290 u64_array_to_bytes(rd_u64)
291 }
292}