openvm_cuda_backend/
engine.rs

1use openvm_cuda_common::memory_manager::MemTracker;
2#[cfg(feature = "touchemall")]
3use openvm_stark_backend::prover::types::AirProvingContext;
4use openvm_stark_backend::{
5    config::StarkGenericConfig,
6    proof::Proof,
7    prover::{
8        coordinator::Coordinator,
9        types::{DeviceMultiStarkProvingKey, ProvingContext},
10        Prover,
11    },
12};
13use openvm_stark_sdk::{
14    config::{
15        baby_bear_poseidon2::{config_from_perm, default_perm, BabyBearPoseidon2Config},
16        fri_params::SecurityParameters,
17        log_up_params::log_up_security_params_baby_bear_100_bits,
18        FriParameters,
19    },
20    engine::{StarkEngine, StarkFriEngine},
21};
22use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
23use p3_field::Field;
24
25use crate::{
26    fri_log_up::FriLogUpPhaseGpu,
27    gpu_device::{GpuConfig, GpuDevice},
28    prelude::{SC, WIDTH},
29    prover_backend::GpuBackend,
30};
31
32pub type MultiTraceStarkProverGPU = Coordinator<SC, GpuBackend, GpuDevice>;
33
34pub struct GpuBabyBearPoseidon2Engine {
35    device: GpuDevice,
36    config: BabyBearPoseidon2Config,
37    perm: Poseidon2BabyBear<WIDTH>,
38}
39
40impl StarkFriEngine for GpuBabyBearPoseidon2Engine {
41    fn new(fri_params: FriParameters) -> Self {
42        let perm = default_perm();
43        let log_up_params = log_up_security_params_baby_bear_100_bits();
44        Self {
45            device: GpuDevice::new(
46                GpuConfig::new(fri_params, BabyBear::GENERATOR),
47                Some(FriLogUpPhaseGpu::new(log_up_params.clone())),
48            ),
49            config: config_from_perm(
50                &perm,
51                SecurityParameters {
52                    fri_params,
53                    log_up_params,
54                },
55            ),
56            perm,
57        }
58    }
59    fn fri_params(&self) -> FriParameters {
60        self.device.config.fri
61    }
62}
63
64impl StarkEngine for GpuBabyBearPoseidon2Engine {
65    type SC = BabyBearPoseidon2Config;
66    type PB = GpuBackend;
67    type PD = GpuDevice;
68
69    fn config(&self) -> &SC {
70        &self.config
71    }
72
73    fn max_constraint_degree(&self) -> Option<usize> {
74        Some(self.device.config.fri.max_constraint_degree())
75    }
76
77    fn new_challenger(&self) -> <SC as StarkGenericConfig>::Challenger {
78        <SC as StarkGenericConfig>::Challenger::new(self.perm.clone())
79    }
80
81    fn device(&self) -> &Self::PD {
82        &self.device
83    }
84
85    fn prover(&self) -> MultiTraceStarkProverGPU {
86        MultiTraceStarkProverGPU::new(
87            GpuBackend::default(),
88            self.device.clone(),
89            self.new_challenger(),
90        )
91    }
92
93    fn prove(
94        &self,
95        pk: &DeviceMultiStarkProvingKey<Self::PB>,
96        ctx: ProvingContext<Self::PB>,
97    ) -> Proof<Self::SC> {
98        let mut mem = MemTracker::start("prove");
99        mem.reset_peak();
100
101        let mpk_view = pk.view(ctx.air_ids());
102        #[cfg(feature = "touchemall")]
103        {
104            for (air_id, air_ctx) in ctx.per_air.iter() {
105                check_trace_validity(air_ctx, &pk.per_air[*air_id].air_name);
106            }
107        }
108        let mut prover = self.prover();
109        let proof = prover.prove(mpk_view, ctx);
110        proof.into()
111    }
112}
113
114#[cfg(feature = "touchemall")]
115pub fn check_trace_validity(proving_ctx: &AirProvingContext<GpuBackend>, name: &str) {
116    use openvm_cuda_common::copy::MemCopyD2H;
117    use openvm_stark_backend::prover::hal::MatrixDimensions;
118
119    use crate::types::F;
120
121    let trace = proving_ctx.common_main.as_ref().unwrap();
122    let height = trace.height();
123    let width = trace.width();
124    let trace = trace.to_host().unwrap();
125    for r in 0..height {
126        for c in 0..width {
127            let value = trace[c * height + r];
128            let value_u32 = unsafe { *(&value as *const F as *const u32) };
129            assert!(
130                value_u32 != 0xffffffff,
131                "potentially untouched value at ({r}, {c}) of a trace of size {height}x{width} for air {name}"
132            );
133        }
134    }
135}