openvm_cuda_backend/
engine.rs

1use openvm_cuda_common::memory_manager::MemTracker;
2#[cfg(feature = "touchemall")]
3use openvm_stark_backend::prover::types::AirProvingContext;
4use openvm_stark_backend::{
5    config::StarkGenericConfig,
6    keygen::MultiStarkKeygenBuilder,
7    proof::Proof,
8    prover::{
9        coordinator::Coordinator,
10        types::{DeviceMultiStarkProvingKey, ProvingContext},
11        Prover,
12    },
13};
14use openvm_stark_sdk::{
15    config::{
16        baby_bear_poseidon2::{config_from_perm, default_perm, BabyBearPoseidon2Config},
17        fri_params::{
18            SecurityParameters, MAX_BATCH_SIZE_LOG_BLOWUP_1, MAX_BATCH_SIZE_LOG_BLOWUP_2,
19            MAX_NUM_CONSTRAINTS,
20        },
21        FriParameters,
22    },
23    engine::{StarkEngine, StarkFriEngine},
24};
25use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
26use p3_field::Field;
27
28use crate::{
29    fri_log_up::FriLogUpPhaseGpu,
30    gpu_device::{GpuConfig, GpuDevice},
31    prelude::{SC, WIDTH},
32    prover_backend::GpuBackend,
33};
34
35pub type MultiTraceStarkProverGPU = Coordinator<SC, GpuBackend, GpuDevice>;
36
37pub struct GpuBabyBearPoseidon2Engine {
38    device: GpuDevice,
39    config: BabyBearPoseidon2Config,
40    perm: Poseidon2BabyBear<WIDTH>,
41}
42
43impl StarkFriEngine for GpuBabyBearPoseidon2Engine {
44    fn new(fri_params: FriParameters) -> Self {
45        let perm = default_perm();
46        let security_params = SecurityParameters::new_baby_bear_100_bits(fri_params);
47        Self {
48            device: GpuDevice::new(
49                GpuConfig::new(
50                    fri_params,
51                    BabyBear::GENERATOR,
52                    security_params.deep_ali_params,
53                ),
54                Some(FriLogUpPhaseGpu::new(security_params.log_up_params.clone())),
55            ),
56            config: config_from_perm(&perm, security_params),
57            perm,
58        }
59    }
60    fn fri_params(&self) -> FriParameters {
61        self.device.config.fri
62    }
63}
64
65impl StarkEngine for GpuBabyBearPoseidon2Engine {
66    type SC = BabyBearPoseidon2Config;
67    type PB = GpuBackend;
68    type PD = GpuDevice;
69
70    fn config(&self) -> &SC {
71        &self.config
72    }
73
74    fn max_constraint_degree(&self) -> Option<usize> {
75        Some(self.device.config.fri.max_constraint_degree())
76    }
77
78    fn new_challenger(&self) -> <SC as StarkGenericConfig>::Challenger {
79        <SC as StarkGenericConfig>::Challenger::new(self.perm.clone())
80    }
81
82    fn device(&self) -> &Self::PD {
83        &self.device
84    }
85
86    fn keygen_builder(&self) -> MultiStarkKeygenBuilder<'_, Self::SC> {
87        let mut builder = MultiStarkKeygenBuilder::new(self.config());
88        builder.set_max_constraint_degree(self.max_constraint_degree().unwrap());
89        let max_batch_size = if self.fri_params().log_blowup == 1 {
90            MAX_BATCH_SIZE_LOG_BLOWUP_1
91        } else {
92            MAX_BATCH_SIZE_LOG_BLOWUP_2
93        };
94        builder.max_batch_size = Some(max_batch_size);
95        builder.max_num_constraints = Some(MAX_NUM_CONSTRAINTS);
96
97        builder
98    }
99
100    fn prover(&self) -> MultiTraceStarkProverGPU {
101        MultiTraceStarkProverGPU::new(
102            GpuBackend::default(),
103            self.device.clone(),
104            self.new_challenger(),
105        )
106    }
107
108    fn prove(
109        &self,
110        pk: &DeviceMultiStarkProvingKey<Self::PB>,
111        ctx: ProvingContext<Self::PB>,
112    ) -> Proof<Self::SC> {
113        let mut mem = MemTracker::start("prove");
114        mem.reset_peak();
115
116        let mpk_view = pk.view(ctx.air_ids());
117        #[cfg(feature = "touchemall")]
118        {
119            for (air_id, air_ctx) in ctx.per_air.iter() {
120                check_trace_validity(air_ctx, &pk.per_air[*air_id].air_name);
121            }
122        }
123        let mut prover = self.prover();
124        let proof = prover.prove(mpk_view, ctx);
125        proof.into()
126    }
127}
128
129#[cfg(feature = "touchemall")]
130pub fn check_trace_validity(proving_ctx: &AirProvingContext<GpuBackend>, name: &str) {
131    use openvm_cuda_common::copy::MemCopyD2H;
132    use openvm_stark_backend::prover::hal::MatrixDimensions;
133
134    use crate::types::F;
135
136    let trace = proving_ctx.common_main.as_ref().unwrap();
137    let height = trace.height();
138    let width = trace.width();
139    let trace = trace.to_host().unwrap();
140    for r in 0..height {
141        for c in 0..width {
142            let value = trace[c * height + r];
143            let value_u32 = unsafe { *(&value as *const F as *const u32) };
144            assert!(
145                value_u32 != 0xffffffff,
146                "potentially untouched value at ({r}, {c}) of a trace of size {height}x{width} for air {name}"
147            );
148        }
149    }
150}