openvm_rv32im_circuit/load_sign_extend/
cuda.rs

1use std::sync::Arc;
2
3use derive_new::new;
4use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
5use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU;
6use openvm_cuda_backend::{
7    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
8};
9use openvm_cuda_common::copy::MemCopyH2D;
10use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS;
11use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
12
13use crate::{
14    adapters::{Rv32LoadStoreAdapterCols, Rv32LoadStoreAdapterRecord},
15    cuda_abi::load_sign_extend_cuda::tracegen,
16    LoadSignExtendCoreCols, LoadSignExtendCoreRecord,
17};
18
19#[derive(new)]
20pub struct Rv32LoadSignExtendChipGpu {
21    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
22    pub pointer_max_bits: usize,
23    pub timestamp_max_bits: usize,
24}
25
26impl Chip<DenseRecordArena, GpuBackend> for Rv32LoadSignExtendChipGpu {
27    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
28        const RECORD_SIZE: usize = size_of::<(
29            Rv32LoadStoreAdapterRecord,
30            LoadSignExtendCoreRecord<RV32_REGISTER_NUM_LIMBS>,
31        )>();
32        let records = arena.allocated();
33        if records.is_empty() {
34            return get_empty_air_proving_ctx::<GpuBackend>();
35        }
36        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
37
38        let trace_width = Rv32LoadStoreAdapterCols::<F>::width()
39            + LoadSignExtendCoreCols::<F, RV32_REGISTER_NUM_LIMBS>::width();
40        let height = records.len() / RECORD_SIZE;
41        let padded_height = next_power_of_two_or_zero(height);
42
43        let d_records = records.to_device().unwrap();
44        let d_trace = DeviceMatrix::<F>::with_capacity(padded_height, trace_width);
45
46        unsafe {
47            tracegen(
48                d_trace.buffer(),
49                padded_height,
50                trace_width,
51                &d_records,
52                self.pointer_max_bits,
53                &self.range_checker.count,
54                self.timestamp_max_bits as u32,
55            )
56            .unwrap();
57        }
58
59        AirProvingContext::simple_no_pis(d_trace)
60    }
61}