openvm_rv32im_circuit/mul/
cuda.rs

1use std::{mem::size_of, sync::Arc};
2
3use derive_new::new;
4use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
5use openvm_circuit_primitives::{
6    range_tuple::RangeTupleCheckerChipGPU, var_range::VariableRangeCheckerChipGPU,
7};
8use openvm_cuda_backend::{
9    base::DeviceMatrix,
10    chip::{get_empty_air_proving_ctx, UInt2},
11    prover_backend::GpuBackend,
12    types::F,
13};
14use openvm_cuda_common::copy::MemCopyH2D;
15use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
16
17use crate::{
18    adapters::{
19        Rv32MultAdapterCols, Rv32MultAdapterRecord, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
20    },
21    cuda_abi::mul_cuda::tracegen,
22    MultiplicationCoreCols, MultiplicationCoreRecord,
23};
24
25#[derive(new)]
26pub struct Rv32MultiplicationChipGpu {
27    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
28    pub range_tuple_checker: Arc<RangeTupleCheckerChipGPU<2>>,
29    pub timestamp_max_bits: usize,
30}
31
32impl Chip<DenseRecordArena, GpuBackend> for Rv32MultiplicationChipGpu {
33    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
34        const RECORD_SIZE: usize = size_of::<(
35            Rv32MultAdapterRecord,
36            MultiplicationCoreRecord<RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>,
37        )>();
38        let records = arena.allocated();
39        if records.is_empty() {
40            return get_empty_air_proving_ctx::<GpuBackend>();
41        }
42        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
43
44        let trace_width =
45            MultiplicationCoreCols::<F, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>::width()
46                + Rv32MultAdapterCols::<F>::width();
47
48        let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
49
50        let tuple_checker_sizes = self.range_tuple_checker.sizes;
51        let tuple_checker_sizes = UInt2::new(tuple_checker_sizes[0], tuple_checker_sizes[1]);
52
53        let d_records = records.to_device().unwrap();
54        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
55
56        unsafe {
57            tracegen(
58                d_trace.buffer(),
59                trace_height,
60                &d_records,
61                &self.range_checker.count,
62                self.range_checker.count.len(),
63                &self.range_tuple_checker.count,
64                tuple_checker_sizes,
65                self.timestamp_max_bits as u32,
66            )
67            .unwrap();
68        }
69
70        AirProvingContext::simple_no_pis(d_trace)
71    }
72}