1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7 arch::*,
8 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11 range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip},
12 AlignedBytesBorrow,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
16use openvm_rv32im_transpiler::MulOpcode;
17use openvm_stark_backend::{
18 interaction::InteractionBuilder,
19 p3_air::BaseAir,
20 p3_field::{Field, FieldAlgebra, PrimeField32},
21 rap::BaseAirWithPublicValues,
22};
23
24#[repr(C)]
25#[derive(AlignedBorrow)]
26pub struct MultiplicationCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
27 pub a: [T; NUM_LIMBS],
28 pub b: [T; NUM_LIMBS],
29 pub c: [T; NUM_LIMBS],
30 pub is_valid: T,
31}
32
33#[derive(Copy, Clone, Debug, derive_new::new)]
34pub struct MultiplicationCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
35 pub bus: RangeTupleCheckerBus<2>,
36 pub offset: usize,
37}
38
39impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
40 for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
41{
42 fn width(&self) -> usize {
43 MultiplicationCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
44 }
45}
46impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
47 for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
48{
49}
50
51impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
52 for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
53where
54 AB: InteractionBuilder,
55 I: VmAdapterInterface<AB::Expr>,
56 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
57 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
58 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
59{
60 fn eval(
61 &self,
62 builder: &mut AB,
63 local_core: &[AB::Var],
64 _from_pc: AB::Var,
65 ) -> AdapterAirContext<AB::Expr, I> {
66 let cols: &MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
67 builder.assert_bool(cols.is_valid);
68
69 let a = &cols.a;
70 let b = &cols.b;
71 let c = &cols.c;
72
73 let mut carry: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
77 let carry_divide = AB::F::from_canonical_u32(1 << LIMB_BITS).inverse();
78
79 for i in 0..NUM_LIMBS {
80 let expected_limb = if i == 0 {
81 AB::Expr::ZERO
82 } else {
83 carry[i - 1].clone()
84 } + (0..=i).fold(AB::Expr::ZERO, |acc, k| acc + (b[k] * c[i - k]));
85 carry[i] = AB::Expr::from(carry_divide) * (expected_limb - a[i]);
86 }
87
88 for (a, carry) in a.iter().zip(carry.iter()) {
89 self.bus
90 .send(vec![(*a).into(), carry.clone()])
91 .eval(builder, cols.is_valid);
92 }
93
94 let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, MulOpcode::MUL);
95
96 AdapterAirContext {
97 to_pc: None,
98 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
99 writes: [cols.a.map(Into::into)].into(),
100 instruction: MinimalInstruction {
101 is_valid: cols.is_valid.into(),
102 opcode: expected_opcode,
103 }
104 .into(),
105 }
106 }
107
108 fn start_offset(&self) -> usize {
109 self.offset
110 }
111}
112
113#[repr(C)]
114#[derive(AlignedBytesBorrow, Debug)]
115pub struct MultiplicationCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
116 pub b: [u8; NUM_LIMBS],
117 pub c: [u8; NUM_LIMBS],
118}
119
120#[derive(Clone, Copy, derive_new::new)]
121pub struct MultiplicationExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
122 adapter: A,
123 pub offset: usize,
124}
125
126#[derive(Clone, Debug)]
127pub struct MultiplicationFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
128 adapter: A,
129 pub offset: usize,
130 pub range_tuple_chip: SharedRangeTupleCheckerChip<2>,
131}
132
133impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize>
134 MultiplicationFiller<A, NUM_LIMBS, LIMB_BITS>
135{
136 pub fn new(
137 adapter: A,
138 range_tuple_chip: SharedRangeTupleCheckerChip<2>,
139 offset: usize,
140 ) -> Self {
141 debug_assert!(
145 range_tuple_chip.sizes()[0] == 1 << LIMB_BITS,
146 "First element of RangeTupleChecker must have size {}",
147 1 << LIMB_BITS
148 );
149 debug_assert!(
150 range_tuple_chip.sizes()[1] >= (1 << LIMB_BITS) * NUM_LIMBS as u32,
151 "Second element of RangeTupleChecker must have size of at least {}",
152 (1 << LIMB_BITS) * NUM_LIMBS as u32
153 );
154
155 Self {
156 adapter,
157 offset,
158 range_tuple_chip,
159 }
160 }
161}
162
163impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
164 for MultiplicationExecutor<A, NUM_LIMBS, LIMB_BITS>
165where
166 F: PrimeField32,
167 A: 'static
168 + AdapterTraceExecutor<
169 F,
170 ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
171 WriteData: From<[[u8; NUM_LIMBS]; 1]>,
172 >,
173 for<'buf> RA: RecordArena<
174 'buf,
175 EmptyAdapterCoreLayout<F, A>,
176 (
177 A::RecordMut<'buf>,
178 &'buf mut MultiplicationCoreRecord<NUM_LIMBS, LIMB_BITS>,
179 ),
180 >,
181{
182 fn get_opcode_name(&self, opcode: usize) -> String {
183 format!("{:?}", MulOpcode::from_usize(opcode - self.offset))
184 }
185
186 fn execute(
187 &self,
188 state: VmStateMut<F, TracingMemory, RA>,
189 instruction: &Instruction<F>,
190 ) -> Result<(), ExecutionError> {
191 let Instruction { opcode, .. } = instruction;
192
193 debug_assert_eq!(
194 MulOpcode::from_usize(opcode.local_opcode_idx(self.offset)),
195 MulOpcode::MUL
196 );
197 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
198
199 A::start(*state.pc, state.memory, &mut adapter_record);
200
201 let [rs1, rs2] = self
202 .adapter
203 .read(state.memory, instruction, &mut adapter_record)
204 .into();
205
206 let (a, _) = run_mul::<NUM_LIMBS, LIMB_BITS>(&rs1, &rs2);
207
208 core_record.b = rs1;
209 core_record.c = rs2;
210
211 self.adapter
212 .write(state.memory, instruction, [a].into(), &mut adapter_record);
213
214 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
215 Ok(())
216 }
217}
218
219impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
220 for MultiplicationFiller<A, NUM_LIMBS, LIMB_BITS>
221where
222 F: PrimeField32,
223 A: 'static + AdapterTraceFiller<F>,
224{
225 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
226 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
229 self.adapter.fill_trace_row(mem_helper, adapter_row);
230 let record: &MultiplicationCoreRecord<NUM_LIMBS, LIMB_BITS> =
233 unsafe { get_record_from_slice(&mut core_row, ()) };
234
235 let core_row: &mut MultiplicationCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
236
237 let (a, carry) = run_mul::<NUM_LIMBS, LIMB_BITS>(&record.b, &record.c);
238
239 for (a, carry) in a.iter().zip(carry.iter()) {
240 self.range_tuple_chip.add_count(&[*a as u32, *carry]);
241 }
242
243 core_row.is_valid = F::ONE;
245 core_row.c = record.c.map(F::from_canonical_u8);
246 core_row.b = record.b.map(F::from_canonical_u8);
247 core_row.a = a.map(F::from_canonical_u8);
248 }
249}
250
251#[inline(always)]
253pub(super) fn run_mul<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
254 x: &[u8; NUM_LIMBS],
255 y: &[u8; NUM_LIMBS],
256) -> ([u8; NUM_LIMBS], [u32; NUM_LIMBS]) {
257 let mut result = [0u8; NUM_LIMBS];
258 let mut carry = [0u32; NUM_LIMBS];
259 for i in 0..NUM_LIMBS {
260 let mut res = 0u32;
261 if i > 0 {
262 res = carry[i - 1];
263 }
264 for j in 0..=i {
265 res += (x[j] as u32) * (y[i - j] as u32);
266 }
267 carry[i] = res >> LIMB_BITS;
268 res %= 1u32 << LIMB_BITS;
269 result[i] = res as u8;
270 }
271 (result, carry)
272}