openvm_bigint_circuit/
mult.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32Mul256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7    instruction::Instruction,
8    program::DEFAULT_PC_STEP,
9    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10    LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
13use openvm_rv32im_circuit::MultiplicationExecutor;
14use openvm_rv32im_transpiler::MulOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18    common::{bytes_to_u32_array, u32_array_to_bytes},
19    Rv32Multiplication256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32Multiplication256Executor {
25    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26        Self(MultiplicationExecutor::new(adapter, offset))
27    }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct MultPreCompute {
33    a: u8,
34    b: u8,
35    c: u8,
36}
37
38impl<F: PrimeField32> Executor<F> for Rv32Multiplication256Executor {
39    fn pre_compute_size(&self) -> usize {
40        size_of::<MultPreCompute>()
41    }
42
43    #[cfg(not(feature = "tco"))]
44    fn pre_compute<Ctx>(
45        &self,
46        pc: u32,
47        inst: &Instruction<F>,
48        data: &mut [u8],
49    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
50    where
51        Ctx: ExecutionCtxTrait,
52    {
53        let data: &mut MultPreCompute = data.borrow_mut();
54        self.pre_compute_impl(pc, inst, data)?;
55        Ok(execute_e1_impl)
56    }
57
58    #[cfg(feature = "tco")]
59    fn handler<Ctx>(
60        &self,
61        pc: u32,
62        inst: &Instruction<F>,
63        data: &mut [u8],
64    ) -> Result<Handler<F, Ctx>, StaticProgramError>
65    where
66        Ctx: ExecutionCtxTrait,
67    {
68        let data: &mut MultPreCompute = data.borrow_mut();
69        self.pre_compute_impl(pc, inst, data)?;
70        Ok(execute_e1_handler)
71    }
72}
73
74impl<F: PrimeField32> MeteredExecutor<F> for Rv32Multiplication256Executor {
75    fn metered_pre_compute_size(&self) -> usize {
76        size_of::<E2PreCompute<MultPreCompute>>()
77    }
78
79    #[cfg(not(feature = "tco"))]
80    fn metered_pre_compute<Ctx>(
81        &self,
82        chip_idx: usize,
83        pc: u32,
84        inst: &Instruction<F>,
85        data: &mut [u8],
86    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
87    where
88        Ctx: MeteredExecutionCtxTrait,
89    {
90        let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
91        data.chip_idx = chip_idx as u32;
92        self.pre_compute_impl(pc, inst, &mut data.data)?;
93        Ok(execute_e2_impl)
94    }
95
96    #[cfg(feature = "tco")]
97    fn metered_handler<Ctx>(
98        &self,
99        chip_idx: usize,
100        pc: u32,
101        inst: &Instruction<F>,
102        data: &mut [u8],
103    ) -> Result<Handler<F, Ctx>, StaticProgramError>
104    where
105        Ctx: MeteredExecutionCtxTrait,
106    {
107        let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
108        data.chip_idx = chip_idx as u32;
109        self.pre_compute_impl(pc, inst, &mut data.data)?;
110        Ok(execute_e2_handler)
111    }
112}
113
114#[inline(always)]
115unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
116    pre_compute: &MultPreCompute,
117    instret: &mut u64,
118    pc: &mut u32,
119    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
120) {
121    let rs1_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
122    let rs2_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
123    let rd_ptr = exec_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
124    let rs1 =
125        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
126    let rs2 =
127        exec_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
128    let rd = u256_mul(rs1, rs2);
129    exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
130
131    *pc += DEFAULT_PC_STEP;
132    *instret += 1;
133}
134
135#[create_handler]
136#[inline(always)]
137unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
138    pre_compute: &[u8],
139    instret: &mut u64,
140    pc: &mut u32,
141    _instret_end: u64,
142    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
143) {
144    let pre_compute: &MultPreCompute = pre_compute.borrow();
145    execute_e12_impl(pre_compute, instret, pc, exec_state);
146}
147
148#[create_handler]
149#[inline(always)]
150unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
151    pre_compute: &[u8],
152    instret: &mut u64,
153    pc: &mut u32,
154    _arg: u64,
155    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
156) {
157    let pre_compute: &E2PreCompute<MultPreCompute> = pre_compute.borrow();
158    exec_state
159        .ctx
160        .on_height_change(pre_compute.chip_idx as usize, 1);
161    execute_e12_impl(&pre_compute.data, instret, pc, exec_state);
162}
163
164impl Rv32Multiplication256Executor {
165    fn pre_compute_impl<F: PrimeField32>(
166        &self,
167        pc: u32,
168        inst: &Instruction<F>,
169        data: &mut MultPreCompute,
170    ) -> Result<(), StaticProgramError> {
171        let Instruction {
172            opcode,
173            a,
174            b,
175            c,
176            d,
177            e,
178            ..
179        } = inst;
180        let e_u32 = e.as_canonical_u32();
181        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
182            return Err(StaticProgramError::InvalidInstruction(pc));
183        }
184        let local_opcode =
185            MulOpcode::from_usize(opcode.local_opcode_idx(Rv32Mul256Opcode::CLASS_OFFSET));
186        assert_eq!(local_opcode, MulOpcode::MUL);
187        *data = MultPreCompute {
188            a: a.as_canonical_u32() as u8,
189            b: b.as_canonical_u32() as u8,
190            c: c.as_canonical_u32() as u8,
191        };
192        Ok(())
193    }
194}
195
196#[inline(always)]
197pub(crate) fn u256_mul(
198    rs1: [u8; INT256_NUM_LIMBS],
199    rs2: [u8; INT256_NUM_LIMBS],
200) -> [u8; INT256_NUM_LIMBS] {
201    let rs1_u64: [u32; 8] = bytes_to_u32_array(rs1);
202    let rs2_u64: [u32; 8] = bytes_to_u32_array(rs2);
203    let mut rd = [0u32; 8];
204    for i in 0..8 {
205        let mut carry = 0u64;
206        for j in 0..(8 - i) {
207            let res = rs1_u64[i] as u64 * rs2_u64[j] as u64 + rd[i + j] as u64 + carry;
208            rd[i + j] = res as u32;
209            carry = res >> 32;
210        }
211    }
212    u32_array_to_bytes(rd)
213}
214
215#[cfg(test)]
216mod tests {
217    use alloy_primitives::U256;
218    use rand::{prelude::StdRng, Rng, SeedableRng};
219
220    use crate::{common::u64_array_to_bytes, mult::u256_mul, INT256_NUM_LIMBS};
221
222    #[test]
223    fn test_u256_mul() {
224        let mut rng = StdRng::from_seed([42; 32]);
225        for _ in 0..10000 {
226            let limbs_a: [u64; 4] = rng.gen();
227            let limbs_b: [u64; 4] = rng.gen();
228            let a = U256::from_limbs(limbs_a);
229            let b = U256::from_limbs(limbs_b);
230            let a_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_a);
231            let b_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_b);
232            assert_eq!(U256::from_le_bytes(u256_mul(a_u8, b_u8)), a.wrapping_mul(b));
233        }
234    }
235}