openvm_rv32im_circuit/hintstore/
cuda.rs1use std::sync::Arc;
2
3use derive_new::new;
4use openvm_circuit::{
5 arch::{DenseRecordArena, RecordSeeker},
6 utils::next_power_of_two_or_zero,
7};
8use openvm_circuit_primitives::{
9 bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU,
10};
11use openvm_cuda_backend::{
12 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
13};
14use openvm_cuda_common::copy::MemCopyH2D;
15use openvm_instructions::riscv::RV32_CELL_BITS;
16use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
17
18use crate::{
19 cuda_abi::hintstore_cuda::tracegen, Rv32HintStoreCols, Rv32HintStoreLayout,
20 Rv32HintStoreRecordMut,
21};
22
23#[derive(new)]
24pub struct Rv32HintStoreChipGpu {
25 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
26 pub bitwise_lookup: Arc<BitwiseOperationLookupChipGPU<RV32_CELL_BITS>>,
27 pub pointer_max_bits: usize,
28 pub timestamp_max_bits: usize,
29}
30
31#[repr(C)]
33#[derive(new)]
34pub struct OffsetInfo {
35 pub record_offset: u32,
36 pub local_idx: u32,
37}
38
39impl Chip<DenseRecordArena, GpuBackend> for Rv32HintStoreChipGpu {
40 fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
41 let width = Rv32HintStoreCols::<u8>::width();
42 let records = arena.allocated_mut();
43 if records.is_empty() {
44 return get_empty_air_proving_ctx::<GpuBackend>();
45 }
46
47 let mut offsets = Vec::<OffsetInfo>::new();
48 let mut offset = 0;
49
50 while offset < records.len() {
51 let prev_offset = offset;
52 let record = RecordSeeker::<
53 DenseRecordArena,
54 Rv32HintStoreRecordMut,
55 Rv32HintStoreLayout,
56 >::get_record_at(&mut offset, records);
57 for idx in 0..record.inner.num_words {
58 offsets.push(OffsetInfo::new(prev_offset as u32, idx));
59 }
60 }
61
62 let d_records = records.to_device().unwrap();
63 let d_record_offsets = offsets.to_device().unwrap();
64
65 let trace_height = next_power_of_two_or_zero(offsets.len());
66 let d_trace = DeviceMatrix::<F>::with_capacity(trace_height, width);
67
68 unsafe {
69 tracegen(
70 d_trace.buffer(),
71 trace_height,
72 &d_records,
73 offsets.len(),
74 &d_record_offsets,
75 self.pointer_max_bits as u32,
76 &self.range_checker.count,
77 &self.bitwise_lookup.count,
78 RV32_CELL_BITS as u32,
79 self.timestamp_max_bits as u32,
80 )
81 .unwrap();
82 }
83
84 AirProvingContext::simple_no_pis(d_trace)
85 }
86}