openvm_rv32im_circuit/hintstore/
cuda.rs

1use std::sync::Arc;
2
3use derive_new::new;
4use openvm_circuit::{
5    arch::{DenseRecordArena, RecordSeeker},
6    utils::next_power_of_two_or_zero,
7};
8use openvm_circuit_primitives::{
9    bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
10};
11use openvm_cuda_backend::{
12    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
13};
14use openvm_cuda_common::copy::MemCopyH2D;
15use openvm_instructions::riscv::RV32_CELL_BITS;
16use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
17
18use crate::{
19    cuda_abi::hintstore_cuda::tracegen, Rv32HintStoreCols, Rv32HintStoreLayout,
20    Rv32HintStoreRecordMut,
21};
22
23#[derive(new)]
24pub struct Rv32HintStoreChipGpu {
25    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
26    pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
27    pub pointer_max_bits: usize,
28    pub timestamp_max_bits: usize,
29}
30
31// This is the info needed by each row to do parallel tracegen
32#[repr(C)]
33#[derive(new)]
34pub struct OffsetInfo {
35    pub record_offset: u32,
36    pub local_idx: u32,
37}
38
39impl Chip<DenseRecordArena, GpuBackend> for Rv32HintStoreChipGpu {
40    fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
41        let width = Rv32HintStoreCols::<u8>::width();
42        let records = arena.allocated_mut();
43        if records.is_empty() {
44            return get_empty_air_proving_ctx::<GpuBackend>();
45        }
46
47        let mut offsets = Vec::<OffsetInfo>::new();
48        let mut offset = 0;
49
50        while offset < records.len() {
51            let prev_offset = offset;
52            let record = RecordSeeker::<
53                DenseRecordArena,
54                Rv32HintStoreRecordMut,
55                Rv32HintStoreLayout,
56            >::get_record_at(&mut offset, records);
57            for idx in 0..record.inner.num_words {
58                offsets.push(OffsetInfo::new(prev_offset as u32, idx));
59            }
60        }
61
62        let d_records = records.to_device().unwrap();
63        let d_record_offsets = offsets.to_device().unwrap();
64
65        let trace_height = next_power_of_two_or_zero(offsets.len());
66        let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, width);
67
68        unsafe {
69            tracegen(
70                d_trace.buffer(),
71                trace_height,
72                &d_records,
73                offsets.len(),
74                &d_record_offsets,
75                self.pointer_max_bits as u32,
76                &self.range_checker.count,
77                &self.bitwise_lookup.count,
78                RV32_CELL_BITS as u32,
79                self.timestamp_max_bits as u32,
80            )
81            .unwrap();
82        }
83
84        AirProvingContext::simple_no_pis(d_trace)
85    }
86}