1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32Mul256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7 instruction::Instruction,
8 program::DEFAULT_PC_STEP,
9 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10 LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
13use openvm_rv32im_circuit::MultiplicationExecutor;
14use openvm_rv32im_transpiler::MulOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18 common::{bytes_to_u32_array, u32_array_to_bytes},
19 Rv32Multiplication256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32Multiplication256Executor {
25 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26 Self(MultiplicationExecutor::new(adapter, offset))
27 }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct MultPreCompute {
33 a: u8,
34 b: u8,
35 c: u8,
36}
37
38impl<F: PrimeField32> InterpreterExecutor<F> for Rv32Multiplication256Executor {
39 fn pre_compute_size(&self) -> usize {
40 size_of::<MultPreCompute>()
41 }
42
43 #[cfg(not(feature = "tco"))]
44 fn pre_compute<Ctx>(
45 &self,
46 pc: u32,
47 inst: &Instruction<F>,
48 data: &mut [u8],
49 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
50 where
51 Ctx: ExecutionCtxTrait,
52 {
53 let data: &mut MultPreCompute = data.borrow_mut();
54 self.pre_compute_impl(pc, inst, data)?;
55 Ok(execute_e1_impl)
56 }
57
58 #[cfg(feature = "tco")]
59 fn handler<Ctx>(
60 &self,
61 pc: u32,
62 inst: &Instruction<F>,
63 data: &mut [u8],
64 ) -> Result<Handler<F, Ctx>, StaticProgramError>
65 where
66 Ctx: ExecutionCtxTrait,
67 {
68 let data: &mut MultPreCompute = data.borrow_mut();
69 self.pre_compute_impl(pc, inst, data)?;
70 Ok(execute_e1_handler)
71 }
72}
73
74#[cfg(feature = "aot")]
75impl<F: PrimeField32> AotExecutor<F> for Rv32Multiplication256Executor {}
76
77impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32Multiplication256Executor {
78 fn metered_pre_compute_size(&self) -> usize {
79 size_of::<E2PreCompute<MultPreCompute>>()
80 }
81
82 #[cfg(not(feature = "tco"))]
83 fn metered_pre_compute<Ctx>(
84 &self,
85 chip_idx: usize,
86 pc: u32,
87 inst: &Instruction<F>,
88 data: &mut [u8],
89 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
90 where
91 Ctx: MeteredExecutionCtxTrait,
92 {
93 let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
94 data.chip_idx = chip_idx as u32;
95 self.pre_compute_impl(pc, inst, &mut data.data)?;
96 Ok(execute_e2_impl)
97 }
98
99 #[cfg(feature = "tco")]
100 fn metered_handler<Ctx>(
101 &self,
102 chip_idx: usize,
103 pc: u32,
104 inst: &Instruction<F>,
105 data: &mut [u8],
106 ) -> Result<Handler<F, Ctx>, StaticProgramError>
107 where
108 Ctx: MeteredExecutionCtxTrait,
109 {
110 let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
111 data.chip_idx = chip_idx as u32;
112 self.pre_compute_impl(pc, inst, &mut data.data)?;
113 Ok(execute_e2_handler)
114 }
115}
116
117#[cfg(feature = "aot")]
118impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32Multiplication256Executor {}
119
120#[inline(always)]
121unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
122 pre_compute: &MultPreCompute,
123 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
124) {
125 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
126 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
127 let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
128 let rs1 =
129 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
130 let rs2 =
131 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
132 let rd = u256_mul(rs1, rs2);
133 exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
134
135 let pc = exec_state.pc();
136 exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
137}
138
139#[create_handler]
140#[inline(always)]
141unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
142 pre_compute: *const u8,
143 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
144) {
145 let pre_compute: &MultPreCompute =
146 std::slice::from_raw_parts(pre_compute, size_of::<MultPreCompute>()).borrow();
147 execute_e12_impl(pre_compute, exec_state);
148}
149
150#[create_handler]
151#[inline(always)]
152unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
153 pre_compute: *const u8,
154 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
155) {
156 let pre_compute: &E2PreCompute<MultPreCompute> =
157 std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<MultPreCompute>>()).borrow();
158 exec_state
159 .ctx
160 .on_height_change(pre_compute.chip_idx as usize, 1);
161 execute_e12_impl(&pre_compute.data, exec_state);
162}
163
164impl Rv32Multiplication256Executor {
165 fn pre_compute_impl<F: PrimeField32>(
166 &self,
167 pc: u32,
168 inst: &Instruction<F>,
169 data: &mut MultPreCompute,
170 ) -> Result<(), StaticProgramError> {
171 let Instruction {
172 opcode,
173 a,
174 b,
175 c,
176 d,
177 e,
178 ..
179 } = inst;
180 let e_u32 = e.as_canonical_u32();
181 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
182 return Err(StaticProgramError::InvalidInstruction(pc));
183 }
184 let local_opcode =
185 MulOpcode::from_usize(opcode.local_opcode_idx(Rv32Mul256Opcode::CLASS_OFFSET));
186 assert_eq!(local_opcode, MulOpcode::MUL);
187 *data = MultPreCompute {
188 a: a.as_canonical_u32() as u8,
189 b: b.as_canonical_u32() as u8,
190 c: c.as_canonical_u32() as u8,
191 };
192 Ok(())
193 }
194}
195
196#[inline(always)]
197pub(crate) fn u256_mul(
198 rs1: [u8; INT256_NUM_LIMBS],
199 rs2: [u8; INT256_NUM_LIMBS],
200) -> [u8; INT256_NUM_LIMBS] {
201 let rs1_u64: [u32; 8] = bytes_to_u32_array(rs1);
202 let rs2_u64: [u32; 8] = bytes_to_u32_array(rs2);
203 let mut rd = [0u32; 8];
204 for i in 0..8 {
205 let mut carry = 0u64;
206 for j in 0..(8 - i) {
207 let res = rs1_u64[i] as u64 * rs2_u64[j] as u64 + rd[i + j] as u64 + carry;
208 rd[i + j] = res as u32;
209 carry = res >> 32;
210 }
211 }
212 u32_array_to_bytes(rd)
213}
214
215#[cfg(test)]
216mod tests {
217 use alloy_primitives::U256;
218 use rand::{prelude::StdRng, Rng, SeedableRng};
219
220 use crate::{common::u64_array_to_bytes, mult::u256_mul, INT256_NUM_LIMBS};
221
222 #[test]
223 fn test_u256_mul() {
224 let mut rng = StdRng::from_seed([42; 32]);
225 for _ in 0..10000 {
226 let limbs_a: [u64; 4] = rng.gen();
227 let limbs_b: [u64; 4] = rng.gen();
228 let a = U256::from_limbs(limbs_a);
229 let b = U256::from_limbs(limbs_b);
230 let a_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_a);
231 let b_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_b);
232 assert_eq!(U256::from_le_bytes(u256_mul(a_u8, b_u8)), a.wrapping_mul(b));
233 }
234 }
235}