openvm_rv32im_circuit/mul/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
8    VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip};
11use openvm_circuit_primitives_derive::AlignedBorrow;
12use openvm_instructions::{instruction::Instruction, LocalOpcode};
13use openvm_rv32im_transpiler::MulOpcode;
14use openvm_stark_backend::{
15    interaction::InteractionBuilder,
16    p3_air::BaseAir,
17    p3_field::{Field, FieldAlgebra, PrimeField32},
18    rap::BaseAirWithPublicValues,
19};
20use serde::{de::DeserializeOwned, Deserialize, Serialize};
21use serde_big_array::BigArray;
22
23#[repr(C)]
24#[derive(AlignedBorrow)]
25pub struct MultiplicationCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
26    pub a: [T; NUM_LIMBS],
27    pub b: [T; NUM_LIMBS],
28    pub c: [T; NUM_LIMBS],
29    pub is_valid: T,
30}
31
32#[derive(Copy, Clone, Debug)]
33pub struct MultiplicationCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
34    pub bus: RangeTupleCheckerBus<2>,
35    pub offset: usize,
36}
37
38impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
39    for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
40{
41    fn width(&self) -> usize {
42        MultiplicationCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
43    }
44}
45impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
46    for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
47{
48}
49
50impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
51    for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
52where
53    AB: InteractionBuilder,
54    I: VmAdapterInterface<AB::Expr>,
55    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
56    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
57    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
58{
59    fn eval(
60        &self,
61        builder: &mut AB,
62        local_core: &[AB::Var],
63        _from_pc: AB::Var,
64    ) -> AdapterAirContext<AB::Expr, I> {
65        let cols: &MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
66        builder.assert_bool(cols.is_valid);
67
68        let a = &cols.a;
69        let b = &cols.b;
70        let c = &cols.c;
71
72        // Define carry[i] = (sum_{k=0}^{i} b[k] * c[i - k] + carry[i - 1] - a[i]) / 2^LIMB_BITS.
73        // If 0 <= a[i], carry[i] < 2^LIMB_BITS, it can be proven that a[i] = sum_{k=0}^{i} (b[k] * c[i - k]) % 2^LIMB_BITS as necessary.
74        let mut carry: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
75        let carry_divide = AB::F::from_canonical_u32(1 << LIMB_BITS).inverse();
76
77        for i in 0..NUM_LIMBS {
78            let expected_limb = if i == 0 {
79                AB::Expr::ZERO
80            } else {
81                carry[i - 1].clone()
82            } + (0..=i).fold(AB::Expr::ZERO, |acc, k| acc + (b[k] * c[i - k]));
83            carry[i] = AB::Expr::from(carry_divide) * (expected_limb - a[i]);
84        }
85
86        for (a, carry) in a.iter().zip(carry.iter()) {
87            self.bus
88                .send(vec![(*a).into(), carry.clone()])
89                .eval(builder, cols.is_valid);
90        }
91
92        let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, MulOpcode::MUL);
93
94        AdapterAirContext {
95            to_pc: None,
96            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
97            writes: [cols.a.map(Into::into)].into(),
98            instruction: MinimalInstruction {
99                is_valid: cols.is_valid.into(),
100                opcode: expected_opcode,
101            }
102            .into(),
103        }
104    }
105
106    fn start_offset(&self) -> usize {
107        self.offset
108    }
109}
110
111#[derive(Debug)]
112pub struct MultiplicationCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
113    pub air: MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>,
114    pub range_tuple_chip: SharedRangeTupleCheckerChip<2>,
115}
116
117impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> MultiplicationCoreChip<NUM_LIMBS, LIMB_BITS> {
118    pub fn new(range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize) -> Self {
119        // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i
120        // < NUM_LIMBS. a[i] must have LIMB_BITS bits and carry[i] is the sum of i + 1 bytes
121        // (with LIMB_BITS bits).
122        debug_assert!(
123            range_tuple_chip.sizes()[0] == 1 << LIMB_BITS,
124            "First element of RangeTupleChecker must have size {}",
125            1 << LIMB_BITS
126        );
127        debug_assert!(
128            range_tuple_chip.sizes()[1] >= (1 << LIMB_BITS) * NUM_LIMBS as u32,
129            "Second element of RangeTupleChecker must have size of at least {}",
130            (1 << LIMB_BITS) * NUM_LIMBS as u32
131        );
132
133        Self {
134            air: MultiplicationCoreAir {
135                bus: *range_tuple_chip.bus(),
136                offset,
137            },
138            range_tuple_chip,
139        }
140    }
141}
142
143#[repr(C)]
144#[derive(Clone, Debug, Serialize, Deserialize)]
145#[serde(bound = "T: Serialize + DeserializeOwned")]
146pub struct MultiplicationCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
147    #[serde(with = "BigArray")]
148    pub a: [T; NUM_LIMBS],
149    #[serde(with = "BigArray")]
150    pub b: [T; NUM_LIMBS],
151    #[serde(with = "BigArray")]
152    pub c: [T; NUM_LIMBS],
153}
154
155impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
156    VmCoreChip<F, I> for MultiplicationCoreChip<NUM_LIMBS, LIMB_BITS>
157where
158    I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
159    I::Writes: From<[[F; NUM_LIMBS]; 1]>,
160{
161    type Record = MultiplicationCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
162    type Air = MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>;
163
164    #[allow(clippy::type_complexity)]
165    fn execute_instruction(
166        &self,
167        instruction: &Instruction<F>,
168        _from_pc: u32,
169        reads: I::Reads,
170    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
171        let Instruction { opcode, .. } = instruction;
172        assert_eq!(
173            MulOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)),
174            MulOpcode::MUL
175        );
176
177        let data: [[F; NUM_LIMBS]; 2] = reads.into();
178        let b = data[0].map(|x| x.as_canonical_u32());
179        let c = data[1].map(|y| y.as_canonical_u32());
180        let (a, carry) = run_mul::<NUM_LIMBS, LIMB_BITS>(&b, &c);
181
182        for (a, carry) in a.iter().zip(carry.iter()) {
183            self.range_tuple_chip.add_count(&[*a, *carry]);
184        }
185
186        let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]);
187        let record = MultiplicationCoreRecord {
188            a: a.map(F::from_canonical_u32),
189            b: data[0],
190            c: data[1],
191        };
192
193        Ok((output, record))
194    }
195
196    fn get_opcode_name(&self, opcode: usize) -> String {
197        format!("{:?}", MulOpcode::from_usize(opcode - self.air.offset))
198    }
199
200    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
201        let row_slice: &mut MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> =
202            row_slice.borrow_mut();
203        row_slice.a = record.a;
204        row_slice.b = record.b;
205        row_slice.c = record.c;
206        row_slice.is_valid = F::ONE;
207    }
208
209    fn air(&self) -> &Self::Air {
210        &self.air
211    }
212}
213
214// returns mul, carry
215pub(super) fn run_mul<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
216    x: &[u32; NUM_LIMBS],
217    y: &[u32; NUM_LIMBS],
218) -> ([u32; NUM_LIMBS], [u32; NUM_LIMBS]) {
219    let mut result = [0; NUM_LIMBS];
220    let mut carry = [0; NUM_LIMBS];
221    for i in 0..NUM_LIMBS {
222        if i > 0 {
223            result[i] = carry[i - 1];
224        }
225        for j in 0..=i {
226            result[i] += x[j] * y[i - j];
227        }
228        carry[i] = result[i] >> LIMB_BITS;
229        result[i] %= 1 << LIMB_BITS;
230    }
231    (result, carry)
232}