openvm_bigint_circuit/
mult.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32Mul256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7    instruction::Instruction,
8    program::DEFAULT_PC_STEP,
9    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10    LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
13use openvm_rv32im_circuit::MultiplicationExecutor;
14use openvm_rv32im_transpiler::MulOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18    common::{bytes_to_u32_array, u32_array_to_bytes},
19    Rv32Multiplication256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32Multiplication256Executor {
25    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26        Self(MultiplicationExecutor::new(adapter, offset))
27    }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct MultPreCompute {
33    a: u8,
34    b: u8,
35    c: u8,
36}
37
38impl<F: PrimeField32> InterpreterExecutor<F> for Rv32Multiplication256Executor {
39    fn pre_compute_size(&self) -> usize {
40        size_of::<MultPreCompute>()
41    }
42
43    #[cfg(not(feature = "tco"))]
44    fn pre_compute<Ctx>(
45        &self,
46        pc: u32,
47        inst: &Instruction<F>,
48        data: &mut [u8],
49    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
50    where
51        Ctx: ExecutionCtxTrait,
52    {
53        let data: &mut MultPreCompute = data.borrow_mut();
54        self.pre_compute_impl(pc, inst, data)?;
55        Ok(execute_e1_impl)
56    }
57
58    #[cfg(feature = "tco")]
59    fn handler<Ctx>(
60        &self,
61        pc: u32,
62        inst: &Instruction<F>,
63        data: &mut [u8],
64    ) -> Result<Handler<F, Ctx>, StaticProgramError>
65    where
66        Ctx: ExecutionCtxTrait,
67    {
68        let data: &mut MultPreCompute = data.borrow_mut();
69        self.pre_compute_impl(pc, inst, data)?;
70        Ok(execute_e1_handler)
71    }
72}
73
74#[cfg(feature = "aot")]
75impl<F: PrimeField32> AotExecutor<F> for Rv32Multiplication256Executor {}
76
77impl<F: PrimeField32> InterpreterMeteredExecutor<F> for Rv32Multiplication256Executor {
78    fn metered_pre_compute_size(&self) -> usize {
79        size_of::<E2PreCompute<MultPreCompute>>()
80    }
81
82    #[cfg(not(feature = "tco"))]
83    fn metered_pre_compute<Ctx>(
84        &self,
85        chip_idx: usize,
86        pc: u32,
87        inst: &Instruction<F>,
88        data: &mut [u8],
89    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
90    where
91        Ctx: MeteredExecutionCtxTrait,
92    {
93        let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
94        data.chip_idx = chip_idx as u32;
95        self.pre_compute_impl(pc, inst, &mut data.data)?;
96        Ok(execute_e2_impl)
97    }
98
99    #[cfg(feature = "tco")]
100    fn metered_handler<Ctx>(
101        &self,
102        chip_idx: usize,
103        pc: u32,
104        inst: &Instruction<F>,
105        data: &mut [u8],
106    ) -> Result<Handler<F, Ctx>, StaticProgramError>
107    where
108        Ctx: MeteredExecutionCtxTrait,
109    {
110        let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
111        data.chip_idx = chip_idx as u32;
112        self.pre_compute_impl(pc, inst, &mut data.data)?;
113        Ok(execute_e2_handler)
114    }
115}
116
117#[cfg(feature = "aot")]
118impl<F: PrimeField32> AotMeteredExecutor<F> for Rv32Multiplication256Executor {}
119
120#[inline(always)]
121unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
122    pre_compute: &MultPreCompute,
123    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
124) {
125    let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
126    let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
127    let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
128    let rs1 =
129        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
130    let rs2 =
131        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
132    let rd = u256_mul(rs1, rs2);
133    exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
134
135    let pc = exec_state.pc();
136    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
137}
138
139#[create_handler]
140#[inline(always)]
141unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
142    pre_compute: *const u8,
143    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
144) {
145    let pre_compute: &MultPreCompute =
146        std::slice::from_raw_parts(pre_compute, size_of::<MultPreCompute>()).borrow();
147    execute_e12_impl(pre_compute, exec_state);
148}
149
150#[create_handler]
151#[inline(always)]
152unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
153    pre_compute: *const u8,
154    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
155) {
156    let pre_compute: &E2PreCompute<MultPreCompute> =
157        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<MultPreCompute>>()).borrow();
158    exec_state
159        .ctx
160        .on_height_change(pre_compute.chip_idx as usize, 1);
161    execute_e12_impl(&pre_compute.data, exec_state);
162}
163
164impl Rv32Multiplication256Executor {
165    fn pre_compute_impl<F: PrimeField32>(
166        &self,
167        pc: u32,
168        inst: &Instruction<F>,
169        data: &mut MultPreCompute,
170    ) -> Result<(), StaticProgramError> {
171        let Instruction {
172            opcode,
173            a,
174            b,
175            c,
176            d,
177            e,
178            ..
179        } = inst;
180        let e_u32 = e.as_canonical_u32();
181        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
182            return Err(StaticProgramError::InvalidInstruction(pc));
183        }
184        let local_opcode =
185            MulOpcode::from_usize(opcode.local_opcode_idx(Rv32Mul256Opcode::CLASS_OFFSET));
186        assert_eq!(local_opcode, MulOpcode::MUL);
187        *data = MultPreCompute {
188            a: a.as_canonical_u32() as u8,
189            b: b.as_canonical_u32() as u8,
190            c: c.as_canonical_u32() as u8,
191        };
192        Ok(())
193    }
194}
195
196#[inline(always)]
197pub(crate) fn u256_mul(
198    rs1: [u8; INT256_NUM_LIMBS],
199    rs2: [u8; INT256_NUM_LIMBS],
200) -> [u8; INT256_NUM_LIMBS] {
201    let rs1_u64: [u32; 8] = bytes_to_u32_array(rs1);
202    let rs2_u64: [u32; 8] = bytes_to_u32_array(rs2);
203    let mut rd = [0u32; 8];
204    for i in 0..8 {
205        let mut carry = 0u64;
206        for j in 0..(8 - i) {
207            let res = rs1_u64[i] as u64 * rs2_u64[j] as u64 + rd[i + j] as u64 + carry;
208            rd[i + j] = res as u32;
209            carry = res >> 32;
210        }
211    }
212    u32_array_to_bytes(rd)
213}
214
215#[cfg(test)]
216mod tests {
217    use alloy_primitives::U256;
218    use rand::{prelude::StdRng, Rng, SeedableRng};
219
220    use crate::{common::u64_array_to_bytes, mult::u256_mul, INT256_NUM_LIMBS};
221
222    #[test]
223    fn test_u256_mul() {
224        let mut rng = StdRng::from_seed([42; 32]);
225        for _ in 0..10000 {
226            let limbs_a: [u64; 4] = rng.gen();
227            let limbs_b: [u64; 4] = rng.gen();
228            let a = U256::from_limbs(limbs_a);
229            let b = U256::from_limbs(limbs_b);
230            let a_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_a);
231            let b_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_b);
232            assert_eq!(U256::from_le_bytes(u256_mul(a_u8, b_u8)), a.wrapping_mul(b));
233        }
234    }
235}