openvm_rv32im_circuit/branch_eq/
cuda.rs1use std::{mem::size_of, sync::Arc};
2
3use derive_new::new;
4use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero};
5use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU;
6use openvm_cuda_backend::{
7 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
8};
9use openvm_cuda_common::copy::MemCopyH2D;
10use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
11
12use crate::{
13 adapters::{Rv32BranchAdapterCols, Rv32BranchAdapterRecord, RV32_REGISTER_NUM_LIMBS},
14 cuda_abi::beq_cuda::tracegen,
15 BranchEqualCoreCols, BranchEqualCoreRecord,
16};
17
18#[derive(new)]
19pub struct Rv32BranchEqualChipGpu {
20 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
21 pub timestamp_max_bits: usize,
22}
23
24impl Chip<DenseRecordArena, GpuBackend> for Rv32BranchEqualChipGpu {
25 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
26 const RECORD_SIZE: usize = size_of::<(
27 Rv32BranchAdapterRecord,
28 BranchEqualCoreRecord<RV32_REGISTER_NUM_LIMBS>,
29 )>();
30 let records = arena.allocated();
31 if records.is_empty() {
32 return get_empty_air_proving_ctx::<GpuBackend>();
33 }
34 debug_assert_eq!(records.len() % RECORD_SIZE, 0);
35
36 let trace_width = BranchEqualCoreCols::<F, RV32_REGISTER_NUM_LIMBS>::width()
37 + Rv32BranchAdapterCols::<F>::width();
38 let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE);
39
40 let d_records = records.to_device().unwrap();
41 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, trace_width);
42
43 unsafe {
44 tracegen(
45 d_trace.buffer(),
46 trace_height,
47 &d_records,
48 &self.range_checker.count,
49 self.timestamp_max_bits as u32,
50 )
51 .unwrap();
52 }
53 AirProvingContext::simple_no_pis(d_trace)
54 }
55}