openvm_rv32im_circuit/branch_eq/
cuda.rs

1use std::{mem::size_of, sync::Arc};
2
3use derive_new::new;
4use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
5use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU;
6use openvm_cuda_backend::{
7    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
8};
9use openvm_cuda_common::copy::MemCopyH2D;
10use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
11
12use crate::{
13    adapters::{Rv32BranchAdapterCols, Rv32BranchAdapterRecord, RV32_REGISTER_NUM_LIMBS},
14    cuda_abi::beq_cuda::tracegen,
15    BranchEqualCoreCols, BranchEqualCoreRecord,
16};
17
18#[derive(new)]
19pub struct Rv32BranchEqualChipGpu {
20    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
21    pub timestamp_max_bits: usize,
22}
23
24impl Chip<DenseRecordArena, GpuBackend> for Rv32BranchEqualChipGpu {
25    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
26        const RECORD_SIZE: usize = size_of::<(
27            Rv32BranchAdapterRecord,
28            BranchEqualCoreRecord<RV32_REGISTER_NUM_LIMBS>,
29        )>();
30        let records = arena.allocated();
31        if records.is_empty() {
32            return get_empty_air_proving_ctx::<GpuBackend>();
33        }
34        debug_assert_eq!(records.len() % RECORD_SIZE, 0);
35
36        let trace_width = BranchEqualCoreCols::<F, RV32_REGISTER_NUM_LIMBS>::width()
37            + Rv32BranchAdapterCols::<F>::width();
38        let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
39
40        let d_records = records.to_device().unwrap();
41        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
42
43        unsafe {
44            tracegen(
45                d_trace.buffer(),
46                trace_height,
47                &d_records,
48                &self.range_checker.count,
49                self.timestamp_max_bits as u32,
50            )
51            .unwrap();
52        }
53        AirProvingContext::simple_no_pis(d_trace)
54    }
55}