1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32Mul256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7 instruction::Instruction,
8 program::DEFAULT_PC_STEP,
9 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10 LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
13use openvm_rv32im_circuit::MultiplicationExecutor;
14use openvm_rv32im_transpiler::MulOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18 common::{bytes_to_u32_array, u32_array_to_bytes},
19 Rv32Multiplication256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32Multiplication256Executor {
25 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26 Self(MultiplicationExecutor::new(adapter, offset))
27 }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct MultPreCompute {
33 a: u8,
34 b: u8,
35 c: u8,
36}
37
38impl<F: PrimeField32> Executor<F> for Rv32Multiplication256Executor {
39 fn pre_compute_size(&self) -> usize {
40 size_of::<MultPreCompute>()
41 }
42
43 #[cfg(not(feature = "tco"))]
44 fn pre_compute<Ctx>(
45 &self,
46 pc: u32,
47 inst: &Instruction<F>,
48 data: &mut [u8],
49 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
50 where
51 Ctx: ExecutionCtxTrait,
52 {
53 let data: &mut MultPreCompute = data.borrow_mut();
54 self.pre_compute_impl(pc, inst, data)?;
55 Ok(execute_e1_impl)
56 }
57
58 #[cfg(feature = "tco")]
59 fn handler<Ctx>(
60 &self,
61 pc: u32,
62 inst: &Instruction<F>,
63 data: &mut [u8],
64 ) -> Result<Handler<F, Ctx>, StaticProgramError>
65 where
66 Ctx: ExecutionCtxTrait,
67 {
68 let data: &mut MultPreCompute = data.borrow_mut();
69 self.pre_compute_impl(pc, inst, data)?;
70 Ok(execute_e1_handler)
71 }
72}
73
74impl<F: PrimeField32> MeteredExecutor<F> for Rv32Multiplication256Executor {
75 fn metered_pre_compute_size(&self) -> usize {
76 size_of::<E2PreCompute<MultPreCompute>>()
77 }
78
79 #[cfg(not(feature = "tco"))]
80 fn metered_pre_compute<Ctx>(
81 &self,
82 chip_idx: usize,
83 pc: u32,
84 inst: &Instruction<F>,
85 data: &mut [u8],
86 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
87 where
88 Ctx: MeteredExecutionCtxTrait,
89 {
90 let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
91 data.chip_idx = chip_idx as u32;
92 self.pre_compute_impl(pc, inst, &mut data.data)?;
93 Ok(execute_e2_impl)
94 }
95
96 #[cfg(feature = "tco")]
97 fn metered_handler<Ctx>(
98 &self,
99 chip_idx: usize,
100 pc: u32,
101 inst: &Instruction<F>,
102 data: &mut [u8],
103 ) -> Result<Handler<F, Ctx>, StaticProgramError>
104 where
105 Ctx: MeteredExecutionCtxTrait,
106 {
107 let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
108 data.chip_idx = chip_idx as u32;
109 self.pre_compute_impl(pc, inst, &mut data.data)?;
110 Ok(execute_e2_handler)
111 }
112}
113
114#[inline(always)]
115unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
116 pre_compute: &MultPreCompute,
117 instret: &mut u64,
118 pc: &mut u32,
119 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
120) {
121 let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
122 let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
123 let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
124 let rs1 =
125 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
126 let rs2 =
127 exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
128 let rd = u256_mul(rs1, rs2);
129 exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
130
131 *pc += DEFAULT_PC_STEP;
132 *instret += 1;
133}
134
135#[create_handler]
136#[inline(always)]
137unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
138 pre_compute: &[u8],
139 instret: &mut u64,
140 pc: &mut u32,
141 _instret_end: u64,
142 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
143) {
144 let pre_compute: &MultPreCompute = pre_compute.borrow();
145 execute_e12_impl(pre_compute, instret, pc, exec_state);
146}
147
148#[create_handler]
149#[inline(always)]
150unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
151 pre_compute: &[u8],
152 instret: &mut u64,
153 pc: &mut u32,
154 _arg: u64,
155 exec_state: &mut VmExecState<F, GuestMemory, CTX>,
156) {
157 let pre_compute: &E2PreCompute<MultPreCompute> = pre_compute.borrow();
158 exec_state
159 .ctx
160 .on_height_change(pre_compute.chip_idx as usize, 1);
161 execute_e12_impl(&pre_compute.data, instret, pc, exec_state);
162}
163
164impl Rv32Multiplication256Executor {
165 fn pre_compute_impl<F: PrimeField32>(
166 &self,
167 pc: u32,
168 inst: &Instruction<F>,
169 data: &mut MultPreCompute,
170 ) -> Result<(), StaticProgramError> {
171 let Instruction {
172 opcode,
173 a,
174 b,
175 c,
176 d,
177 e,
178 ..
179 } = inst;
180 let e_u32 = e.as_canonical_u32();
181 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
182 return Err(StaticProgramError::InvalidInstruction(pc));
183 }
184 let local_opcode =
185 MulOpcode::from_usize(opcode.local_opcode_idx(Rv32Mul256Opcode::CLASS_OFFSET));
186 assert_eq!(local_opcode, MulOpcode::MUL);
187 *data = MultPreCompute {
188 a: a.as_canonical_u32() as u8,
189 b: b.as_canonical_u32() as u8,
190 c: c.as_canonical_u32() as u8,
191 };
192 Ok(())
193 }
194}
195
196#[inline(always)]
197pub(crate) fn u256_mul(
198 rs1: [u8; INT256_NUM_LIMBS],
199 rs2: [u8; INT256_NUM_LIMBS],
200) -> [u8; INT256_NUM_LIMBS] {
201 let rs1_u64: [u32; 8] = bytes_to_u32_array(rs1);
202 let rs2_u64: [u32; 8] = bytes_to_u32_array(rs2);
203 let mut rd = [0u32; 8];
204 for i in 0..8 {
205 let mut carry = 0u64;
206 for j in 0..(8 - i) {
207 let res = rs1_u64[i] as u64 * rs2_u64[j] as u64 + rd[i + j] as u64 + carry;
208 rd[i + j] = res as u32;
209 carry = res >> 32;
210 }
211 }
212 u32_array_to_bytes(rd)
213}
214
215#[cfg(test)]
216mod tests {
217 use alloy_primitives::U256;
218 use rand::{prelude::StdRng, Rng, SeedableRng};
219
220 use crate::{common::u64_array_to_bytes, mult::u256_mul, INT256_NUM_LIMBS};
221
222 #[test]
223 fn test_u256_mul() {
224 let mut rng = StdRng::from_seed([42; 32]);
225 for _ in 0..10000 {
226 let limbs_a: [u64; 4] = rng.gen();
227 let limbs_b: [u64; 4] = rng.gen();
228 let a = U256::from_limbs(limbs_a);
229 let b = U256::from_limbs(limbs_b);
230 let a_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_a);
231 let b_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_b);
232 assert_eq!(U256::from_le_bytes(u256_mul(a_u8, b_u8)), a.wrapping_mul(b));
233 }
234 }
235}