openvm_circuit/arch/testing/execution/
cuda.rs

1use std::slice::from_raw_parts;
2
3use openvm_circuit::{
4    arch::{
5        testing::{execution::air::DummyExecutionInteractionCols, ExecutionTester},
6        ExecutionBus, ExecutionState,
7    },
8    utils::next_power_of_two_or_zero,
9};
10use openvm_cuda_backend::{
11    base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
12};
13use openvm_cuda_common::copy::MemCopyH2D;
14use openvm_stark_backend::{prover::types::AirProvingContext, Chip, ChipUsageGetter};
15
16use crate::cuda_abi::execution_testing;
17
18pub struct DeviceExecutionTester(pub(crate) ExecutionTester<F>);
19
20impl DeviceExecutionTester {
21    pub fn new(bus: ExecutionBus) -> Self {
22        Self(ExecutionTester::new(bus))
23    }
24
25    pub fn bus(&self) -> ExecutionBus {
26        self.0.bus
27    }
28
29    pub fn execute(
30        &mut self,
31        initial_state: ExecutionState<u32>,
32        final_state: ExecutionState<u32>,
33    ) {
34        self.0.execute(initial_state, final_state);
35    }
36}
37
38impl ChipUsageGetter for DeviceExecutionTester {
39    fn air_name(&self) -> String {
40        self.0.air_name()
41    }
42
43    fn current_trace_height(&self) -> usize {
44        self.0.current_trace_height()
45    }
46
47    fn trace_width(&self) -> usize {
48        self.0.trace_width()
49    }
50}
51
52impl<RA> Chip<RA, GpuBackend> for DeviceExecutionTester {
53    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
54        let height = next_power_of_two_or_zero(self.0.current_trace_height());
55        let width = self.0.trace_width();
56
57        if height == 0 {
58            return get_empty_air_proving_ctx();
59        }
60        let trace = DeviceMatrix::<F>::with_capacity(height, width);
61
62        let records = &self.0.records;
63        let num_records = records.len();
64
65        unsafe {
66            let bytes_size = num_records * size_of::<DummyExecutionInteractionCols<F>>();
67            let records_bytes = from_raw_parts(records.as_ptr() as *const u8, bytes_size);
68            let records = records_bytes.to_device().unwrap();
69            execution_testing::tracegen(trace.buffer(), height, width, &records).unwrap();
70        }
71        AirProvingContext::simple_no_pis(trace)
72    }
73}