openvm_cuda_backend/
engine.rs1use openvm_cuda_common::memory_manager::MemTracker;
2#[cfg(feature = "touchemall")]
3use openvm_stark_backend::prover::types::AirProvingContext;
4use openvm_stark_backend::{
5 config::StarkGenericConfig,
6 keygen::MultiStarkKeygenBuilder,
7 proof::Proof,
8 prover::{
9 coordinator::Coordinator,
10 types::{DeviceMultiStarkProvingKey, ProvingContext},
11 Prover,
12 },
13};
14use openvm_stark_sdk::{
15 config::{
16 baby_bear_poseidon2::{config_from_perm, default_perm, BabyBearPoseidon2Config},
17 fri_params::{
18 SecurityParameters, MAX_BATCH_SIZE_LOG_BLOWUP_1, MAX_BATCH_SIZE_LOG_BLOWUP_2,
19 MAX_NUM_CONSTRAINTS,
20 },
21 FriParameters,
22 },
23 engine::{StarkEngine, StarkFriEngine},
24};
25use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
26use p3_field::Field;
27
28use crate::{
29 fri_log_up::FriLogUpPhaseGpu,
30 gpu_device::{GpuConfig, GpuDevice},
31 prelude::{SC, WIDTH},
32 prover_backend::GpuBackend,
33};
34
35pub type MultiTraceStarkProverGPU = Coordinator<SC, GpuBackend, GpuDevice>;
36
37pub struct GpuBabyBearPoseidon2Engine {
38 device: GpuDevice,
39 config: BabyBearPoseidon2Config,
40 perm: Poseidon2BabyBear<WIDTH>,
41}
42
43impl StarkFriEngine for GpuBabyBearPoseidon2Engine {
44 fn new(fri_params: FriParameters) -> Self {
45 let perm = default_perm();
46 let security_params = SecurityParameters::new_baby_bear_100_bits(fri_params);
47 Self {
48 device: GpuDevice::new(
49 GpuConfig::new(
50 fri_params,
51 BabyBear::GENERATOR,
52 security_params.deep_ali_params,
53 ),
54 Some(FriLogUpPhaseGpu::new(security_params.log_up_params.clone())),
55 ),
56 config: config_from_perm(&perm, security_params),
57 perm,
58 }
59 }
60 fn fri_params(&self) -> FriParameters {
61 self.device.config.fri
62 }
63}
64
65impl StarkEngine for GpuBabyBearPoseidon2Engine {
66 type SC = BabyBearPoseidon2Config;
67 type PB = GpuBackend;
68 type PD = GpuDevice;
69
70 fn config(&self) -> &SC {
71 &self.config
72 }
73
74 fn max_constraint_degree(&self) -> Option<usize> {
75 Some(self.device.config.fri.max_constraint_degree())
76 }
77
78 fn new_challenger(&self) -> <SC as StarkGenericConfig>::Challenger {
79 <SC as StarkGenericConfig>::Challenger::new(self.perm.clone())
80 }
81
82 fn device(&self) -> &Self::PD {
83 &self.device
84 }
85
86 fn keygen_builder(&self) -> MultiStarkKeygenBuilder<'_, Self::SC> {
87 let mut builder = MultiStarkKeygenBuilder::new(self.config());
88 builder.set_max_constraint_degree(self.max_constraint_degree().unwrap());
89 let max_batch_size = if self.fri_params().log_blowup == 1 {
90 MAX_BATCH_SIZE_LOG_BLOWUP_1
91 } else {
92 MAX_BATCH_SIZE_LOG_BLOWUP_2
93 };
94 builder.max_batch_size = Some(max_batch_size);
95 builder.max_num_constraints = Some(MAX_NUM_CONSTRAINTS);
96
97 builder
98 }
99
100 fn prover(&self) -> MultiTraceStarkProverGPU {
101 MultiTraceStarkProverGPU::new(
102 GpuBackend::default(),
103 self.device.clone(),
104 self.new_challenger(),
105 )
106 }
107
108 fn prove(
109 &self,
110 pk: &DeviceMultiStarkProvingKey<Self::PB>,
111 ctx: ProvingContext<Self::PB>,
112 ) -> Proof<Self::SC> {
113 let mut mem = MemTracker::start("prove");
114 mem.reset_peak();
115
116 let mpk_view = pk.view(ctx.air_ids());
117 #[cfg(feature = "touchemall")]
118 {
119 for (air_id, air_ctx) in ctx.per_air.iter() {
120 check_trace_validity(air_ctx, &pk.per_air[*air_id].air_name);
121 }
122 }
123 let mut prover = self.prover();
124 let proof = prover.prove(mpk_view, ctx);
125 proof.into()
126 }
127}
128
129#[cfg(feature = "touchemall")]
130pub fn check_trace_validity(proving_ctx: &AirProvingContext<GpuBackend>, name: &str) {
131 use openvm_cuda_common::copy::MemCopyD2H;
132 use openvm_stark_backend::prover::hal::MatrixDimensions;
133
134 use crate::types::F;
135
136 let trace = proving_ctx.common_main.as_ref().unwrap();
137 let height = trace.height();
138 let width = trace.width();
139 let trace = trace.to_host().unwrap();
140 for r in 0..height {
141 for c in 0..width {
142 let value = trace[c * height + r];
143 let value_u32 = unsafe { *(&value as *const F as *const u32) };
144 assert!(
145 value_u32 != 0xffffffff,
146 "potentially untouched value at ({r}, {c}) of a trace of size {height}x{width} for air {name}"
147 );
148 }
149 }
150}