1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
8 VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12 range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip},
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{instruction::Instruction, LocalOpcode};
16use openvm_rv32im_transpiler::MulHOpcode;
17use openvm_stark_backend::{
18 interaction::InteractionBuilder,
19 p3_air::{AirBuilder, BaseAir},
20 p3_field::{Field, FieldAlgebra, PrimeField32},
21 rap::BaseAirWithPublicValues,
22};
23use serde::{Deserialize, Serialize};
24use serde_big_array::BigArray;
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow)]
29pub struct MulHCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30 pub a: [T; NUM_LIMBS],
31 pub b: [T; NUM_LIMBS],
32 pub c: [T; NUM_LIMBS],
33
34 pub a_mul: [T; NUM_LIMBS],
35 pub b_ext: T,
36 pub c_ext: T,
37
38 pub opcode_mulh_flag: T,
39 pub opcode_mulhsu_flag: T,
40 pub opcode_mulhu_flag: T,
41}
42
43#[derive(Copy, Clone, Debug)]
44pub struct MulHCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
45 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
46 pub range_tuple_bus: RangeTupleCheckerBus<2>,
47}
48
49impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
50 for MulHCoreAir<NUM_LIMBS, LIMB_BITS>
51{
52 fn width(&self) -> usize {
53 MulHCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
54 }
55}
56impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
57 for MulHCoreAir<NUM_LIMBS, LIMB_BITS>
58{
59}
60
61impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
62 for MulHCoreAir<NUM_LIMBS, LIMB_BITS>
63where
64 AB: InteractionBuilder,
65 I: VmAdapterInterface<AB::Expr>,
66 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
67 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
68 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
69{
70 fn eval(
71 &self,
72 builder: &mut AB,
73 local_core: &[AB::Var],
74 _from_pc: AB::Var,
75 ) -> AdapterAirContext<AB::Expr, I> {
76 let cols: &MulHCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
77 let flags = [
78 cols.opcode_mulh_flag,
79 cols.opcode_mulhsu_flag,
80 cols.opcode_mulhu_flag,
81 ];
82
83 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
84 builder.assert_bool(flag);
85 acc + flag.into()
86 });
87 builder.assert_bool(is_valid.clone());
88
89 let b = &cols.b;
90 let c = &cols.c;
91 let carry_divide = AB::F::from_canonical_u32(1 << LIMB_BITS).inverse();
92
93 let a_mul = &cols.a_mul;
96 let mut carry_mul: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
97
98 for i in 0..NUM_LIMBS {
99 let expected_limb = if i == 0 {
100 AB::Expr::ZERO
101 } else {
102 carry_mul[i - 1].clone()
103 } + (0..=i).fold(AB::Expr::ZERO, |ac, k| ac + (b[k] * c[i - k]));
104 carry_mul[i] = AB::Expr::from(carry_divide) * (expected_limb - a_mul[i]);
105 }
106
107 for (a_mul, carry_mul) in a_mul.iter().zip(carry_mul.iter()) {
108 self.range_tuple_bus
109 .send(vec![(*a_mul).into(), carry_mul.clone()])
110 .eval(builder, is_valid.clone());
111 }
112
113 let a = &cols.a;
115 let mut carry: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
116
117 for j in 0..NUM_LIMBS {
118 let expected_limb = if j == 0 {
119 carry_mul[NUM_LIMBS - 1].clone()
120 } else {
121 carry[j - 1].clone()
122 } + ((j + 1)..NUM_LIMBS)
123 .fold(AB::Expr::ZERO, |acc, k| acc + (b[k] * c[NUM_LIMBS + j - k]))
124 + (0..(j + 1)).fold(AB::Expr::ZERO, |acc, k| {
125 acc + (b[k] * cols.c_ext) + (c[k] * cols.b_ext)
126 });
127 carry[j] = AB::Expr::from(carry_divide) * (expected_limb - a[j]);
128 }
129
130 for (a, carry) in a.iter().zip(carry.iter()) {
131 self.range_tuple_bus
132 .send(vec![(*a).into(), carry.clone()])
133 .eval(builder, is_valid.clone());
134 }
135
136 let sign_mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
139 let ext_inv = AB::F::from_canonical_u32((1 << LIMB_BITS) - 1).inverse();
140 let b_sign = cols.b_ext * ext_inv;
141 let c_sign = cols.c_ext * ext_inv;
142
143 builder.assert_bool(b_sign.clone());
144 builder.assert_bool(c_sign.clone());
145 builder
146 .when(cols.opcode_mulhu_flag)
147 .assert_zero(b_sign.clone());
148 builder
149 .when(cols.opcode_mulhu_flag + cols.opcode_mulhsu_flag)
150 .assert_zero(c_sign.clone());
151
152 self.bitwise_lookup_bus
153 .send_range(
154 AB::Expr::from_canonical_u32(2) * (b[NUM_LIMBS - 1] - b_sign * sign_mask),
155 (cols.opcode_mulh_flag + AB::Expr::ONE) * (c[NUM_LIMBS - 1] - c_sign * sign_mask),
156 )
157 .eval(builder, cols.opcode_mulh_flag + cols.opcode_mulhsu_flag);
158
159 let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
160 self,
161 flags.iter().zip(MulHOpcode::iter()).fold(
162 AB::Expr::ZERO,
163 |acc, (flag, local_opcode)| {
164 acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
165 },
166 ),
167 );
168
169 AdapterAirContext {
170 to_pc: None,
171 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
172 writes: [cols.a.map(Into::into)].into(),
173 instruction: MinimalInstruction {
174 is_valid,
175 opcode: expected_opcode,
176 }
177 .into(),
178 }
179 }
180
181 fn start_offset(&self) -> usize {
182 MulHOpcode::CLASS_OFFSET
183 }
184}
185
186pub struct MulHCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
187 pub air: MulHCoreAir<NUM_LIMBS, LIMB_BITS>,
188 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
189 pub range_tuple_chip: SharedRangeTupleCheckerChip<2>,
190}
191
192impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> MulHCoreChip<NUM_LIMBS, LIMB_BITS> {
193 pub fn new(
194 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
195 range_tuple_chip: SharedRangeTupleCheckerChip<2>,
196 ) -> Self {
197 debug_assert!(
201 range_tuple_chip.sizes()[0] == 1 << LIMB_BITS,
202 "First element of RangeTupleChecker must have size {}",
203 1 << LIMB_BITS
204 );
205 debug_assert!(
206 range_tuple_chip.sizes()[1] >= (1 << LIMB_BITS) * 2 * NUM_LIMBS as u32,
207 "Second element of RangeTupleChecker must have size of at least {}",
208 (1 << LIMB_BITS) * 2 * NUM_LIMBS as u32
209 );
210
211 Self {
212 air: MulHCoreAir {
213 bitwise_lookup_bus: bitwise_lookup_chip.bus(),
214 range_tuple_bus: *range_tuple_chip.bus(),
215 },
216 bitwise_lookup_chip,
217 range_tuple_chip,
218 }
219 }
220}
221
222#[repr(C)]
223#[derive(Clone, Debug, Serialize, Deserialize)]
224pub struct MulHCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
225 pub opcode: MulHOpcode,
226 #[serde(with = "BigArray")]
227 pub a: [T; NUM_LIMBS],
228 #[serde(with = "BigArray")]
229 pub b: [T; NUM_LIMBS],
230 #[serde(with = "BigArray")]
231 pub c: [T; NUM_LIMBS],
232 #[serde(with = "BigArray")]
233 pub a_mul: [T; NUM_LIMBS],
234 pub b_ext: T,
235 pub c_ext: T,
236}
237
238impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
239 VmCoreChip<F, I> for MulHCoreChip<NUM_LIMBS, LIMB_BITS>
240where
241 I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
242 I::Writes: From<[[F; NUM_LIMBS]; 1]>,
243{
244 type Record = MulHCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
245 type Air = MulHCoreAir<NUM_LIMBS, LIMB_BITS>;
246
247 #[allow(clippy::type_complexity)]
248 fn execute_instruction(
249 &self,
250 instruction: &Instruction<F>,
251 _from_pc: u32,
252 reads: I::Reads,
253 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
254 let Instruction { opcode, .. } = instruction;
255 let mulh_opcode = MulHOpcode::from_usize(opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET));
256
257 let data: [[F; NUM_LIMBS]; 2] = reads.into();
258 let b = data[0].map(|x| x.as_canonical_u32());
259 let c = data[1].map(|y| y.as_canonical_u32());
260 let (a, a_mul, carry, b_ext, c_ext) = run_mulh::<NUM_LIMBS, LIMB_BITS>(mulh_opcode, &b, &c);
261
262 for i in 0..NUM_LIMBS {
263 self.range_tuple_chip.add_count(&[a_mul[i], carry[i]]);
264 self.range_tuple_chip
265 .add_count(&[a[i], carry[NUM_LIMBS + i]]);
266 }
267
268 if mulh_opcode != MulHOpcode::MULHU {
269 let b_sign_mask = if b_ext == 0 { 0 } else { 1 << (LIMB_BITS - 1) };
270 let c_sign_mask = if c_ext == 0 { 0 } else { 1 << (LIMB_BITS - 1) };
271 self.bitwise_lookup_chip.request_range(
272 (b[NUM_LIMBS - 1] - b_sign_mask) << 1,
273 (c[NUM_LIMBS - 1] - c_sign_mask) << ((mulh_opcode == MulHOpcode::MULH) as u32),
274 );
275 }
276
277 let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]);
278 let record = MulHCoreRecord {
279 opcode: mulh_opcode,
280 a: a.map(F::from_canonical_u32),
281 b: data[0],
282 c: data[1],
283 a_mul: a_mul.map(F::from_canonical_u32),
284 b_ext: F::from_canonical_u32(b_ext),
285 c_ext: F::from_canonical_u32(c_ext),
286 };
287
288 Ok((output, record))
289 }
290
291 fn get_opcode_name(&self, opcode: usize) -> String {
292 format!(
293 "{:?}",
294 MulHOpcode::from_usize(opcode - MulHOpcode::CLASS_OFFSET)
295 )
296 }
297
298 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
299 let row_slice: &mut MulHCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
300 row_slice.a = record.a;
301 row_slice.b = record.b;
302 row_slice.c = record.c;
303 row_slice.a_mul = record.a_mul;
304 row_slice.b_ext = record.b_ext;
305 row_slice.c_ext = record.c_ext;
306 row_slice.opcode_mulh_flag = F::from_bool(record.opcode == MulHOpcode::MULH);
307 row_slice.opcode_mulhsu_flag = F::from_bool(record.opcode == MulHOpcode::MULHSU);
308 row_slice.opcode_mulhu_flag = F::from_bool(record.opcode == MulHOpcode::MULHU);
309 }
310
311 fn air(&self) -> &Self::Air {
312 &self.air
313 }
314}
315
316pub(super) fn run_mulh<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
318 opcode: MulHOpcode,
319 x: &[u32; NUM_LIMBS],
320 y: &[u32; NUM_LIMBS],
321) -> ([u32; NUM_LIMBS], [u32; NUM_LIMBS], Vec<u32>, u32, u32) {
322 let mut mul = [0; NUM_LIMBS];
323 let mut carry = vec![0; 2 * NUM_LIMBS];
324 for i in 0..NUM_LIMBS {
325 if i > 0 {
326 mul[i] = carry[i - 1];
327 }
328 for j in 0..=i {
329 mul[i] += x[j] * y[i - j];
330 }
331 carry[i] = mul[i] >> LIMB_BITS;
332 mul[i] %= 1 << LIMB_BITS;
333 }
334
335 let x_ext = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1))
336 * if opcode == MulHOpcode::MULHU {
337 0
338 } else {
339 (1 << LIMB_BITS) - 1
340 };
341 let y_ext = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1))
342 * if opcode == MulHOpcode::MULH {
343 (1 << LIMB_BITS) - 1
344 } else {
345 0
346 };
347
348 let mut mulh = [0; NUM_LIMBS];
349 let mut x_prefix = 0;
350 let mut y_prefix = 0;
351
352 for i in 0..NUM_LIMBS {
353 x_prefix += x[i];
354 y_prefix += y[i];
355 mulh[i] = carry[NUM_LIMBS + i - 1] + x_prefix * y_ext + y_prefix * x_ext;
356 for j in (i + 1)..NUM_LIMBS {
357 mulh[i] += x[j] * y[NUM_LIMBS + i - j];
358 }
359 carry[NUM_LIMBS + i] = mulh[i] >> LIMB_BITS;
360 mulh[i] %= 1 << LIMB_BITS;
361 }
362
363 (mulh, mul, carry, x_ext, y_ext)
364}