openvm_circuit/system/cuda/
public_values.rs1use std::{mem::size_of, sync::Arc};
2
3use openvm_circuit::{
4 arch::DenseRecordArena,
5 system::{
6 native_adapter::{NativeAdapterCols, NativeAdapterRecord},
7 public_values::PublicValuesRecord,
8 },
9 utils::next_power_of_two_or_zero,
10};
11use openvm_circuit_primitives::{encoder::Encoder, var_range::VariableRangeCheckerChipGPU};
12use openvm_cuda_backend::{
13 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend,
14};
15use openvm_cuda_common::copy::MemCopyH2D;
16use openvm_stark_backend::{
17 prover::{hal::MatrixDimensions, types::AirProvingContext},
18 Chip,
19};
20
21use crate::cuda_abi::public_values;
22
23#[repr(C)]
24struct FullPublicValuesRecord {
25 #[allow(unused)]
26 adapter: NativeAdapterRecord<F, 2, 0>,
27 #[allow(unused)]
28 core: PublicValuesRecord<F>,
29}
30
31pub struct PublicValuesChipGPU {
32 pub range_checker: Arc<VariableRangeCheckerChipGPU>,
33 pub public_values: Vec<F>,
34 pub num_custom_pvs: usize,
35 pub max_degree: u32,
36 encoder: Encoder,
38 pub timestamp_max_bits: u32,
39}
40
41impl PublicValuesChipGPU {
42 pub fn new(
43 range_checker: Arc<VariableRangeCheckerChipGPU>,
44 num_custom_pvs: usize,
45 max_degree: u32,
46 timestamp_max_bits: u32,
47 ) -> Self {
48 Self {
49 range_checker,
50 public_values: Vec::new(),
51 num_custom_pvs,
52 max_degree,
53 encoder: Encoder::new(num_custom_pvs, max_degree, true),
54 timestamp_max_bits,
55 }
56 }
57}
58
59impl PublicValuesChipGPU {
60 pub fn trace_height(arena: &DenseRecordArena) -> usize {
61 let record_size = size_of::<FullPublicValuesRecord>();
62 let records_len = arena.allocated().len();
63 assert_eq!(records_len % record_size, 0);
64 records_len / record_size
65 }
66
67 pub fn trace_width(&self) -> usize {
68 NativeAdapterCols::<u8, 2, 0>::width() + 3 + self.encoder.width()
69 }
70}
71
72impl Chip<DenseRecordArena, GpuBackend> for PublicValuesChipGPU {
73 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
74 let num_records = Self::trace_height(&arena);
75 if num_records == 0 {
76 return get_empty_air_proving_ctx();
77 }
78 let trace_height = next_power_of_two_or_zero(num_records);
79 let trace = DeviceMatrix::<F>::with_capacity(trace_height, self.trace_width());
80 unsafe {
81 public_values::tracegen(
82 trace.buffer(),
83 trace.height(),
84 trace.width(),
85 &arena.allocated().to_device().unwrap(),
86 &self.range_checker.count,
87 self.timestamp_max_bits,
88 self.num_custom_pvs,
89 self.max_degree,
90 )
91 .expect("Failed to generate trace");
92 }
93 AirProvingContext::simple(trace, self.public_values.clone())
94 }
95}