openvm_circuit/arch/testing/program/
cuda.rs

1use std::slice::from_raw_parts;
2
3use openvm_circuit::{
4    arch::{
5        instructions::instruction::Instruction, testing::program::ProgramTester, ExecutionState,
6    },
7    system::program::{ProgramBus, ProgramExecutionCols},
8    utils::next_power_of_two_or_zero,
9};
10use openvm_cuda_backend::{
11    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
12};
13use openvm_cuda_common::copy::MemCopyH2D;
14use openvm_stark_backend::{prover::types::AirProvingContext, Chip, ChipUsageGetter};
15
16use crate::cuda_abi::program_testing;
17
18pub struct DeviceProgramTester(ProgramTester<F>);
19
20impl DeviceProgramTester {
21    pub fn new(bus: ProgramBus) -> Self {
22        Self(ProgramTester::new(bus))
23    }
24
25    pub fn bus(&self) -> ProgramBus {
26        self.0.bus
27    }
28
29    pub fn execute(&mut self, instruction: &Instruction<F>, initial_state: &ExecutionState<u32>) {
30        self.0.execute(instruction, initial_state);
31    }
32}
33
34impl ChipUsageGetter for DeviceProgramTester {
35    fn air_name(&self) -> String {
36        self.0.air_name()
37    }
38
39    fn current_trace_height(&self) -> usize {
40        self.0.current_trace_height()
41    }
42
43    fn trace_width(&self) -> usize {
44        self.0.trace_width()
45    }
46}
47
48impl<RA> Chip<RA, GpuBackend> for DeviceProgramTester {
49    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
50        let height = next_power_of_two_or_zero(self.0.current_trace_height());
51        let width = self.0.trace_width();
52
53        if height == 0 {
54            return get_empty_air_proving_ctx();
55        }
56        let trace = DeviceMatrix::<F>::with_capacity(height, width);
57
58        let records = &self.0.records;
59        let num_records = records.len();
60
61        unsafe {
62            let bytes_size = num_records * size_of::<ProgramExecutionCols<F>>();
63            let records_bytes = from_raw_parts(records.as_ptr() as *const u8, bytes_size);
64            let records = records_bytes.to_device().unwrap();
65            program_testing::tracegen(trace.buffer(), height, width, &records, num_records)
66                .unwrap();
67        }
68        AirProvingContext::simple_no_pis(trace)
69    }
70}