openvm_rv32im_circuit/jalr/
cuda.rs

1use std::{mem::size_of, sync::Arc};
2
3use derive_new::new;
4use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
5use openvm_circuit_primitives::{
6    bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
7};
8use openvm_cuda_backend::{
9    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
10};
11use openvm_cuda_common::copy::MemCopyH2D;
12use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
13
14use crate::{
15    adapters::{Rv32JalrAdapterCols, Rv32JalrAdapterRecord, RV32_CELL_BITS},
16    cuda_abi::jalr_cuda::tracegen,
17    Rv32JalrCoreCols, Rv32JalrCoreRecord,
18};
19#[derive(new)]
20pub struct Rv32JalrChipGpu {
21    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
22    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
23    pub timestamp_max_bits: usize,
24}
25
26impl Chip<DenseRecordArena, GpuBackend> for Rv32JalrChipGpu {
27    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
28        const RECORD_SIZE: usize = size_of::<(Rv32JalrAdapterRecord, Rv32JalrCoreRecord)>();
29        let records = arena.allocated();
30        if records.is_empty() {
31            return get_empty_air_proving_ctx::<GpuBackend>();
32        }
33        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
34
35        let trace_width = Rv32JalrCoreCols::<F>::width() + Rv32JalrAdapterCols::<F>::width();
36        let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
37
38        let d_records = records.to_device().unwrap();
39        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
40
41        unsafe {
42            tracegen(
43                d_trace.buffer(),
44                trace_height,
45                &d_records,
46                &self.range_checker.count,
47                &self.bitwise_lookup.count,
48                RV32_CELL_BITS,
49                self.timestamp_max_bits as u32,
50            )
51            .unwrap();
52        }
53        AirProvingContext::simple_no_pis(d_trace)
54    }
55}