openvm_rv32im_circuit/mulh/
cuda.rs1use std::{mem::size_of, sync::Arc};
2
3use derive_new::new;
4use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
5use openvm_circuit_primitives::{
6 bitwise_op_lookup::BitwiseOperationLookupChipGPU, range_tuple::RangeTupleCheckerChipGPU,
7 var_range::VariableRangeCheckerChipGPU,
8};
9use openvm_cuda_backend::{
10 base::DeviceMatrix,
11 chip::{get_empty_air_proving_ctx, UInt2},
12 prover_backend::GpuBackend,
13 types::F,
14};
15use openvm_cuda_common::copy::MemCopyH2D;
16use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
17
18use crate::{
19 adapters::{
20 Rv32MultAdapterCols, Rv32MultAdapterRecord, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
21 },
22 cuda_abi::mulh_cuda::tracegen,
23 MulHCoreCols, MulHCoreRecord,
24};
25
26#[derive(new)]
27pub struct Rv32MulHChipGpu {
28 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
29 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
30 pub range_tuple_checker: Arc<RangeTupleCheckerChipGPU<2>>,
31 pub timestamp_max_bits: usize,
32}
33
34impl Chip<DenseRecordArena, GpuBackend> for Rv32MulHChipGpu {
35 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
36 const RECORD_SIZE: usize = size_of::<(
37 Rv32MultAdapterRecord,
38 MulHCoreRecord<RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>,
39 )>();
40 let records = arena.allocated();
41 if records.is_empty() {
42 return get_empty_air_proving_ctx::<GpuBackend>();
43 }
44 debug_assert_eq!(records.len() % RECORD_SIZE, 0);
45
46 let trace_width = MulHCoreCols::<F, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>::width()
47 + Rv32MultAdapterCols::<F>::width();
48 let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
49
50 let tuple_checker_sizes = self.range_tuple_checker.sizes;
51 let tuple_checker_sizes = UInt2::new(tuple_checker_sizes[0], tuple_checker_sizes[1]);
52
53 let d_records = records.to_device().unwrap();
54 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
55
56 unsafe {
57 tracegen(
58 d_trace.buffer(),
59 trace_height,
60 &d_records,
61 &self.range_checker.count,
62 &self.bitwise_lookup.count,
63 RV32_CELL_BITS,
64 &self.range_tuple_checker.count,
65 tuple_checker_sizes,
66 self.timestamp_max_bits as u32,
67 )
68 .unwrap();
69 }
70
71 AirProvingContext::simple_no_pis(d_trace)
72 }
73}