openvm_cuda_backend/
chip.rs1use std::marker::PhantomData;
2
3use derive_new::new;
4use openvm_stark_backend::{
5 prover::{
6 cpu::CpuBackend,
7 hal::{MatrixDimensions, ProverBackend},
8 types::AirProvingContext,
9 },
10 Chip,
11};
12
13use crate::{data_transporter::transport_matrix_to_device, prover_backend::GpuBackend, types::SC};
14
15#[repr(C)]
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, new)]
18pub struct UInt2 {
19 pub x: u32,
20 pub y: u32,
21}
22
23pub fn get_empty_air_proving_ctx<PB: ProverBackend>() -> AirProvingContext<PB> {
24 AirProvingContext {
25 cached_mains: vec![],
26 common_main: None,
27 public_values: vec![],
28 }
29}
30
31pub struct HybridChip<RA, C: Chip<RA, CpuBackend<SC>>> {
33 pub cpu_chip: C,
34 _marker: PhantomData<RA>,
35}
36
37impl<RA, C: Chip<RA, CpuBackend<SC>>> HybridChip<RA, C> {
38 pub fn new(cpu_chip: C) -> Self {
39 Self {
40 cpu_chip,
41 _marker: PhantomData,
42 }
43 }
44}
45
46impl<RA, C: Chip<RA, CpuBackend<SC>>> Chip<RA, GpuBackend> for HybridChip<RA, C> {
47 fn generate_proving_ctx(&self, arena: RA) -> AirProvingContext<GpuBackend> {
48 let ctx = self.cpu_chip.generate_proving_ctx(arena);
49 cpu_proving_ctx_to_gpu(ctx)
50 }
51}
52
53pub fn cpu_proving_ctx_to_gpu(
54 cpu_ctx: AirProvingContext<CpuBackend<SC>>,
55) -> AirProvingContext<GpuBackend> {
56 assert!(
57 cpu_ctx.cached_mains.is_empty(),
58 "CPU to GPU transfer of cached traces not supported"
59 );
60 let trace = cpu_ctx
61 .common_main
62 .filter(|trace| trace.height() > 0)
63 .map(transport_matrix_to_device);
64 AirProvingContext {
65 cached_mains: vec![],
66 common_main: trace,
67 public_values: cpu_ctx.public_values,
68 }
69}