openvm_circuit/system/cuda/
phantom.rs1use std::mem::size_of;
2
3use derive_new::new;
4use openvm_circuit::{
5 arch::DenseRecordArena,
6 system::phantom::{PhantomCols, PhantomRecord},
7 utils::next_power_of_two_or_zero,
8};
9use openvm_cuda_backend::{
10 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
11};
12use openvm_cuda_common::copy::MemCopyH2D;
13use openvm_stark_backend::{
14 prover::{hal::MatrixDimensions, types::AirProvingContext},
15 Chip,
16};
17
18use crate::cuda_abi::phantom;
19
20#[derive(new)]
21pub struct PhantomChipGPU;
22
23impl PhantomChipGPU {
24 pub fn trace_height(arena: &DenseRecordArena) -> usize {
25 let record_size = size_of::<PhantomRecord>();
26 let records_len = arena.allocated().len();
27 assert_eq!(records_len % record_size, 0);
28 records_len / record_size
29 }
30
31 pub fn trace_width() -> usize {
32 PhantomCols::<F>::width()
33 }
34}
35
36impl Chip<DenseRecordArena, GpuBackend> for PhantomChipGPU {
37 fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
38 let num_records = Self::trace_height(&arena);
39 if num_records == 0 {
40 return get_empty_air_proving_ctx();
41 }
42 let trace_height = next_power_of_two_or_zero(num_records);
43 let trace = DeviceMatrix::<F>::with_capacity(trace_height, Self::trace_width());
44 unsafe {
45 phantom::tracegen(
46 trace.buffer(),
47 trace.height(),
48 trace.width(),
49 &arena.allocated().to_device().unwrap(),
50 )
51 .expect("Failed to generate trace");
52 }
53 AirProvingContext::simple_no_pis(trace)
54 }
55}