1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32Mul256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7 instruction::Instruction,
8 program::DEFAULT_PC_STEP,
9 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10 LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
13use openvm_rv32im_circuit::MultiplicationExecutor;
14use openvm_rv32im_transpiler::MulOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18 common::{bytes_to_u32_array, u32_array_to_bytes},
19 Rv32Multiplication256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32Multiplication256Executor {
25 pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26 Self(MultiplicationExecutor::new(adapter, offset))
27 }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct MultPreCompute {
33 a: u8,
34 b: u8,
35 c: u8,
36}
37
38impl<F: PrimeField32> Executor<F> for Rv32Multiplication256Executor {
39 fn pre_compute_size(&self) -> usize {
40 size_of::<MultPreCompute>()
41 }
42
43 fn pre_compute<Ctx>(
44 &self,
45 pc: u32,
46 inst: &Instruction<F>,
47 data: &mut [u8],
48 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
49 where
50 Ctx: ExecutionCtxTrait,
51 {
52 let data: &mut MultPreCompute = data.borrow_mut();
53 self.pre_compute_impl(pc, inst, data)?;
54 Ok(execute_e1_impl)
55 }
56
57 #[cfg(feature = "tco")]
58 fn handler<Ctx>(
59 &self,
60 pc: u32,
61 inst: &Instruction<F>,
62 data: &mut [u8],
63 ) -> Result<Handler<F, Ctx>, StaticProgramError>
64 where
65 Ctx: ExecutionCtxTrait,
66 {
67 let data: &mut MultPreCompute = data.borrow_mut();
68 self.pre_compute_impl(pc, inst, data)?;
69 Ok(execute_e1_tco_handler)
70 }
71}
72
73impl<F: PrimeField32> MeteredExecutor<F> for Rv32Multiplication256Executor {
74 fn metered_pre_compute_size(&self) -> usize {
75 size_of::<E2PreCompute<MultPreCompute>>()
76 }
77
78 fn metered_pre_compute<Ctx>(
79 &self,
80 chip_idx: usize,
81 pc: u32,
82 inst: &Instruction<F>,
83 data: &mut [u8],
84 ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
85 where
86 Ctx: MeteredExecutionCtxTrait,
87 {
88 let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
89 data.chip_idx = chip_idx as u32;
90 self.pre_compute_impl(pc, inst, &mut data.data)?;
91 Ok(execute_e2_impl)
92 }
93
94 #[cfg(feature = "tco")]
95 fn metered_handler<Ctx>(
96 &self,
97 chip_idx: usize,
98 pc: u32,
99 inst: &Instruction<F>,
100 data: &mut [u8],
101 ) -> Result<Handler<F, Ctx>, StaticProgramError>
102 where
103 Ctx: MeteredExecutionCtxTrait,
104 {
105 let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
106 data.chip_idx = chip_idx as u32;
107 self.pre_compute_impl(pc, inst, &mut data.data)?;
108 Ok(execute_e2_tco_handler)
109 }
110}
111
112#[inline(always)]
113unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
114 pre_compute: &MultPreCompute,
115 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
116) {
117 let rs1_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
118 let rs2_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
119 let rd_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
120 let rs1 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
121 let rs2 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
122 let rd = u256_mul(rs1, rs2);
123 vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
124
125 vm_state.pc += DEFAULT_PC_STEP;
126 vm_state.instret += 1;
127}
128
129#[create_tco_handler]
130unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
131 pre_compute: &[u8],
132 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
133) {
134 let pre_compute: &MultPreCompute = pre_compute.borrow();
135 execute_e12_impl(pre_compute, vm_state);
136}
137
138#[create_tco_handler]
139unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
140 pre_compute: &[u8],
141 vm_state: &mut VmExecState<F, GuestMemory, CTX>,
142) {
143 let pre_compute: &E2PreCompute<MultPreCompute> = pre_compute.borrow();
144 vm_state
145 .ctx
146 .on_height_change(pre_compute.chip_idx as usize, 1);
147 execute_e12_impl(&pre_compute.data, vm_state);
148}
149
150impl Rv32Multiplication256Executor {
151 fn pre_compute_impl<F: PrimeField32>(
152 &self,
153 pc: u32,
154 inst: &Instruction<F>,
155 data: &mut MultPreCompute,
156 ) -> Result<(), StaticProgramError> {
157 let Instruction {
158 opcode,
159 a,
160 b,
161 c,
162 d,
163 e,
164 ..
165 } = inst;
166 let e_u32 = e.as_canonical_u32();
167 if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
168 return Err(StaticProgramError::InvalidInstruction(pc));
169 }
170 let local_opcode =
171 MulOpcode::from_usize(opcode.local_opcode_idx(Rv32Mul256Opcode::CLASS_OFFSET));
172 assert_eq!(local_opcode, MulOpcode::MUL);
173 *data = MultPreCompute {
174 a: a.as_canonical_u32() as u8,
175 b: b.as_canonical_u32() as u8,
176 c: c.as_canonical_u32() as u8,
177 };
178 Ok(())
179 }
180}
181
182#[inline(always)]
183pub(crate) fn u256_mul(
184 rs1: [u8; INT256_NUM_LIMBS],
185 rs2: [u8; INT256_NUM_LIMBS],
186) -> [u8; INT256_NUM_LIMBS] {
187 let rs1_u64: [u32; 8] = bytes_to_u32_array(rs1);
188 let rs2_u64: [u32; 8] = bytes_to_u32_array(rs2);
189 let mut rd = [0u32; 8];
190 for i in 0..8 {
191 let mut carry = 0u64;
192 for j in 0..(8 - i) {
193 let res = rs1_u64[i] as u64 * rs2_u64[j] as u64 + rd[i + j] as u64 + carry;
194 rd[i + j] = res as u32;
195 carry = res >> 32;
196 }
197 }
198 u32_array_to_bytes(rd)
199}
200
201#[cfg(test)]
202mod tests {
203 use alloy_primitives::U256;
204 use rand::{prelude::StdRng, Rng, SeedableRng};
205
206 use crate::{common::u64_array_to_bytes, mult::u256_mul, INT256_NUM_LIMBS};
207
208 #[test]
209 fn test_u256_mul() {
210 let mut rng = StdRng::from_seed([42; 32]);
211 for _ in 0..10000 {
212 let limbs_a: [u64; 4] = rng.gen();
213 let limbs_b: [u64; 4] = rng.gen();
214 let a = U256::from_limbs(limbs_a);
215 let b = U256::from_limbs(limbs_b);
216 let a_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_a);
217 let b_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_b);
218 assert_eq!(U256::from_le_bytes(u256_mul(a_u8, b_u8)), a.wrapping_mul(b));
219 }
220 }
221}