openvm_rv32im_circuit/mul/
core.rs
1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
8 VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip};
11use openvm_circuit_primitives_derive::AlignedBorrow;
12use openvm_instructions::{instruction::Instruction, LocalOpcode};
13use openvm_rv32im_transpiler::MulOpcode;
14use openvm_stark_backend::{
15 interaction::InteractionBuilder,
16 p3_air::BaseAir,
17 p3_field::{Field, FieldAlgebra, PrimeField32},
18 rap::BaseAirWithPublicValues,
19};
20use serde::{de::DeserializeOwned, Deserialize, Serialize};
21use serde_big_array::BigArray;
22
23#[repr(C)]
24#[derive(AlignedBorrow)]
25pub struct MultiplicationCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
26 pub a: [T; NUM_LIMBS],
27 pub b: [T; NUM_LIMBS],
28 pub c: [T; NUM_LIMBS],
29 pub is_valid: T,
30}
31
32#[derive(Copy, Clone, Debug)]
33pub struct MultiplicationCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
34 pub bus: RangeTupleCheckerBus<2>,
35 pub offset: usize,
36}
37
38impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
39 for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
40{
41 fn width(&self) -> usize {
42 MultiplicationCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
43 }
44}
45impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
46 for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
47{
48}
49
50impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
51 for MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>
52where
53 AB: InteractionBuilder,
54 I: VmAdapterInterface<AB::Expr>,
55 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
56 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
57 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
58{
59 fn eval(
60 &self,
61 builder: &mut AB,
62 local_core: &[AB::Var],
63 _from_pc: AB::Var,
64 ) -> AdapterAirContext<AB::Expr, I> {
65 let cols: &MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
66 builder.assert_bool(cols.is_valid);
67
68 let a = &cols.a;
69 let b = &cols.b;
70 let c = &cols.c;
71
72 let mut carry: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
75 let carry_divide = AB::F::from_canonical_u32(1 << LIMB_BITS).inverse();
76
77 for i in 0..NUM_LIMBS {
78 let expected_limb = if i == 0 {
79 AB::Expr::ZERO
80 } else {
81 carry[i - 1].clone()
82 } + (0..=i).fold(AB::Expr::ZERO, |acc, k| acc + (b[k] * c[i - k]));
83 carry[i] = AB::Expr::from(carry_divide) * (expected_limb - a[i]);
84 }
85
86 for (a, carry) in a.iter().zip(carry.iter()) {
87 self.bus
88 .send(vec![(*a).into(), carry.clone()])
89 .eval(builder, cols.is_valid);
90 }
91
92 let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, MulOpcode::MUL);
93
94 AdapterAirContext {
95 to_pc: None,
96 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
97 writes: [cols.a.map(Into::into)].into(),
98 instruction: MinimalInstruction {
99 is_valid: cols.is_valid.into(),
100 opcode: expected_opcode,
101 }
102 .into(),
103 }
104 }
105
106 fn start_offset(&self) -> usize {
107 self.offset
108 }
109}
110
111#[derive(Debug)]
112pub struct MultiplicationCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
113 pub air: MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>,
114 pub range_tuple_chip: SharedRangeTupleCheckerChip<2>,
115}
116
117impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> MultiplicationCoreChip<NUM_LIMBS, LIMB_BITS> {
118 pub fn new(range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize) -> Self {
119 debug_assert!(
123 range_tuple_chip.sizes()[0] == 1 << LIMB_BITS,
124 "First element of RangeTupleChecker must have size {}",
125 1 << LIMB_BITS
126 );
127 debug_assert!(
128 range_tuple_chip.sizes()[1] >= (1 << LIMB_BITS) * NUM_LIMBS as u32,
129 "Second element of RangeTupleChecker must have size of at least {}",
130 (1 << LIMB_BITS) * NUM_LIMBS as u32
131 );
132
133 Self {
134 air: MultiplicationCoreAir {
135 bus: *range_tuple_chip.bus(),
136 offset,
137 },
138 range_tuple_chip,
139 }
140 }
141}
142
143#[repr(C)]
144#[derive(Clone, Debug, Serialize, Deserialize)]
145#[serde(bound = "T: Serialize + DeserializeOwned")]
146pub struct MultiplicationCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
147 #[serde(with = "BigArray")]
148 pub a: [T; NUM_LIMBS],
149 #[serde(with = "BigArray")]
150 pub b: [T; NUM_LIMBS],
151 #[serde(with = "BigArray")]
152 pub c: [T; NUM_LIMBS],
153}
154
155impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
156 VmCoreChip<F, I> for MultiplicationCoreChip<NUM_LIMBS, LIMB_BITS>
157where
158 I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
159 I::Writes: From<[[F; NUM_LIMBS]; 1]>,
160{
161 type Record = MultiplicationCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
162 type Air = MultiplicationCoreAir<NUM_LIMBS, LIMB_BITS>;
163
164 #[allow(clippy::type_complexity)]
165 fn execute_instruction(
166 &self,
167 instruction: &Instruction<F>,
168 _from_pc: u32,
169 reads: I::Reads,
170 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
171 let Instruction { opcode, .. } = instruction;
172 assert_eq!(
173 MulOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)),
174 MulOpcode::MUL
175 );
176
177 let data: [[F; NUM_LIMBS]; 2] = reads.into();
178 let b = data[0].map(|x| x.as_canonical_u32());
179 let c = data[1].map(|y| y.as_canonical_u32());
180 let (a, carry) = run_mul::<NUM_LIMBS, LIMB_BITS>(&b, &c);
181
182 for (a, carry) in a.iter().zip(carry.iter()) {
183 self.range_tuple_chip.add_count(&[*a, *carry]);
184 }
185
186 let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]);
187 let record = MultiplicationCoreRecord {
188 a: a.map(F::from_canonical_u32),
189 b: data[0],
190 c: data[1],
191 };
192
193 Ok((output, record))
194 }
195
196 fn get_opcode_name(&self, opcode: usize) -> String {
197 format!("{:?}", MulOpcode::from_usize(opcode - self.air.offset))
198 }
199
200 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
201 let row_slice: &mut MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> =
202 row_slice.borrow_mut();
203 row_slice.a = record.a;
204 row_slice.b = record.b;
205 row_slice.c = record.c;
206 row_slice.is_valid = F::ONE;
207 }
208
209 fn air(&self) -> &Self::Air {
210 &self.air
211 }
212}
213
214pub(super) fn run_mul<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
216 x: &[u32; NUM_LIMBS],
217 y: &[u32; NUM_LIMBS],
218) -> ([u32; NUM_LIMBS], [u32; NUM_LIMBS]) {
219 let mut result = [0; NUM_LIMBS];
220 let mut carry = [0; NUM_LIMBS];
221 for i in 0..NUM_LIMBS {
222 if i > 0 {
223 result[i] = carry[i - 1];
224 }
225 for j in 0..=i {
226 result[i] += x[j] * y[i - j];
227 }
228 carry[i] = result[i] >> LIMB_BITS;
229 result[i] %= 1 << LIMB_BITS;
230 }
231 (result, carry)
232}