openvm_cuda_backend/
engine.rs1use openvm_cuda_common::memory_manager::MemTracker;
2#[cfg(feature = "touchemall")]
3use openvm_stark_backend::prover::types::AirProvingContext;
4use openvm_stark_backend::{
5 config::StarkGenericConfig,
6 proof::Proof,
7 prover::{
8 coordinator::Coordinator,
9 types::{DeviceMultiStarkProvingKey, ProvingContext},
10 Prover,
11 },
12};
13use openvm_stark_sdk::{
14 config::{
15 baby_bear_poseidon2::{config_from_perm, default_perm, BabyBearPoseidon2Config},
16 fri_params::SecurityParameters,
17 log_up_params::log_up_security_params_baby_bear_100_bits,
18 FriParameters,
19 },
20 engine::{StarkEngine, StarkFriEngine},
21};
22use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
23use p3_field::Field;
24
25use crate::{
26 fri_log_up::FriLogUpPhaseGpu,
27 gpu_device::{GpuConfig, GpuDevice},
28 prelude::{SC, WIDTH},
29 prover_backend::GpuBackend,
30};
31
32pub type MultiTraceStarkProverGPU = Coordinator<SC, GpuBackend, GpuDevice>;
33
34pub struct GpuBabyBearPoseidon2Engine {
35 device: GpuDevice,
36 config: BabyBearPoseidon2Config,
37 perm: Poseidon2BabyBear<WIDTH>,
38}
39
40impl StarkFriEngine for GpuBabyBearPoseidon2Engine {
41 fn new(fri_params: FriParameters) -> Self {
42 let perm = default_perm();
43 let log_up_params = log_up_security_params_baby_bear_100_bits();
44 Self {
45 device: GpuDevice::new(
46 GpuConfig::new(fri_params, BabyBear::GENERATOR),
47 Some(FriLogUpPhaseGpu::new(log_up_params.clone())),
48 ),
49 config: config_from_perm(
50 &perm,
51 SecurityParameters {
52 fri_params,
53 log_up_params,
54 },
55 ),
56 perm,
57 }
58 }
59 fn fri_params(&self) -> FriParameters {
60 self.device.config.fri
61 }
62}
63
64impl StarkEngine for GpuBabyBearPoseidon2Engine {
65 type SC = BabyBearPoseidon2Config;
66 type PB = GpuBackend;
67 type PD = GpuDevice;
68
69 fn config(&self) -> &SC {
70 &self.config
71 }
72
73 fn max_constraint_degree(&self) -> Option<usize> {
74 Some(self.device.config.fri.max_constraint_degree())
75 }
76
77 fn new_challenger(&self) -> <SC as StarkGenericConfig>::Challenger {
78 <SC as StarkGenericConfig>::Challenger::new(self.perm.clone())
79 }
80
81 fn device(&self) -> &Self::PD {
82 &self.device
83 }
84
85 fn prover(&self) -> MultiTraceStarkProverGPU {
86 MultiTraceStarkProverGPU::new(
87 GpuBackend::default(),
88 self.device.clone(),
89 self.new_challenger(),
90 )
91 }
92
93 fn prove(
94 &self,
95 pk: &DeviceMultiStarkProvingKey<Self::PB>,
96 ctx: ProvingContext<Self::PB>,
97 ) -> Proof<Self::SC> {
98 let mut mem = MemTracker::start("prove");
99 mem.reset_peak();
100
101 let mpk_view = pk.view(ctx.air_ids());
102 #[cfg(feature = "touchemall")]
103 {
104 for (air_id, air_ctx) in ctx.per_air.iter() {
105 check_trace_validity(air_ctx, &pk.per_air[*air_id].air_name);
106 }
107 }
108 let mut prover = self.prover();
109 let proof = prover.prove(mpk_view, ctx);
110 proof.into()
111 }
112}
113
114#[cfg(feature = "touchemall")]
115pub fn check_trace_validity(proving_ctx: &AirProvingContext<GpuBackend>, name: &str) {
116 use openvm_cuda_common::copy::MemCopyD2H;
117 use openvm_stark_backend::prover::hal::MatrixDimensions;
118
119 use crate::types::F;
120
121 let trace = proving_ctx.common_main.as_ref().unwrap();
122 let height = trace.height();
123 let width = trace.width();
124 let trace = trace.to_host().unwrap();
125 for r in 0..height {
126 for c in 0..width {
127 let value = trace[c * height + r];
128 let value_u32 = unsafe { *(&value as *const F as *const u32) };
129 assert!(
130 value_u32 != 0xffffffff,
131 "potentially untouched value at ({r}, {c}) of a trace of size {height}x{width} for air {name}"
132 );
133 }
134 }
135}