openvm_circuit/arch/testing/execution/
cuda.rs1use std::slice::from_raw_parts;
2
3use openvm_circuit::{
4 arch::{
5 testing::{execution::air::DummyExecutionInteractionCols, ExecutionTester},
6 ExecutionBus, ExecutionState,
7 },
8 utils::next_power_of_two_or_zero,
9};
10use openvm_cuda_backend::{
11 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
12};
13use openvm_cuda_common::copy::MemCopyH2D;
14use openvm_stark_backend::{prover::types::AirProvingContext, Chip, ChipUsageGetter};
15
16use crate::cuda_abi::execution_testing;
17
18pub struct DeviceExecutionTester(pub(crate) ExecutionTester<F>);
19
20impl DeviceExecutionTester {
21 pub fn new(bus: ExecutionBus) -> Self {
22 Self(ExecutionTester::new(bus))
23 }
24
25 pub fn bus(&self) -> ExecutionBus {
26 self.0.bus
27 }
28
29 pub fn execute(
30 &mut self,
31 initial_state: ExecutionState<u32>,
32 final_state: ExecutionState<u32>,
33 ) {
34 self.0.execute(initial_state, final_state);
35 }
36}
37
38impl ChipUsageGetter for DeviceExecutionTester {
39 fn air_name(&self) -> String {
40 self.0.air_name()
41 }
42
43 fn current_trace_height(&self) -> usize {
44 self.0.current_trace_height()
45 }
46
47 fn trace_width(&self) -> usize {
48 self.0.trace_width()
49 }
50}
51
52impl<RA> Chip<RA, GpuBackend> for DeviceExecutionTester {
53 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
54 let height = next_power_of_two_or_zero(self.0.current_trace_height());
55 let width = self.0.trace_width();
56
57 if height == 0 {
58 return get_empty_air_proving_ctx();
59 }
60 let trace = DeviceMatrix::<F>::with_capacity(height, width);
61
62 let records = &self.0.records;
63 let num_records = records.len();
64
65 unsafe {
66 let bytes_size = num_records * size_of::<DummyExecutionInteractionCols<F>>();
67 let records_bytes = from_raw_parts(records.as_ptr() as *const u8, bytes_size);
68 let records = records_bytes.to_device().unwrap();
69 execution_testing::tracegen(trace.buffer(), height, width, &records).unwrap();
70 }
71 AirProvingContext::simple_no_pis(trace)
72 }
73}