openvm_circuit/system/cuda/
phantom.rs

1use std::mem::size_of;
2
3use derive_new::new;
4use openvm_circuit::{
5    arch::DenseRecordArena,
6    system::phantom::{PhantomCols, PhantomRecord},
7    utils::next_power_of_two_or_zero,
8};
9use openvm_cuda_backend::{
10    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
11};
12use openvm_cuda_common::copy::MemCopyH2D;
13use openvm_stark_backend::{
14    prover::{hal::MatrixDimensions, types::AirProvingContext},
15    Chip,
16};
17
18use crate::cuda_abi::phantom;
19
20#[derive(new)]
21pub struct PhantomChipGPU;
22
23impl PhantomChipGPU {
24    pub fn trace_height(arena: &DenseRecordArena) -> usize {
25        let record_size = size_of::<PhantomRecord>();
26        let records_len = arena.allocated().len();
27        assert_eq!(records_len % record_size, 0);
28        records_len / record_size
29    }
30
31    pub fn trace_width() -> usize {
32        PhantomCols::<F>::width()
33    }
34}
35
36impl Chip<DenseRecordArena, GpuBackend> for PhantomChipGPU {
37    fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
38        let num_records = Self::trace_height(&arena);
39        if num_records == 0 {
40            return get_empty_air_proving_ctx();
41        }
42        let trace_height = next_power_of_two_or_zero(num_records);
43        let trace = DeviceMatrix::<F>::with_capacity(trace_height, Self::trace_width());
44        unsafe {
45            phantom::tracegen(
46                trace.buffer(),
47                trace.height(),
48                trace.width(),
49                &arena.allocated().to_device().unwrap(),
50            )
51            .expect("Failed to generate trace");
52        }
53        AirProvingContext::simple_no_pis(trace)
54    }
55}