openvm_rv32im_circuit/divrem/
cuda.rs

1use std::sync::Arc;
2
3use derive_new::new;
4use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
5use openvm_circuit_primitives::{
6    bitwise_op_lookup::BitwiseOperationLookupChipGPU, range_tuple::RangeTupleCheckerChipGPU,
7    var_range::VariableRangeCheckerChipGPU,
8};
9use openvm_cuda_backend::{
10    base::DeviceMatrix,
11    chip::{get_empty_air_proving_ctx, UInt2},
12    prover_backend::GpuBackend,
13    types::F,
14};
15use openvm_cuda_common::copy::MemCopyH2D;
16use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
17use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
18
19use crate::{
20    adapters::{Rv32MultAdapterCols, Rv32MultAdapterRecord},
21    cuda_abi::divrem_cuda::tracegen,
22    DivRemCoreCols, DivRemCoreRecord,
23};
24
25#[derive(new)]
26pub struct Rv32DivRemChipGpu {
27    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
28    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
29    pub range_tuple_checker: Arc<RangeTupleCheckerChipGPU<2>>,
30    pub pointer_max_bits: usize,
31    pub timestamp_max_bits: usize,
32}
33
34impl Chip<DenseRecordArena, GpuBackend> for Rv32DivRemChipGpu {
35    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
36        const RECORD_SIZE: usize = size_of::<(
37            Rv32MultAdapterRecord,
38            DivRemCoreRecord<RV32_REGISTER_NUM_LIMBS>,
39        )>();
40        let records = arena.allocated();
41        if records.is_empty() {
42            return get_empty_air_proving_ctx::<GpuBackend>();
43        }
44        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
45
46        let trace_width = DivRemCoreCols::<F, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>::width()
47            + Rv32MultAdapterCols::<F>::width();
48        let height = records.len() / RECORD_SIZE;
49        let padded_height = next_power_of_two_or_zero(height);
50
51        let tuple_checker_sizes = self.range_tuple_checker.sizes;
52        let tuple_checker_sizes = UInt2::new(tuple_checker_sizes[0], tuple_checker_sizes[1]);
53
54        let d_records = records.to_device().unwrap();
55        let d_trace = DeviceMatrix::<F>::with_capacity(padded_height, trace_width);
56        unsafe {
57            tracegen(
58                d_trace.buffer(),
59                padded_height,
60                trace_width,
61                &d_records,
62                &self.range_checker.count,
63                &self.bitwise_lookup.count,
64                RV32_CELL_BITS as u32,
65                &self.range_tuple_checker.count,
66                tuple_checker_sizes,
67                self.timestamp_max_bits as u32,
68            )
69            .unwrap();
70        }
71
72        AirProvingContext::simple_no_pis(d_trace)
73    }
74}