openvm_cuda_backend/
prover_backend.rs

1use std::iter::zip;
2
3use itertools::{izip, zip_eq, Itertools};
4use openvm_cuda_common::{memory_manager::MemTracker, stream::gpu_metrics_span};
5use openvm_stark_backend::{
6    air_builders::symbolic::SymbolicConstraints,
7    config::{Com, PcsProof, RapPartialProvingKey, RapPhaseSeqPartialProof},
8    keygen::view::MultiStarkVerifyingKeyView,
9    p3_challenger::{DuplexChallenger, FieldChallenger},
10    proof::{OpenedValues, OpeningProof},
11    prover::{
12        hal::{
13            MatrixDimensions, OpeningProver, ProverBackend, ProverDevice, QuotientCommitter,
14            RapPartialProver, TraceCommitter,
15        },
16        types::{
17            AirView, DeviceMultiStarkProvingKeyView, DeviceStarkProvingKey, PairView,
18            ProverDataAfterRapPhases, RapSinglePhaseView, RapView,
19        },
20    },
21};
22use p3_baby_bear::Poseidon2BabyBear;
23use p3_commit::PolynomialSpace;
24use p3_util::log2_strict_usize;
25use tracing::{info_span, instrument};
26
27use crate::{
28    base::DeviceMatrix,
29    gpu_device::GpuDevice,
30    lde::{GpuLde, GpuLdeImpl},
31    merkle_tree::GpuMerkleTree,
32    opener::OpeningProverGpu,
33    prelude::*,
34    quotient::{QuotientCommitterGpu, QuotientDataGpu},
35};
36
37/// Gpu backend implementation for STARK proving system
38#[derive(Clone, Copy, Default, Debug)]
39pub struct GpuBackend {}
40
41impl ProverBackend for GpuBackend {
42    const CHALLENGE_EXT_DEGREE: u8 = 4;
43
44    // Host Types
45    type Val = F;
46    type Challenge = EF;
47    type OpeningProof = OpeningProof<PcsProof<SC>, Self::Challenge>;
48    type RapPartialProof = Option<RapPhaseSeqPartialProof<SC>>;
49    type Commitment = Com<SC>; // From<[BabyBear; DIGEST_WIDTH]>
50    type Challenger = DuplexChallenger<F, Poseidon2BabyBear<WIDTH>, WIDTH, RATE>;
51
52    // Device Types
53    type Matrix = DeviceMatrix<F>;
54    type PcsData = GpuPcsData;
55    type RapPartialProvingKey = RapPartialProvingKey<SC>;
56}
57
58#[derive(Clone)]
59pub struct GpuPcsData {
60    pub data: GpuMerkleTree<GpuLdeImpl>,
61    pub log_trace_heights: Vec<u8>,
62}
63
64impl ProverDevice<GpuBackend> for GpuDevice {}
65
66impl TraceCommitter<GpuBackend> for GpuDevice {
67    #[instrument(level = "debug", skip_all)]
68    fn commit(&self, traces: &[DeviceMatrix<F>]) -> (Com<SC>, GpuPcsData) {
69        let _mem = MemTracker::start("commit");
70        tracing::debug!(
71            "trace (size,strong_count): {:?}",
72            traces
73                .iter()
74                .map(|t| (t.buffer().len(), t.strong_count()))
75                .collect::<Vec<_>>()
76        );
77        let traces_with_shifts = traces
78            .iter()
79            .map(|trace| (trace.clone(), self.config.shift))
80            .collect_vec();
81        // We drop the trace in Lde because `traces` is passed by reference
82        let (log_trace_heights, merkle_tree) =
83            self.commit_traces_with_lde(traces_with_shifts, self.config.fri.log_blowup);
84        let root = merkle_tree.root();
85        let pcs_data = GpuPcsData {
86            data: merkle_tree,
87            log_trace_heights,
88        };
89
90        (root, pcs_data)
91    }
92}
93
94type GB = GpuBackend;
95type GBChallenger = <GB as ProverBackend>::Challenger;
96type GBMatrix = <GB as ProverBackend>::Matrix;
97type GBVal = <GB as ProverBackend>::Val;
98type GBPcsData = <GB as ProverBackend>::PcsData;
99type GBCommitment = <GB as ProverBackend>::Commitment;
100
101impl RapPartialProver<GB> for GpuDevice {
102    #[instrument(skip_all)]
103    fn partially_prove(
104        &self,
105        challenger: &mut GBChallenger,
106        mpk: &DeviceMultiStarkProvingKeyView<'_, GB>,
107        trace_views: Vec<AirView<GBMatrix, GBVal>>,
108    ) -> (
109        Option<RapPhaseSeqPartialProof<SC>>,
110        ProverDataAfterRapPhases<GB>,
111    ) {
112        let mem = MemTracker::start("partially_prove");
113        let num_airs = mpk.per_air.len();
114        assert_eq!(num_airs, trace_views.len());
115
116        let (constraints_per_air, rap_pk_per_air): (Vec<_>, Vec<_>) = mpk
117            .per_air
118            .iter()
119            .map(|pk| {
120                (
121                    SymbolicConstraints::from(&pk.vk.symbolic_constraints),
122                    &pk.rap_partial_pk,
123                )
124            })
125            .unzip();
126
127        let trace_views = zip(&mpk.per_air, trace_views)
128            .map(|(pk, v)| PairView {
129                log_trace_height: log2_strict_usize(v.partitioned_main.first().unwrap().height())
130                    as u8,
131                preprocessed: pk.preprocessed_data.as_ref().map(|p| p.trace.clone()), // DeviceMatrix is smart pointer clone for now
132                partitioned_main: v.partitioned_main,
133                public_values: v.public_values,
134            })
135            .collect_vec();
136
137        let (rap_phase_seq_proof, rap_phase_seq_data) =
138            info_span!("generate_perm_trace").in_scope(|| {
139                self.rap_phase_seq()
140                    .partially_prove_gpu(
141                        challenger,
142                        &constraints_per_air.iter().collect_vec(),
143                        &rap_pk_per_air,
144                        trace_views,
145                    )
146                    .map_or((None, None), |(p, d)| (Some(p), Some(d)))
147            });
148        mem.tracing_info("after perm trace generation");
149
150        // Set up for the final output
151        let mvk_view = MultiStarkVerifyingKeyView::new(
152            mpk.per_air.iter().map(|pk| &pk.vk).collect(),
153            mpk.trace_height_constraints,
154            *mpk.vk_pre_hash,
155        );
156
157        let mut perm_matrix_idx = 0usize;
158        let rap_views_per_phase;
159        let perm_trace_per_air = if let Some(phase_data) = rap_phase_seq_data {
160            assert_eq!(mvk_view.num_phases(), 1);
161            assert_eq!(
162                mvk_view.num_challenges_in_phase(0),
163                phase_data.challenges.len()
164            );
165            let perm_views = zip_eq(
166                &phase_data.after_challenge_trace_per_air,
167                phase_data.exposed_values_per_air,
168            )
169            .map(|(perm_trace, exposed_values)| {
170                let mut matrix_idx = None;
171                if perm_trace.is_some() {
172                    matrix_idx = Some(perm_matrix_idx);
173                    perm_matrix_idx += 1;
174                }
175                RapSinglePhaseView {
176                    inner: matrix_idx,
177                    challenges: phase_data.challenges.clone(),
178                    exposed_values: exposed_values.unwrap_or_default(),
179                }
180            })
181            .collect_vec();
182            rap_views_per_phase = vec![perm_views]; // 1 challenge phase
183            phase_data.after_challenge_trace_per_air
184        } else {
185            assert_eq!(mvk_view.num_phases(), 0);
186            rap_views_per_phase = vec![];
187            vec![None; num_airs]
188        };
189
190        // Commit to permutation traces: this means only 1 challenge round right now
191        // One shared commit for all permutation traces (done on GPU)
192        let committed_pcs_data_per_phase: Vec<(Com<SC>, GpuPcsData)> =
193            gpu_metrics_span("perm_trace_commit_time_ms", || {
194                let flattened_traces_with_shifts = perm_trace_per_air
195                    .into_iter()
196                    .flatten()
197                    .map(|trace| (trace, self.config.shift))
198                    .collect_vec();
199                // Only commit if there are permutation traces
200                if !flattened_traces_with_shifts.is_empty() {
201                    let (log_trace_heights, merkle_tree) = self.commit_traces_with_lde(
202                        flattened_traces_with_shifts,
203                        self.config.fri.log_blowup,
204                    );
205                    let root = merkle_tree.root();
206                    let pcs_data = GpuPcsData {
207                        data: merkle_tree,
208                        log_trace_heights,
209                    };
210
211                    Some((root, pcs_data))
212                } else {
213                    None
214                }
215            })
216            .unwrap()
217            .into_iter()
218            .collect();
219        let prover_view = ProverDataAfterRapPhases {
220            committed_pcs_data_per_phase,
221            rap_views_per_phase,
222        };
223        (rap_phase_seq_proof, prover_view)
224    }
225}
226
227impl QuotientCommitter<GB> for GpuDevice {
228    #[instrument(skip_all)]
229    fn eval_and_commit_quotient(
230        &self,
231        challenger: &mut GBChallenger,
232        pk_views: &[&DeviceStarkProvingKey<GB>],
233        public_values: &[Vec<GBVal>],
234        cached_pcs_datas_per_air: &[Vec<GBPcsData>],
235        common_main_pcs_data: &GBPcsData,
236        prover_data_after: &ProverDataAfterRapPhases<GB>,
237    ) -> (GBCommitment, GBPcsData) {
238        let mem = MemTracker::start("quotient");
239        let alpha: EF = challenger.sample_ext_element();
240        tracing::debug!("alpha: {alpha:?}");
241        let qc = QuotientCommitterGpu::new(alpha);
242
243        let mut common_main_idx = 0;
244        let per_rap_quotient = gpu_metrics_span("quotient_poly_compute_time_ms", || {
245            izip!(pk_views, cached_pcs_datas_per_air, public_values)
246                .enumerate()
247                .map(|(i, (pk, cached_pcs_datas, pvs))| {
248                    // Prepare extended views(for GPU):
249                    let quotient_degree = pk.vk.quotient_degree;
250                    let log_trace_height = if pk.vk.has_common_main() {
251                        common_main_pcs_data.log_trace_heights[common_main_idx]
252                    } else {
253                        cached_pcs_datas[0].log_trace_heights[0]
254                    };
255                    let trace_domain = self.natural_domain_for_degree(1usize << log_trace_height);
256                    let quotient_domain = trace_domain
257                        .create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
258                    tracing::debug!("quotient_domain: {:?}", quotient_domain);
259                    let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
260                        cv.data.data.leaves[cv.matrix_idx as usize].take_lde(quotient_domain.size())
261                    });
262                    let mut partitioned_main: Vec<DeviceMatrix<F>> = cached_pcs_datas
263                        .iter()
264                        .map(|cv| cv.data.leaves[0].take_lde(quotient_domain.size()))
265                        .collect();
266                    if pk.vk.has_common_main() {
267                        partitioned_main.push(
268                            common_main_pcs_data.data.leaves[common_main_idx]
269                                .take_lde(quotient_domain.size()),
270                        );
271                        common_main_idx += 1;
272                    }
273                    let mut per_phase = zip(
274                        &prover_data_after.committed_pcs_data_per_phase,
275                        &prover_data_after.rap_views_per_phase,
276                    )
277                    .map(|((_, pcs_data), rap_views)| -> Option<_> {
278                        let rap_view = rap_views.get(i)?;
279                        let matrix_idx = rap_view.inner?;
280                        let extended_matrix =
281                            pcs_data.data.leaves[matrix_idx].take_lde(quotient_domain.size());
282                        Some(RapSinglePhaseView {
283                            inner: Some(extended_matrix),
284                            challenges: rap_view.challenges.clone(),
285                            exposed_values: rap_view.exposed_values.clone(),
286                        })
287                    })
288                    .collect_vec();
289                    while let Some(last) = per_phase.last() {
290                        if last.is_none() {
291                            per_phase.pop();
292                        } else {
293                            break;
294                        }
295                    }
296                    let per_phase = per_phase
297                        .into_iter()
298                        .map(|v| v.unwrap_or_default())
299                        .collect();
300
301                    // Compute quotient values
302                    let extended_view_gpu = RapView {
303                        log_trace_height,
304                        preprocessed,
305                        partitioned_main,
306                        public_values: pvs.to_vec(),
307                        per_phase,
308                    };
309                    let constraints = &pk.vk.symbolic_constraints;
310                    let quotient_degree = pk.vk.quotient_degree;
311                    qc.single_rap_quotient_values(
312                        self,
313                        constraints,
314                        extended_view_gpu,
315                        quotient_degree,
316                    )
317                })
318                .collect()
319        })
320        .unwrap();
321
322        let quotient_data = QuotientDataGpu {
323            inner: per_rap_quotient,
324        };
325
326        let quotient_values = quotient_data
327            .split()
328            .into_iter()
329            .map(|q| (q.chunk, self.config.shift / q.domain.shift))
330            .collect_vec();
331        mem.tracing_info("before commit");
332
333        // Commit to quotient polynomials. One shared commit for all quotient polynomials
334        gpu_metrics_span("quotient_poly_commit_time_ms", || {
335            let (log_trace_heights, merkle_tree) =
336                self.commit_traces_with_lde(quotient_values, self.config.fri.log_blowup);
337            let root = merkle_tree.root();
338            let pcs_data = GpuPcsData {
339                data: merkle_tree,
340                log_trace_heights,
341            };
342
343            (root, pcs_data)
344        })
345        .unwrap()
346    }
347}
348
349impl OpeningProver<GB> for GpuDevice {
350    #[instrument(skip_all)]
351    fn open(
352        &self,
353        challenger: &mut GBChallenger,
354        preprocessed: Vec<&GBPcsData>,
355        main: Vec<GBPcsData>,
356        after_phase: Vec<GBPcsData>,
357        quotient_data: GBPcsData,
358        quotient_degrees: &[u8],
359    ) -> OpeningProof<PcsProof<SC>, EF> {
360        let zeta: EF = challenger.sample_ext_element();
361        tracing::debug!("zeta: {zeta:?}");
362
363        let domain = |log_height| self.natural_domain_for_degree(1usize << log_height);
364
365        let preprocessed_iter = preprocessed.iter().map(|v| {
366            assert_eq!(v.log_trace_heights.len(), 1);
367            let domain = domain(v.log_trace_heights[0]);
368            (&v.data, vec![domain])
369        });
370        let main_iter = main.iter().map(|v| {
371            let domains = v
372                .log_trace_heights
373                .iter()
374                .copied()
375                .map(domain)
376                .collect_vec();
377            (&v.data, domains)
378        });
379        let after_phase_iter = after_phase.iter().map(|v| {
380            let domains = v
381                .log_trace_heights
382                .iter()
383                .copied()
384                .map(domain)
385                .collect_vec();
386            (&v.data, domains)
387        });
388        let mut rounds = preprocessed_iter
389            .chain(main_iter)
390            .chain(after_phase_iter)
391            .map(|(data, domains)| {
392                let points_per_mat = domains
393                    .iter()
394                    .map(|domain| vec![zeta, domain.next_point(zeta).unwrap()])
395                    .collect_vec();
396                (data, points_per_mat)
397            })
398            .collect_vec();
399        let num_chunks = quotient_degrees.iter().sum::<u8>() as usize;
400        let quotient_opening_points = vec![vec![zeta]; num_chunks];
401        rounds.push((&quotient_data.data, quotient_opening_points));
402
403        let opener = OpeningProverGpu {};
404        let (mut opening_values, opening_proof) =
405            info_span!("OpeningProverGpu::open").in_scope(|| opener.open(self, rounds, challenger));
406
407        // Unflatten opening_values
408        let mut quotient_openings = opening_values.pop().expect("Should have quotient opening");
409
410        let num_after_challenge = after_phase.len();
411        let after_challenge_openings = opening_values
412            .split_off(opening_values.len() - num_after_challenge)
413            .into_iter()
414            .map(|values| opener.collect_trace_openings(values))
415            .collect_vec();
416        assert_eq!(
417            after_challenge_openings.len(),
418            num_after_challenge,
419            "Incorrect number of after challenge trace openings"
420        );
421
422        let main_openings = opening_values
423            .split_off(preprocessed.len())
424            .into_iter()
425            .map(|values| opener.collect_trace_openings(values))
426            .collect_vec();
427        assert_eq!(
428            main_openings.len(),
429            main.len(),
430            "Incorrect number of main trace openings"
431        );
432
433        let preprocessed_openings = opening_values
434            .into_iter()
435            .map(|values| {
436                let mut openings = opener.collect_trace_openings(values);
437                openings
438                    .pop()
439                    .expect("Preprocessed trace should be opened at 1 point")
440            })
441            .collect_vec();
442        assert_eq!(
443            preprocessed_openings.len(),
444            preprocessed.len(),
445            "Incorrect number of preprocessed trace openings"
446        );
447
448        // Unflatten quotient openings
449        let quotient_openings = quotient_degrees
450            .iter()
451            .map(|&chunk_size| {
452                quotient_openings
453                    .drain(..chunk_size as usize)
454                    .map(|mut op| {
455                        op.pop()
456                            .expect("quotient chunk should be opened at 1 point")
457                    })
458                    .collect_vec()
459            })
460            .collect_vec();
461
462        OpeningProof {
463            proof: opening_proof,
464            values: OpenedValues {
465                preprocessed: preprocessed_openings,
466                main: main_openings,
467                after_challenge: after_challenge_openings,
468                quotient: quotient_openings,
469            },
470        }
471    }
472}