openvm_bigint_circuit/
mult.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_bigint_transpiler::Rv32Mul256Opcode;
4use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
5use openvm_circuit_primitives_derive::AlignedBytesBorrow;
6use openvm_instructions::{
7    instruction::Instruction,
8    program::DEFAULT_PC_STEP,
9    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
10    LocalOpcode,
11};
12use openvm_rv32_adapters::Rv32HeapAdapterExecutor;
13use openvm_rv32im_circuit::MultiplicationExecutor;
14use openvm_rv32im_transpiler::MulOpcode;
15use openvm_stark_backend::p3_field::PrimeField32;
16
17use crate::{
18    common::{bytes_to_u32_array, u32_array_to_bytes},
19    Rv32Multiplication256Executor, INT256_NUM_LIMBS,
20};
21
22type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>;
23
24impl Rv32Multiplication256Executor {
25    pub fn new(adapter: AdapterExecutor, offset: usize) -> Self {
26        Self(MultiplicationExecutor::new(adapter, offset))
27    }
28}
29
30#[derive(AlignedBytesBorrow, Clone)]
31#[repr(C)]
32struct MultPreCompute {
33    a: u8,
34    b: u8,
35    c: u8,
36}
37
38impl<F: PrimeField32> Executor<F> for Rv32Multiplication256Executor {
39    fn pre_compute_size(&self) -> usize {
40        size_of::<MultPreCompute>()
41    }
42
43    fn pre_compute<Ctx>(
44        &self,
45        pc: u32,
46        inst: &Instruction<F>,
47        data: &mut [u8],
48    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
49    where
50        Ctx: ExecutionCtxTrait,
51    {
52        let data: &mut MultPreCompute = data.borrow_mut();
53        self.pre_compute_impl(pc, inst, data)?;
54        Ok(execute_e1_impl)
55    }
56
57    #[cfg(feature = "tco")]
58    fn handler<Ctx>(
59        &self,
60        pc: u32,
61        inst: &Instruction<F>,
62        data: &mut [u8],
63    ) -> Result<Handler<F, Ctx>, StaticProgramError>
64    where
65        Ctx: ExecutionCtxTrait,
66    {
67        let data: &mut MultPreCompute = data.borrow_mut();
68        self.pre_compute_impl(pc, inst, data)?;
69        Ok(execute_e1_tco_handler)
70    }
71}
72
73impl<F: PrimeField32> MeteredExecutor<F> for Rv32Multiplication256Executor {
74    fn metered_pre_compute_size(&self) -> usize {
75        size_of::<E2PreCompute<MultPreCompute>>()
76    }
77
78    fn metered_pre_compute<Ctx>(
79        &self,
80        chip_idx: usize,
81        pc: u32,
82        inst: &Instruction<F>,
83        data: &mut [u8],
84    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
85    where
86        Ctx: MeteredExecutionCtxTrait,
87    {
88        let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
89        data.chip_idx = chip_idx as u32;
90        self.pre_compute_impl(pc, inst, &mut data.data)?;
91        Ok(execute_e2_impl)
92    }
93
94    #[cfg(feature = "tco")]
95    fn metered_handler<Ctx>(
96        &self,
97        chip_idx: usize,
98        pc: u32,
99        inst: &Instruction<F>,
100        data: &mut [u8],
101    ) -> Result<Handler<F, Ctx>, StaticProgramError>
102    where
103        Ctx: MeteredExecutionCtxTrait,
104    {
105        let data: &mut E2PreCompute<MultPreCompute> = data.borrow_mut();
106        data.chip_idx = chip_idx as u32;
107        self.pre_compute_impl(pc, inst, &mut data.data)?;
108        Ok(execute_e2_tco_handler)
109    }
110}
111
112#[inline(always)]
113unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
114    pre_compute: &MultPreCompute,
115    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
116) {
117    let rs1_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.b as u32);
118    let rs2_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.c as u32);
119    let rd_ptr = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, pre_compute.a as u32);
120    let rs1 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr));
121    let rs2 = vm_state.vm_read::<u8, INT256_NUM_LIMBS>(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr));
122    let rd = u256_mul(rs1, rs2);
123    vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd);
124
125    vm_state.pc += DEFAULT_PC_STEP;
126    vm_state.instret += 1;
127}
128
129#[create_tco_handler]
130unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
131    pre_compute: &[u8],
132    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
133) {
134    let pre_compute: &MultPreCompute = pre_compute.borrow();
135    execute_e12_impl(pre_compute, vm_state);
136}
137
138#[create_tco_handler]
139unsafe fn execute_e2_impl<F: PrimeField32, CTX: MeteredExecutionCtxTrait>(
140    pre_compute: &[u8],
141    vm_state: &mut VmExecState<F, GuestMemory, CTX>,
142) {
143    let pre_compute: &E2PreCompute<MultPreCompute> = pre_compute.borrow();
144    vm_state
145        .ctx
146        .on_height_change(pre_compute.chip_idx as usize, 1);
147    execute_e12_impl(&pre_compute.data, vm_state);
148}
149
150impl Rv32Multiplication256Executor {
151    fn pre_compute_impl<F: PrimeField32>(
152        &self,
153        pc: u32,
154        inst: &Instruction<F>,
155        data: &mut MultPreCompute,
156    ) -> Result<(), StaticProgramError> {
157        let Instruction {
158            opcode,
159            a,
160            b,
161            c,
162            d,
163            e,
164            ..
165        } = inst;
166        let e_u32 = e.as_canonical_u32();
167        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS {
168            return Err(StaticProgramError::InvalidInstruction(pc));
169        }
170        let local_opcode =
171            MulOpcode::from_usize(opcode.local_opcode_idx(Rv32Mul256Opcode::CLASS_OFFSET));
172        assert_eq!(local_opcode, MulOpcode::MUL);
173        *data = MultPreCompute {
174            a: a.as_canonical_u32() as u8,
175            b: b.as_canonical_u32() as u8,
176            c: c.as_canonical_u32() as u8,
177        };
178        Ok(())
179    }
180}
181
182#[inline(always)]
183pub(crate) fn u256_mul(
184    rs1: [u8; INT256_NUM_LIMBS],
185    rs2: [u8; INT256_NUM_LIMBS],
186) -> [u8; INT256_NUM_LIMBS] {
187    let rs1_u64: [u32; 8] = bytes_to_u32_array(rs1);
188    let rs2_u64: [u32; 8] = bytes_to_u32_array(rs2);
189    let mut rd = [0u32; 8];
190    for i in 0..8 {
191        let mut carry = 0u64;
192        for j in 0..(8 - i) {
193            let res = rs1_u64[i] as u64 * rs2_u64[j] as u64 + rd[i + j] as u64 + carry;
194            rd[i + j] = res as u32;
195            carry = res >> 32;
196        }
197    }
198    u32_array_to_bytes(rd)
199}
200
201#[cfg(test)]
202mod tests {
203    use alloy_primitives::U256;
204    use rand::{prelude::StdRng, Rng, SeedableRng};
205
206    use crate::{common::u64_array_to_bytes, mult::u256_mul, INT256_NUM_LIMBS};
207
208    #[test]
209    fn test_u256_mul() {
210        let mut rng = StdRng::from_seed([42; 32]);
211        for _ in 0..10000 {
212            let limbs_a: [u64; 4] = rng.gen();
213            let limbs_b: [u64; 4] = rng.gen();
214            let a = U256::from_limbs(limbs_a);
215            let b = U256::from_limbs(limbs_b);
216            let a_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_a);
217            let b_u8: [u8; INT256_NUM_LIMBS] = u64_array_to_bytes(limbs_b);
218            assert_eq!(U256::from_le_bytes(u256_mul(a_u8, b_u8)), a.wrapping_mul(b));
219        }
220    }
221}