1use std::iter::zip;
2
3use itertools::{izip, zip_eq, Itertools};
4use openvm_cuda_common::{memory_manager::MemTracker, stream::gpu_metrics_span};
5use openvm_stark_backend::{
6 air_builders::symbolic::SymbolicConstraints,
7 config::{Com, PcsProof, RapPartialProvingKey, RapPhaseSeqPartialProof},
8 keygen::view::MultiStarkVerifyingKeyView,
9 p3_challenger::{DuplexChallenger, FieldChallenger},
10 proof::{OpenedValues, OpeningProof},
11 prover::{
12 hal::{
13 MatrixDimensions, OpeningProver, ProverBackend, ProverDevice, QuotientCommitter,
14 RapPartialProver, TraceCommitter,
15 },
16 types::{
17 AirView, DeviceMultiStarkProvingKeyView, DeviceStarkProvingKey, PairView,
18 ProverDataAfterRapPhases, RapSinglePhaseView, RapView,
19 },
20 },
21};
22use p3_baby_bear::Poseidon2BabyBear;
23use p3_commit::PolynomialSpace;
24use p3_util::log2_strict_usize;
25use tracing::{info_span, instrument};
26
27use crate::{
28 base::DeviceMatrix,
29 gpu_device::GpuDevice,
30 lde::{GpuLde, GpuLdeImpl},
31 merkle_tree::GpuMerkleTree,
32 opener::OpeningProverGpu,
33 prelude::*,
34 quotient::{QuotientCommitterGpu, QuotientDataGpu},
35};
36
37#[derive(Clone, Copy, Default, Debug)]
39pub struct GpuBackend {}
40
41impl ProverBackend for GpuBackend {
42 const CHALLENGE_EXT_DEGREE: u8 = 4;
43
44 type Val = F;
46 type Challenge = EF;
47 type OpeningProof = OpeningProof<PcsProof<SC>, Self::Challenge>;
48 type RapPartialProof = Option<RapPhaseSeqPartialProof<SC>>;
49 type Commitment = Com<SC>; type Challenger = DuplexChallenger<F, Poseidon2BabyBear<WIDTH>, WIDTH, RATE>;
51
52 type Matrix = DeviceMatrix<F>;
54 type PcsData = GpuPcsData;
55 type RapPartialProvingKey = RapPartialProvingKey<SC>;
56}
57
58#[derive(Clone)]
59pub struct GpuPcsData {
60 pub data: GpuMerkleTree<GpuLdeImpl>,
61 pub log_trace_heights: Vec<u8>,
62}
63
64impl ProverDevice<GpuBackend> for GpuDevice {}
65
66impl TraceCommitter<GpuBackend> for GpuDevice {
67 #[instrument(level = "debug", skip_all)]
68 fn commit(&self, traces: &[DeviceMatrix<F>]) -> (Com<SC>, GpuPcsData) {
69 let _mem = MemTracker::start("commit");
70 tracing::debug!(
71 "trace (size,strong_count): {:?}",
72 traces
73 .iter()
74 .map(|t| (t.buffer().len(), t.strong_count()))
75 .collect::<Vec<_>>()
76 );
77 let traces_with_shifts = traces
78 .iter()
79 .map(|trace| (trace.clone(), self.config.shift))
80 .collect_vec();
81 let (log_trace_heights, merkle_tree) =
83 self.commit_traces_with_lde(traces_with_shifts, self.config.fri.log_blowup);
84 let root = merkle_tree.root();
85 let pcs_data = GpuPcsData {
86 data: merkle_tree,
87 log_trace_heights,
88 };
89
90 (root, pcs_data)
91 }
92}
93
94type GB = GpuBackend;
95type GBChallenger = <GB as ProverBackend>::Challenger;
96type GBMatrix = <GB as ProverBackend>::Matrix;
97type GBVal = <GB as ProverBackend>::Val;
98type GBPcsData = <GB as ProverBackend>::PcsData;
99type GBCommitment = <GB as ProverBackend>::Commitment;
100
101impl RapPartialProver<GB> for GpuDevice {
102 #[instrument(skip_all)]
103 fn partially_prove(
104 &self,
105 challenger: &mut GBChallenger,
106 mpk: &DeviceMultiStarkProvingKeyView<'_, GB>,
107 trace_views: Vec<AirView<GBMatrix, GBVal>>,
108 ) -> (
109 Option<RapPhaseSeqPartialProof<SC>>,
110 ProverDataAfterRapPhases<GB>,
111 ) {
112 let mem = MemTracker::start("partially_prove");
113 let num_airs = mpk.per_air.len();
114 assert_eq!(num_airs, trace_views.len());
115
116 let (constraints_per_air, rap_pk_per_air): (Vec<_>, Vec<_>) = mpk
117 .per_air
118 .iter()
119 .map(|pk| {
120 (
121 SymbolicConstraints::from(&pk.vk.symbolic_constraints),
122 &pk.rap_partial_pk,
123 )
124 })
125 .unzip();
126
127 let trace_views = zip(&mpk.per_air, trace_views)
128 .map(|(pk, v)| PairView {
129 log_trace_height: log2_strict_usize(v.partitioned_main.first().unwrap().height())
130 as u8,
131 preprocessed: pk.preprocessed_data.as_ref().map(|p| p.trace.clone()), partitioned_main: v.partitioned_main,
133 public_values: v.public_values,
134 })
135 .collect_vec();
136
137 let (rap_phase_seq_proof, rap_phase_seq_data) =
138 info_span!("generate_perm_trace").in_scope(|| {
139 self.rap_phase_seq()
140 .partially_prove_gpu(
141 challenger,
142 &constraints_per_air.iter().collect_vec(),
143 &rap_pk_per_air,
144 trace_views,
145 )
146 .map_or((None, None), |(p, d)| (Some(p), Some(d)))
147 });
148 mem.tracing_info("after perm trace generation");
149
150 let mvk_view = MultiStarkVerifyingKeyView::new(
152 mpk.per_air.iter().map(|pk| &pk.vk).collect(),
153 mpk.trace_height_constraints,
154 *mpk.vk_pre_hash,
155 );
156
157 let mut perm_matrix_idx = 0usize;
158 let rap_views_per_phase;
159 let perm_trace_per_air = if let Some(phase_data) = rap_phase_seq_data {
160 assert_eq!(mvk_view.num_phases(), 1);
161 assert_eq!(
162 mvk_view.num_challenges_in_phase(0),
163 phase_data.challenges.len()
164 );
165 let perm_views = zip_eq(
166 &phase_data.after_challenge_trace_per_air,
167 phase_data.exposed_values_per_air,
168 )
169 .map(|(perm_trace, exposed_values)| {
170 let mut matrix_idx = None;
171 if perm_trace.is_some() {
172 matrix_idx = Some(perm_matrix_idx);
173 perm_matrix_idx += 1;
174 }
175 RapSinglePhaseView {
176 inner: matrix_idx,
177 challenges: phase_data.challenges.clone(),
178 exposed_values: exposed_values.unwrap_or_default(),
179 }
180 })
181 .collect_vec();
182 rap_views_per_phase = vec![perm_views]; phase_data.after_challenge_trace_per_air
184 } else {
185 assert_eq!(mvk_view.num_phases(), 0);
186 rap_views_per_phase = vec![];
187 vec![None; num_airs]
188 };
189
190 let committed_pcs_data_per_phase: Vec<(Com<SC>, GpuPcsData)> =
193 gpu_metrics_span("perm_trace_commit_time_ms", || {
194 let flattened_traces_with_shifts = perm_trace_per_air
195 .into_iter()
196 .flatten()
197 .map(|trace| (trace, self.config.shift))
198 .collect_vec();
199 if !flattened_traces_with_shifts.is_empty() {
201 let (log_trace_heights, merkle_tree) = self.commit_traces_with_lde(
202 flattened_traces_with_shifts,
203 self.config.fri.log_blowup,
204 );
205 let root = merkle_tree.root();
206 let pcs_data = GpuPcsData {
207 data: merkle_tree,
208 log_trace_heights,
209 };
210
211 Some((root, pcs_data))
212 } else {
213 None
214 }
215 })
216 .unwrap()
217 .into_iter()
218 .collect();
219 let prover_view = ProverDataAfterRapPhases {
220 committed_pcs_data_per_phase,
221 rap_views_per_phase,
222 };
223 (rap_phase_seq_proof, prover_view)
224 }
225}
226
227impl QuotientCommitter<GB> for GpuDevice {
228 #[instrument(skip_all)]
229 fn eval_and_commit_quotient(
230 &self,
231 challenger: &mut GBChallenger,
232 pk_views: &[&DeviceStarkProvingKey<GB>],
233 public_values: &[Vec<GBVal>],
234 cached_pcs_datas_per_air: &[Vec<GBPcsData>],
235 common_main_pcs_data: &GBPcsData,
236 prover_data_after: &ProverDataAfterRapPhases<GB>,
237 ) -> (GBCommitment, GBPcsData) {
238 let mem = MemTracker::start("quotient");
239 let alpha: EF = challenger.sample_ext_element();
240 tracing::debug!("alpha: {alpha:?}");
241 let qc = QuotientCommitterGpu::new(alpha);
242
243 let mut common_main_idx = 0;
244 let per_rap_quotient = gpu_metrics_span("quotient_poly_compute_time_ms", || {
245 izip!(pk_views, cached_pcs_datas_per_air, public_values)
246 .enumerate()
247 .map(|(i, (pk, cached_pcs_datas, pvs))| {
248 let quotient_degree = pk.vk.quotient_degree;
250 let log_trace_height = if pk.vk.has_common_main() {
251 common_main_pcs_data.log_trace_heights[common_main_idx]
252 } else {
253 cached_pcs_datas[0].log_trace_heights[0]
254 };
255 let trace_domain = self.natural_domain_for_degree(1usize << log_trace_height);
256 let quotient_domain = trace_domain
257 .create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
258 tracing::debug!("quotient_domain: {:?}", quotient_domain);
259 let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
260 cv.data.data.leaves[cv.matrix_idx as usize].take_lde(quotient_domain.size())
261 });
262 let mut partitioned_main: Vec<DeviceMatrix<F>> = cached_pcs_datas
263 .iter()
264 .map(|cv| cv.data.leaves[0].take_lde(quotient_domain.size()))
265 .collect();
266 if pk.vk.has_common_main() {
267 partitioned_main.push(
268 common_main_pcs_data.data.leaves[common_main_idx]
269 .take_lde(quotient_domain.size()),
270 );
271 common_main_idx += 1;
272 }
273 let mut per_phase = zip(
274 &prover_data_after.committed_pcs_data_per_phase,
275 &prover_data_after.rap_views_per_phase,
276 )
277 .map(|((_, pcs_data), rap_views)| -> Option<_> {
278 let rap_view = rap_views.get(i)?;
279 let matrix_idx = rap_view.inner?;
280 let extended_matrix =
281 pcs_data.data.leaves[matrix_idx].take_lde(quotient_domain.size());
282 Some(RapSinglePhaseView {
283 inner: Some(extended_matrix),
284 challenges: rap_view.challenges.clone(),
285 exposed_values: rap_view.exposed_values.clone(),
286 })
287 })
288 .collect_vec();
289 while let Some(last) = per_phase.last() {
290 if last.is_none() {
291 per_phase.pop();
292 } else {
293 break;
294 }
295 }
296 let per_phase = per_phase
297 .into_iter()
298 .map(|v| v.unwrap_or_default())
299 .collect();
300
301 let extended_view_gpu = RapView {
303 log_trace_height,
304 preprocessed,
305 partitioned_main,
306 public_values: pvs.to_vec(),
307 per_phase,
308 };
309 let constraints = &pk.vk.symbolic_constraints;
310 let quotient_degree = pk.vk.quotient_degree;
311 qc.single_rap_quotient_values(
312 self,
313 constraints,
314 extended_view_gpu,
315 quotient_degree,
316 )
317 })
318 .collect()
319 })
320 .unwrap();
321
322 let quotient_data = QuotientDataGpu {
323 inner: per_rap_quotient,
324 };
325
326 let quotient_values = quotient_data
327 .split()
328 .into_iter()
329 .map(|q| (q.chunk, self.config.shift / q.domain.shift))
330 .collect_vec();
331 mem.tracing_info("before commit");
332
333 gpu_metrics_span("quotient_poly_commit_time_ms", || {
335 let (log_trace_heights, merkle_tree) =
336 self.commit_traces_with_lde(quotient_values, self.config.fri.log_blowup);
337 let root = merkle_tree.root();
338 let pcs_data = GpuPcsData {
339 data: merkle_tree,
340 log_trace_heights,
341 };
342
343 (root, pcs_data)
344 })
345 .unwrap()
346 }
347}
348
349impl OpeningProver<GB> for GpuDevice {
350 #[instrument(skip_all)]
351 fn open(
352 &self,
353 challenger: &mut GBChallenger,
354 preprocessed: Vec<&GBPcsData>,
355 main: Vec<GBPcsData>,
356 after_phase: Vec<GBPcsData>,
357 quotient_data: GBPcsData,
358 quotient_degrees: &[u8],
359 ) -> OpeningProof<PcsProof<SC>, EF> {
360 let zeta: EF = challenger.sample_ext_element();
361 tracing::debug!("zeta: {zeta:?}");
362
363 let domain = |log_height| self.natural_domain_for_degree(1usize << log_height);
364
365 let preprocessed_iter = preprocessed.iter().map(|v| {
366 assert_eq!(v.log_trace_heights.len(), 1);
367 let domain = domain(v.log_trace_heights[0]);
368 (&v.data, vec![domain])
369 });
370 let main_iter = main.iter().map(|v| {
371 let domains = v
372 .log_trace_heights
373 .iter()
374 .copied()
375 .map(domain)
376 .collect_vec();
377 (&v.data, domains)
378 });
379 let after_phase_iter = after_phase.iter().map(|v| {
380 let domains = v
381 .log_trace_heights
382 .iter()
383 .copied()
384 .map(domain)
385 .collect_vec();
386 (&v.data, domains)
387 });
388 let mut rounds = preprocessed_iter
389 .chain(main_iter)
390 .chain(after_phase_iter)
391 .map(|(data, domains)| {
392 let points_per_mat = domains
393 .iter()
394 .map(|domain| vec![zeta, domain.next_point(zeta).unwrap()])
395 .collect_vec();
396 (data, points_per_mat)
397 })
398 .collect_vec();
399 let num_chunks = quotient_degrees.iter().sum::<u8>() as usize;
400 let quotient_opening_points = vec![vec![zeta]; num_chunks];
401 rounds.push(("ient_data.data, quotient_opening_points));
402
403 let opener = OpeningProverGpu {};
404 let (mut opening_values, opening_proof) =
405 info_span!("OpeningProverGpu::open").in_scope(|| opener.open(self, rounds, challenger));
406
407 let mut quotient_openings = opening_values.pop().expect("Should have quotient opening");
409
410 let num_after_challenge = after_phase.len();
411 let after_challenge_openings = opening_values
412 .split_off(opening_values.len() - num_after_challenge)
413 .into_iter()
414 .map(|values| opener.collect_trace_openings(values))
415 .collect_vec();
416 assert_eq!(
417 after_challenge_openings.len(),
418 num_after_challenge,
419 "Incorrect number of after challenge trace openings"
420 );
421
422 let main_openings = opening_values
423 .split_off(preprocessed.len())
424 .into_iter()
425 .map(|values| opener.collect_trace_openings(values))
426 .collect_vec();
427 assert_eq!(
428 main_openings.len(),
429 main.len(),
430 "Incorrect number of main trace openings"
431 );
432
433 let preprocessed_openings = opening_values
434 .into_iter()
435 .map(|values| {
436 let mut openings = opener.collect_trace_openings(values);
437 openings
438 .pop()
439 .expect("Preprocessed trace should be opened at 1 point")
440 })
441 .collect_vec();
442 assert_eq!(
443 preprocessed_openings.len(),
444 preprocessed.len(),
445 "Incorrect number of preprocessed trace openings"
446 );
447
448 let quotient_openings = quotient_degrees
450 .iter()
451 .map(|&chunk_size| {
452 quotient_openings
453 .drain(..chunk_size as usize)
454 .map(|mut op| {
455 op.pop()
456 .expect("quotient chunk should be opened at 1 point")
457 })
458 .collect_vec()
459 })
460 .collect_vec();
461
462 OpeningProof {
463 proof: opening_proof,
464 values: OpenedValues {
465 preprocessed: preprocessed_openings,
466 main: main_openings,
467 after_challenge: after_challenge_openings,
468 quotient: quotient_openings,
469 },
470 }
471 }
472}