openvm_cuda_backend/
chip.rs

1use std::marker::PhantomData;
2
3use derive_new::new;
4use openvm_stark_backend::{
5    prover::{
6        cpu::CpuBackend,
7        hal::{MatrixDimensions, ProverBackend},
8        types::AirProvingContext,
9    },
10    Chip,
11};
12
13use crate::{data_transporter::transport_matrix_to_device, prover_backend::GpuBackend, types::SC};
14
15/// A struct that has the same memory layout as `uint2` to be used in FFI functions
16#[repr(C)]
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, new)]
18pub struct UInt2 {
19    pub x: u32,
20    pub y: u32,
21}
22
23pub fn get_empty_air_proving_ctx<PB: ProverBackend>() -> AirProvingContext<PB> {
24    AirProvingContext {
25        cached_mains: vec![],
26        common_main: None,
27        public_values: vec![],
28    }
29}
30
31// Wraps a CPU chip for use with GpuBackend
32pub struct HybridChip<RA, C: Chip<RA, CpuBackend<SC>>> {
33    pub cpu_chip: C,
34    _marker: PhantomData<RA>,
35}
36
37impl<RA, C: Chip<RA, CpuBackend<SC>>> HybridChip<RA, C> {
38    pub fn new(cpu_chip: C) -> Self {
39        Self {
40            cpu_chip,
41            _marker: PhantomData,
42        }
43    }
44}
45
46impl<RA, C: Chip<RA, CpuBackend<SC>>> Chip<RA, GpuBackend> for HybridChip<RA, C> {
47    fn generate_proving_ctx(&self, arena: RA) -> AirProvingContext<GpuBackend> {
48        let ctx = self.cpu_chip.generate_proving_ctx(arena);
49        cpu_proving_ctx_to_gpu(ctx)
50    }
51}
52
53pub fn cpu_proving_ctx_to_gpu(
54    cpu_ctx: AirProvingContext<CpuBackend<SC>>,
55) -> AirProvingContext<GpuBackend> {
56    assert!(
57        cpu_ctx.cached_mains.is_empty(),
58        "CPU to GPU transfer of cached traces not supported"
59    );
60    let trace = cpu_ctx
61        .common_main
62        .filter(|trace| trace.height() > 0)
63        .map(transport_matrix_to_device);
64    AirProvingContext {
65        cached_mains: vec![],
66        common_main: trace,
67        public_values: cpu_ctx.public_values,
68    }
69}