openvm_circuit/arch/testing/program/
cuda.rs1use std::slice::from_raw_parts;
2
3use openvm_circuit::{
4 arch::{
5 instructions::instruction::Instruction, testing::program::ProgramTester, ExecutionState,
6 },
7 system::program::{ProgramBus, ProgramExecutionCols},
8 utils::next_power_of_two_or_zero,
9};
10use openvm_cuda_backend::{
11 base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F,
12};
13use openvm_cuda_common::copy::MemCopyH2D;
14use openvm_stark_backend::{prover::types::AirProvingContext, Chip, ChipUsageGetter};
15
16use crate::cuda_abi::program_testing;
17
18pub struct DeviceProgramTester(ProgramTester<F>);
19
20impl DeviceProgramTester {
21 pub fn new(bus: ProgramBus) -> Self {
22 Self(ProgramTester::new(bus))
23 }
24
25 pub fn bus(&self) -> ProgramBus {
26 self.0.bus
27 }
28
29 pub fn execute(&mut self, instruction: &Instruction<F>, initial_state: &ExecutionState<u32>) {
30 self.0.execute(instruction, initial_state);
31 }
32}
33
34impl ChipUsageGetter for DeviceProgramTester {
35 fn air_name(&self) -> String {
36 self.0.air_name()
37 }
38
39 fn current_trace_height(&self) -> usize {
40 self.0.current_trace_height()
41 }
42
43 fn trace_width(&self) -> usize {
44 self.0.trace_width()
45 }
46}
47
48impl<RA> Chip<RA, GpuBackend> for DeviceProgramTester {
49 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
50 let height = next_power_of_two_or_zero(self.0.current_trace_height());
51 let width = self.0.trace_width();
52
53 if height == 0 {
54 return get_empty_air_proving_ctx();
55 }
56 let trace = DeviceMatrix::<F>::with_capacity(height, width);
57
58 let records = &self.0.records;
59 let num_records = records.len();
60
61 unsafe {
62 let bytes_size = num_records * size_of::<ProgramExecutionCols<F>>();
63 let records_bytes = from_raw_parts(records.as_ptr() as *const u8, bytes_size);
64 let records = records_bytes.to_device().unwrap();
65 program_testing::tracegen(trace.buffer(), height, width, &records, num_records)
66 .unwrap();
67 }
68 AirProvingContext::simple_no_pis(trace)
69 }
70}