openvm_rv32im_circuit/mulh/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip},
13    AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
17use openvm_rv32im_transpiler::MulHOpcode;
18use openvm_stark_backend::{
19    interaction::InteractionBuilder,
20    p3_air::{AirBuilder, BaseAir},
21    p3_field::{Field, FieldAlgebra, PrimeField32},
22    rap::BaseAirWithPublicValues,
23};
24use strum::IntoEnumIterator;
25
26#[repr(C)]
27#[derive(AlignedBorrow)]
28pub struct MulHCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
29    pub a: [T; NUM_LIMBS],
30    pub b: [T; NUM_LIMBS],
31    pub c: [T; NUM_LIMBS],
32
33    pub a_mul: [T; NUM_LIMBS],
34    pub b_ext: T,
35    pub c_ext: T,
36
37    pub opcode_mulh_flag: T,
38    pub opcode_mulhsu_flag: T,
39    pub opcode_mulhu_flag: T,
40}
41
42#[derive(Copy, Clone, Debug, derive_new::new)]
43pub struct MulHCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
44    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
45    pub range_tuple_bus: RangeTupleCheckerBus<2>,
46}
47
48impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
49    for MulHCoreAir<NUM_LIMBS, LIMB_BITS>
50{
51    fn width(&self) -> usize {
52        MulHCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
53    }
54}
55impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
56    for MulHCoreAir<NUM_LIMBS, LIMB_BITS>
57{
58}
59
60impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
61    for MulHCoreAir<NUM_LIMBS, LIMB_BITS>
62where
63    AB: InteractionBuilder,
64    I: VmAdapterInterface<AB::Expr>,
65    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
66    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
67    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
68{
69    fn eval(
70        &self,
71        builder: &mut AB,
72        local_core: &[AB::Var],
73        _from_pc: AB::Var,
74    ) -> AdapterAirContext<AB::Expr, I> {
75        let cols: &MulHCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
76        let flags = [
77            cols.opcode_mulh_flag,
78            cols.opcode_mulhsu_flag,
79            cols.opcode_mulhu_flag,
80        ];
81
82        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
83            builder.assert_bool(flag);
84            acc + flag.into()
85        });
86        builder.assert_bool(is_valid.clone());
87
88        let b = &cols.b;
89        let c = &cols.c;
90        let carry_divide = AB::F::from_canonical_u32(1 << LIMB_BITS).inverse();
91
92        // Note b * c = a << LIMB_BITS + a_mul, in order to constrain that a is correct we
93        // need to compute the carries generated by a_mul.
94        let a_mul = &cols.a_mul;
95        let mut carry_mul: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
96
97        for i in 0..NUM_LIMBS {
98            let expected_limb = if i == 0 {
99                AB::Expr::ZERO
100            } else {
101                carry_mul[i - 1].clone()
102            } + (0..=i).fold(AB::Expr::ZERO, |ac, k| ac + (b[k] * c[i - k]));
103            carry_mul[i] = AB::Expr::from(carry_divide) * (expected_limb - a_mul[i]);
104        }
105
106        for (a_mul, carry_mul) in a_mul.iter().zip(carry_mul.iter()) {
107            self.range_tuple_bus
108                .send(vec![(*a_mul).into(), carry_mul.clone()])
109                .eval(builder, is_valid.clone());
110        }
111
112        // We can now constrain that a is correct using carry_mul[NUM_LIMBS - 1]
113        let a = &cols.a;
114        let mut carry: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
115
116        for j in 0..NUM_LIMBS {
117            let expected_limb = if j == 0 {
118                carry_mul[NUM_LIMBS - 1].clone()
119            } else {
120                carry[j - 1].clone()
121            } + ((j + 1)..NUM_LIMBS)
122                .fold(AB::Expr::ZERO, |acc, k| acc + (b[k] * c[NUM_LIMBS + j - k]))
123                + (0..(j + 1)).fold(AB::Expr::ZERO, |acc, k| {
124                    acc + (b[k] * cols.c_ext) + (c[k] * cols.b_ext)
125                });
126            carry[j] = AB::Expr::from(carry_divide) * (expected_limb - a[j]);
127        }
128
129        for (a, carry) in a.iter().zip(carry.iter()) {
130            self.range_tuple_bus
131                .send(vec![(*a).into(), carry.clone()])
132                .eval(builder, is_valid.clone());
133        }
134
135        // Check that b_ext and c_ext are correct using bitwise lookup. We check
136        // both b and c when the opcode is MULH, and only b when MULHSU.
137        let sign_mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
138        let ext_inv = AB::F::from_canonical_u32((1 << LIMB_BITS) - 1).inverse();
139        let b_sign = cols.b_ext * ext_inv;
140        let c_sign = cols.c_ext * ext_inv;
141
142        builder.assert_bool(b_sign.clone());
143        builder.assert_bool(c_sign.clone());
144        builder
145            .when(cols.opcode_mulhu_flag)
146            .assert_zero(b_sign.clone());
147        builder
148            .when(cols.opcode_mulhu_flag + cols.opcode_mulhsu_flag)
149            .assert_zero(c_sign.clone());
150
151        self.bitwise_lookup_bus
152            .send_range(
153                AB::Expr::from_canonical_u32(2) * (b[NUM_LIMBS - 1] - b_sign * sign_mask),
154                (cols.opcode_mulh_flag + AB::Expr::ONE) * (c[NUM_LIMBS - 1] - c_sign * sign_mask),
155            )
156            .eval(builder, cols.opcode_mulh_flag + cols.opcode_mulhsu_flag);
157
158        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
159            self,
160            flags.iter().zip(MulHOpcode::iter()).fold(
161                AB::Expr::ZERO,
162                |acc, (flag, local_opcode)| {
163                    acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
164                },
165            ),
166        );
167
168        AdapterAirContext {
169            to_pc: None,
170            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
171            writes: [cols.a.map(Into::into)].into(),
172            instruction: MinimalInstruction {
173                is_valid,
174                opcode: expected_opcode,
175            }
176            .into(),
177        }
178    }
179
180    fn start_offset(&self) -> usize {
181        MulHOpcode::CLASS_OFFSET
182    }
183}
184
185#[repr(C)]
186#[derive(AlignedBytesBorrow, Debug)]
187pub struct MulHCoreRecord<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
188    pub b: [u8; NUM_LIMBS],
189    pub c: [u8; NUM_LIMBS],
190    pub local_opcode: u8,
191}
192
193#[derive(Clone, Copy, derive_new::new)]
194pub struct MulHExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
195    adapter: A,
196    pub offset: usize,
197}
198
199#[derive(Clone)]
200pub struct MulHFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
201    adapter: A,
202    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
203    pub range_tuple_chip: SharedRangeTupleCheckerChip<2>,
204}
205
206impl<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> MulHFiller<A, NUM_LIMBS, LIMB_BITS> {
207    pub fn new(
208        adapter: A,
209        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
210        range_tuple_chip: SharedRangeTupleCheckerChip<2>,
211    ) -> Self {
212        // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i
213        // < 2 * NUM_LIMBS. a[i] must have LIMB_BITS bits and carry[i] is the sum of i + 1
214        // bytes (with LIMB_BITS bits). BitwiseOperationLookup is used to sign check bytes.
215        debug_assert!(
216            range_tuple_chip.sizes()[0] == 1 << LIMB_BITS,
217            "First element of RangeTupleChecker must have size {}",
218            1 << LIMB_BITS
219        );
220        debug_assert!(
221            range_tuple_chip.sizes()[1] >= (1 << LIMB_BITS) * 2 * NUM_LIMBS as u32,
222            "Second element of RangeTupleChecker must have size of at least {}",
223            (1 << LIMB_BITS) * 2 * NUM_LIMBS as u32
224        );
225
226        Self {
227            adapter,
228            bitwise_lookup_chip,
229            range_tuple_chip,
230        }
231    }
232}
233
234impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
235    for MulHExecutor<A, NUM_LIMBS, LIMB_BITS>
236where
237    F: PrimeField32,
238    A: 'static
239        + AdapterTraceExecutor<
240            F,
241            ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
242            WriteData: From<[[u8; NUM_LIMBS]; 1]>,
243        >,
244    for<'buf> RA: RecordArena<
245        'buf,
246        EmptyAdapterCoreLayout<F, A>,
247        (
248            A::RecordMut<'buf>,
249            &'buf mut MulHCoreRecord<NUM_LIMBS, LIMB_BITS>,
250        ),
251    >,
252{
253    fn get_opcode_name(&self, opcode: usize) -> String {
254        format!(
255            "{:?}",
256            MulHOpcode::from_usize(opcode - MulHOpcode::CLASS_OFFSET)
257        )
258    }
259
260    fn execute(
261        &self,
262        state: VmStateMut<F, TracingMemory, RA>,
263        instruction: &Instruction<F>,
264    ) -> Result<(), ExecutionError> {
265        let Instruction { opcode, .. } = instruction;
266
267        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
268
269        A::start(*state.pc, state.memory, &mut adapter_record);
270
271        core_record.local_opcode = opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET) as u8;
272        let mulh_opcode = MulHOpcode::from_usize(core_record.local_opcode as usize);
273
274        [core_record.b, core_record.c] = self
275            .adapter
276            .read(state.memory, instruction, &mut adapter_record)
277            .into();
278
279        let (a, _, _, _, _) = run_mulh::<NUM_LIMBS, LIMB_BITS>(
280            mulh_opcode,
281            &core_record.b.map(u32::from),
282            &core_record.c.map(u32::from),
283        );
284
285        let a = a.map(|x| x as u8);
286        self.adapter
287            .write(state.memory, instruction, [a].into(), &mut adapter_record);
288
289        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
290
291        Ok(())
292    }
293}
294
295impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
296    for MulHFiller<A, NUM_LIMBS, LIMB_BITS>
297where
298    F: PrimeField32,
299    A: 'static + AdapterTraceFiller<F>,
300{
301    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
302        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
303        // MulHCoreCols::width() elements
304        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
305        self.adapter.fill_trace_row(mem_helper, adapter_row);
306        // SAFETY: core_row contains a valid MulHCoreRecord written by the executor
307        // during trace generation
308        let record: &MulHCoreRecord<NUM_LIMBS, LIMB_BITS> =
309            unsafe { get_record_from_slice(&mut core_row, ()) };
310        let core_row: &mut MulHCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
311
312        let opcode = MulHOpcode::from_usize(record.local_opcode as usize);
313        let (a, a_mul, carry, b_ext, c_ext) = run_mulh::<NUM_LIMBS, LIMB_BITS>(
314            opcode,
315            &record.b.map(u32::from),
316            &record.c.map(u32::from),
317        );
318
319        for i in 0..NUM_LIMBS {
320            self.range_tuple_chip.add_count(&[a_mul[i], carry[i]]);
321            self.range_tuple_chip
322                .add_count(&[a[i], carry[NUM_LIMBS + i]]);
323        }
324
325        if opcode != MulHOpcode::MULHU {
326            let b_sign_mask = if b_ext == 0 { 0 } else { 1 << (LIMB_BITS - 1) };
327            let c_sign_mask = if c_ext == 0 { 0 } else { 1 << (LIMB_BITS - 1) };
328            self.bitwise_lookup_chip.request_range(
329                (record.b[NUM_LIMBS - 1] as u32 - b_sign_mask) << 1,
330                (record.c[NUM_LIMBS - 1] as u32 - c_sign_mask)
331                    << ((opcode == MulHOpcode::MULH) as u32),
332            );
333        }
334
335        // Write in reverse order
336        core_row.opcode_mulhu_flag = F::from_bool(opcode == MulHOpcode::MULHU);
337        core_row.opcode_mulhsu_flag = F::from_bool(opcode == MulHOpcode::MULHSU);
338        core_row.opcode_mulh_flag = F::from_bool(opcode == MulHOpcode::MULH);
339        core_row.c_ext = F::from_canonical_u32(c_ext);
340        core_row.b_ext = F::from_canonical_u32(b_ext);
341        core_row.a_mul = a_mul.map(F::from_canonical_u32);
342        core_row.c = record.c.map(F::from_canonical_u8);
343        core_row.b = record.b.map(F::from_canonical_u8);
344        core_row.a = a.map(F::from_canonical_u32);
345    }
346}
347
348// returns mulh[[s]u], mul, carry, x_ext, y_ext
349#[inline(always)]
350pub(super) fn run_mulh<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
351    opcode: MulHOpcode,
352    x: &[u32; NUM_LIMBS],
353    y: &[u32; NUM_LIMBS],
354) -> ([u32; NUM_LIMBS], [u32; NUM_LIMBS], Vec<u32>, u32, u32) {
355    let mut mul = [0; NUM_LIMBS];
356    let mut carry = vec![0; 2 * NUM_LIMBS];
357    for i in 0..NUM_LIMBS {
358        if i > 0 {
359            mul[i] = carry[i - 1];
360        }
361        for j in 0..=i {
362            mul[i] += x[j] * y[i - j];
363        }
364        carry[i] = mul[i] >> LIMB_BITS;
365        mul[i] %= 1 << LIMB_BITS;
366    }
367
368    let x_ext = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1))
369        * if opcode == MulHOpcode::MULHU {
370            0
371        } else {
372            (1 << LIMB_BITS) - 1
373        };
374    let y_ext = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1))
375        * if opcode == MulHOpcode::MULH {
376            (1 << LIMB_BITS) - 1
377        } else {
378            0
379        };
380
381    let mut mulh = [0; NUM_LIMBS];
382    let mut x_prefix = 0;
383    let mut y_prefix = 0;
384
385    for i in 0..NUM_LIMBS {
386        x_prefix += x[i];
387        y_prefix += y[i];
388        mulh[i] = carry[NUM_LIMBS + i - 1] + x_prefix * y_ext + y_prefix * x_ext;
389        for j in (i + 1)..NUM_LIMBS {
390            mulh[i] += x[j] * y[NUM_LIMBS + i - j];
391        }
392        carry[NUM_LIMBS + i] = mulh[i] >> LIMB_BITS;
393        mulh[i] %= 1 << LIMB_BITS;
394    }
395
396    (mulh, mul, carry, x_ext, y_ext)
397}