openvm_circuit/system/cuda/
public_values.rs

1use std::{mem::size_of, sync::Arc};
2
3use openvm_circuit::{
4    arch::DenseRecordArena,
5    system::{
6        native_adapter::{NativeAdapterCols, NativeAdapterRecord},
7        public_values::PublicValuesRecord,
8    },
9    utils::next_power_of_two_or_zero,
10};
11use openvm_circuit_primitives::{encoder::Encoder, var_range::VariableRangeCheckerChipGPU};
12use openvm_cuda_backend::{
13    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend,
14};
15use openvm_cuda_common::copy::MemCopyH2D;
16use openvm_stark_backend::{
17    prover::{hal::MatrixDimensions, types::AirProvingContext},
18    Chip,
19};
20
21use crate::cuda_abi::public_values;
22
23#[repr(C)]
24struct FullPublicValuesRecord {
25    #[allow(unused)]
26    adapter: NativeAdapterRecord<F, 2, 0>,
27    #[allow(unused)]
28    core: PublicValuesRecord<F>,
29}
30
31pub struct PublicValuesChipGPU {
32    pub range_checker: Arc<VariableRangeCheckerChipGPU>,
33    pub public_values: Vec<F>,
34    pub num_custom_pvs: usize,
35    pub max_degree: u32,
36    // needed to compute the width of the trace
37    encoder: Encoder,
38    pub timestamp_max_bits: u32,
39}
40
41impl PublicValuesChipGPU {
42    pub fn new(
43        range_checker: Arc<VariableRangeCheckerChipGPU>,
44        num_custom_pvs: usize,
45        max_degree: u32,
46        timestamp_max_bits: u32,
47    ) -> Self {
48        Self {
49            range_checker,
50            public_values: Vec::new(),
51            num_custom_pvs,
52            max_degree,
53            encoder: Encoder::new(num_custom_pvs, max_degree, true),
54            timestamp_max_bits,
55        }
56    }
57}
58
59impl PublicValuesChipGPU {
60    pub fn trace_height(arena: &DenseRecordArena) -> usize {
61        let record_size = size_of::<FullPublicValuesRecord>();
62        let records_len = arena.allocated().len();
63        assert_eq!(records_len % record_size, 0);
64        records_len / record_size
65    }
66
67    pub fn trace_width(&self) -> usize {
68        NativeAdapterCols::<u8, 2, 0>::width() + 3 + self.encoder.width()
69    }
70}
71
72impl Chip<DenseRecordArena, GpuBackend> for PublicValuesChipGPU {
73    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
74        let num_records = Self::trace_height(&arena);
75        if num_records == 0 {
76            return get_empty_air_proving_ctx();
77        }
78        let trace_height = next_power_of_two_or_zero(num_records);
79        let trace = DeviceMatrix::<F>::with_capacity(trace_height, self.trace_width());
80        unsafe {
81            public_values::tracegen(
82                trace.buffer(),
83                trace.height(),
84                trace.width(),
85                &arena.allocated().to_device().unwrap(),
86                &self.range_checker.count,
87                self.timestamp_max_bits,
88                self.num_custom_pvs,
89                self.max_degree,
90            )
91            .expect("Failed to generate trace");
92        }
93        AirProvingContext::simple(trace, self.public_values.clone())
94    }
95}