openvm_rv32im_circuit/mul/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip},
12    AlignedBytesBorrow,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
16use openvm_rv32im_transpiler::MulOpcode;
17use openvm_stark_backend::{
18    interaction::InteractionBuilder,
19    p3_air::BaseAir,
20    p3_field::{Field, FieldAlgebra, PrimeField32},
21    rap::BaseAirWithPublicValues,
22};
23
24#[repr(C)]
25#[derive(AlignedBorrow)]
26pub struct MultiplicationCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
27    pub a: [T; NUM_LIMBS],
28    pub b: [T; NUM_LIMBS],
29    pub c: [T; NUM_LIMBS],
30    pub is_valid: T,
31}
32
33#[derive(Copy, Clone, Debug, derive_new::new)]
34pub struct MultiplicationCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
35    pub bus: RangeTupleCheckerBus<2>,
36    pub offset: usize,
37}
38
39impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
40    for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
41{
42    fn width(&self) -> usize {
43        MultiplicationCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
44    }
45}
46impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
47    for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
48{
49}
50
51impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
52    for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
53where
54    AB: InteractionBuilder,
55    I: VmAdapterInterface<AB::Expr>,
56    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
57    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
58    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
59{
60    fn eval(
61        &self,
62        builder: &mut AB,
63        local_core: &[AB::Var],
64        _from_pc: AB::Var,
65    ) -> AdapterAirContext<AB::Expr, I> {
66        let cols: &MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
67        builder.assert_bool(cols.is_valid);
68
69        let a = &cols.a;
70        let b = &cols.b;
71        let c = &cols.c;
72
73        // Define carry[i] = (sum_{k=0}^{i} b[k] * c[i - k] + carry[i - 1] - a[i]) / 2^LIMB_BITS.
74        // If 0 <= a[i], carry[i] < 2^LIMB_BITS, it can be proven that a[i] = sum_{k=0}^{i} (b[k] *
75        // c[i - k]) % 2^LIMB_BITS as necessary.
76        let mut carry: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
77        let carry_divide = AB::F::from_canonical_u32(1 << LIMB_BITS).inverse();
78
79        for i in 0..NUM_LIMBS {
80            let expected_limb = if i == 0 {
81                AB::Expr::ZERO
82            } else {
83                carry[i - 1].clone()
84            } + (0..=i).fold(AB::Expr::ZERO, |acc, k| acc + (b[k] * c[i - k]));
85            carry[i] = AB::Expr::from(carry_divide) * (expected_limb - a[i]);
86        }
87
88        for (a, carry) in a.iter().zip(carry.iter()) {
89            self.bus
90                .send(vec![(*a).into(), carry.clone()])
91                .eval(builder, cols.is_valid);
92        }
93
94        let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, MulOpcode::MUL);
95
96        AdapterAirContext {
97            to_pc: None,
98            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
99            writes: [cols.a.map(Into::into)].into(),
100            instruction: MinimalInstruction {
101                is_valid: cols.is_valid.into(),
102                opcode: expected_opcode,
103            }
104            .into(),
105        }
106    }
107
108    fn start_offset(&self) -> usize {
109        self.offset
110    }
111}
112
113#[repr(C)]
114#[derive(AlignedBytesBorrow, Debug)]
115pub struct MultiplicationCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
116    pub b: [u8; NUM_LIMBS],
117    pub c: [u8; NUM_LIMBS],
118}
119
120#[derive(Clone, Copy, derive_new::new)]
121pub struct MultiplicationExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
122    adapter: A,
123    pub offset: usize,
124}
125
126#[derive(Clone, Debug)]
127pub struct MultiplicationFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
128    adapter: A,
129    pub offset: usize,
130    pub range_tuple_chip: SharedRangeTupleCheckerChip<2>,
131}
132
133impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize>
134    MultiplicationFiller<A, NUM_LIMBS, LIMB_BITS>
135{
136    pub fn new(
137        adapter: A,
138        range_tuple_chip: SharedRangeTupleCheckerChip<2>,
139        offset: usize,
140    ) -> Self {
141        // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i
142        // < NUM_LIMBS. a[i] must have LIMB_BITS bits and carry[i] is the sum of i + 1 bytes
143        // (with LIMB_BITS bits).
144        debug_assert!(
145            range_tuple_chip.sizes()[0] == 1 << LIMB_BITS,
146            "First element of RangeTupleChecker must have size {}",
147            1 << LIMB_BITS
148        );
149        debug_assert!(
150            range_tuple_chip.sizes()[1] >= (1 << LIMB_BITS) * NUM_LIMBS as u32,
151            "Second element of RangeTupleChecker must have size of at least {}",
152            (1 << LIMB_BITS) * NUM_LIMBS as u32
153        );
154
155        Self {
156            adapter,
157            offset,
158            range_tuple_chip,
159        }
160    }
161}
162
163impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
164    for MultiplicationExecutor<A, NUM_LIMBS, LIMB_BITS>
165where
166    F: PrimeField32,
167    A: 'static
168        + AdapterTraceExecutor<
169            F,
170            ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
171            WriteData: From<[[u8; NUM_LIMBS]; 1]>,
172        >,
173    for<'buf> RA: RecordArena<
174        'buf,
175        EmptyAdapterCoreLayout<F, A>,
176        (
177            A::RecordMut<'buf>,
178            &'buf mut MultiplicationCoreRecord<NUM_LIMBS, LIMB_BITS>,
179        ),
180    >,
181{
182    fn get_opcode_name(&self, opcode: usize) -> String {
183        format!("{:?}", MulOpcode::from_usize(opcode - self.offset))
184    }
185
186    fn execute(
187        &self,
188        state: VmStateMut<F, TracingMemory, RA>,
189        instruction: &Instruction<F>,
190    ) -> Result<(), ExecutionError> {
191        let Instruction { opcode, .. } = instruction;
192
193        debug_assert_eq!(
194            MulOpcode::from_usize(opcode.local_opcode_idx(self.offset)),
195            MulOpcode::MUL
196        );
197        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
198
199        A::start(*state.pc, state.memory, &mut adapter_record);
200
201        let [rs1, rs2] = self
202            .adapter
203            .read(state.memory, instruction, &mut adapter_record)
204            .into();
205
206        let (a, _) = run_mul::<NUM_LIMBS, LIMB_BITS>(&rs1, &rs2);
207
208        core_record.b = rs1;
209        core_record.c = rs2;
210
211        self.adapter
212            .write(state.memory, instruction, [a].into(), &mut adapter_record);
213
214        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
215        Ok(())
216    }
217}
218
219impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
220    for MultiplicationFiller<A, NUM_LIMBS, LIMB_BITS>
221where
222    F: PrimeField32,
223    A: 'static + AdapterTraceFiller<F>,
224{
225    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
226        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
227        // MultiplicationCoreCols::width() elements
228        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
229        self.adapter.fill_trace_row(mem_helper, adapter_row);
230        // SAFETY: core_row contains a valid MultiplicationCoreRecord written by the executor
231        // during trace generation
232        let record: &MultiplicationCoreRecord<NUM_LIMBS, LIMB_BITS> =
233            unsafe { get_record_from_slice(&mut core_row, ()) };
234
235        let core_row: &mut MultiplicationCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
236
237        let (a, carry) = run_mul::<NUM_LIMBS, LIMB_BITS>(&record.b, &record.c);
238
239        for (a, carry) in a.iter().zip(carry.iter()) {
240            self.range_tuple_chip.add_count(&[*a as u32, *carry]);
241        }
242
243        // write in reverse order
244        core_row.is_valid = F::ONE;
245        core_row.c = record.c.map(F::from_canonical_u8);
246        core_row.b = record.b.map(F::from_canonical_u8);
247        core_row.a = a.map(F::from_canonical_u8);
248    }
249}
250
251// returns mul, carry
252#[inline(always)]
253pub(super) fn run_mul<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
254    x: &[u8; NUM_LIMBS],
255    y: &[u8; NUM_LIMBS],
256) -> ([u8; NUM_LIMBS], [u32; NUM_LIMBS]) {
257    let mut result = [0u8; NUM_LIMBS];
258    let mut carry = [0u32; NUM_LIMBS];
259    for i in 0..NUM_LIMBS {
260        let mut res = 0u32;
261        if i > 0 {
262            res = carry[i - 1];
263        }
264        for j in 0..=i {
265            res += (x[j] as u32) * (y[i - j] as u32);
266        }
267        carry[i] = res >> LIMB_BITS;
268        res %= 1u32 << LIMB_BITS;
269        result[i] = res as u8;
270    }
271    (result, carry)
272}