openvm_stark_backend/prover/cpu/
mod.rs

1use std::{iter::zip, marker::PhantomData, mem::ManuallyDrop, ops::Deref, sync::Arc};
2
3use derivative::Derivative;
4use itertools::{izip, zip_eq, Itertools};
5use opener::OpeningProver;
6use p3_challenger::FieldChallenger;
7use p3_commit::{Pcs, PolynomialSpace};
8use p3_field::{ExtensionField, Field, FieldExtensionAlgebra};
9use p3_matrix::{dense::RowMajorMatrix, Matrix};
10use p3_util::log2_strict_usize;
11use quotient::QuotientCommitter;
12use tracing::info_span;
13
14use super::{
15    hal::{self, DeviceDataTransporter, MatrixDimensions, ProverBackend, ProverDevice},
16    types::{
17        AirView, DeviceMultiStarkProvingKey, DeviceStarkProvingKey, ProverDataAfterRapPhases,
18        RapView, SingleCommitPreimage,
19    },
20};
21use crate::{
22    air_builders::symbolic::SymbolicConstraints,
23    config::{
24        Com, PcsProof, PcsProverData, RapPartialProvingKey, RapPhaseSeqPartialProof,
25        StarkGenericConfig, Val,
26    },
27    interaction::RapPhaseSeq,
28    keygen::types::MultiStarkProvingKey,
29    proof::OpeningProof,
30    prover::{
31        hal::TraceCommitter,
32        types::{CommittedTraceData, DeviceMultiStarkProvingKeyView, PairView, RapSinglePhaseView},
33    },
34};
35
36/// Polynomial opening proofs
37pub mod opener;
38/// Computation of DEEP quotient polynomial and commitment
39pub mod quotient;
40
41/// CPU backend using Plonky3 traits.
42///
43/// # Safety
44/// For performance optimization of extension field operations, we assumes that `SC::Challenge` is
45/// an extension field of `F = Val<SC>` that is `repr(C)` or `repr(transparent)` with
46/// internal memory layout `[F; SC::Challenge::D]`.
47/// This ensures `SC::Challenge` and `F` have the same alignment and
48/// `size_of::<SC::Challenge>() == size_of::<F>() * SC::Challenge::D`.
49/// We assume that `<SC::Challenge as ExtensionField<F>::as_base_slice` is the same as
50/// transmuting `SC::Challenge` to `[F; SC::Challenge::D]`.
51#[derive(Derivative)]
52#[derivative(Clone(bound = ""), Copy(bound = ""), Default(bound = ""))]
53pub struct CpuBackend<SC> {
54    phantom: PhantomData<SC>,
55}
56
57/// # Safety
58/// See [`CpuBackend`].
59#[derive(Derivative, derive_new::new)]
60#[derivative(Clone(bound = ""))]
61pub struct CpuDevice<SC> {
62    // Use Arc to get around Clone-ing SC
63    pub config: Arc<SC>,
64    /// When committing a matrix, the matrix is cloned into newly allocated memory.
65    /// The size of the newly allocated memory will be `matrix.size() << log_blowup_factor`.
66    log_blowup_factor: usize,
67}
68
69impl<SC: StarkGenericConfig> ProverBackend for CpuBackend<SC> {
70    const CHALLENGE_EXT_DEGREE: u8 = <SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D as u8;
71
72    type Val = Val<SC>;
73    type Challenge = SC::Challenge;
74    type OpeningProof = OpeningProof<PcsProof<SC>, SC::Challenge>;
75    type RapPartialProof = Option<RapPhaseSeqPartialProof<SC>>;
76    type Commitment = Com<SC>;
77    type Challenger = SC::Challenger;
78    type Matrix = Arc<RowMajorMatrix<Val<SC>>>;
79    type PcsData = PcsData<SC>;
80    type RapPartialProvingKey = RapPartialProvingKey<SC>;
81}
82
83#[derive(Derivative, derive_new::new)]
84#[derivative(Clone(bound = ""))]
85pub struct PcsData<SC: StarkGenericConfig> {
86    /// The preimage of a single commitment.
87    pub data: Arc<PcsProverData<SC>>,
88    /// A mixed matrix commitment scheme commits to multiple trace matrices within a single
89    /// commitment. This is the ordered list of log2 heights of all committed trace matrices.
90    pub log_trace_heights: Vec<u8>,
91}
92
93impl<T: Send + Sync + Clone> MatrixDimensions for Arc<RowMajorMatrix<T>> {
94    fn height(&self) -> usize {
95        self.deref().height()
96    }
97    fn width(&self) -> usize {
98        self.deref().width()
99    }
100}
101
102impl<SC> CpuDevice<SC> {
103    pub fn config(&self) -> &SC {
104        &self.config
105    }
106}
107
108impl<SC: StarkGenericConfig> CpuDevice<SC> {
109    pub fn pcs(&self) -> &SC::Pcs {
110        self.config.pcs()
111    }
112}
113
114impl<SC: StarkGenericConfig> ProverDevice<CpuBackend<SC>> for CpuDevice<SC> {}
115
116impl<SC: StarkGenericConfig> TraceCommitter<CpuBackend<SC>> for CpuDevice<SC> {
117    fn commit(&self, traces: &[Arc<RowMajorMatrix<Val<SC>>>]) -> (Com<SC>, PcsData<SC>) {
118        let log_blowup_factor = self.log_blowup_factor;
119        let pcs = self.pcs();
120        let (log_trace_heights, traces_with_domains): (Vec<_>, Vec<_>) = traces
121            .iter()
122            .map(|matrix| {
123                let height = matrix.height();
124                let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
125                // Recomputing the domain is lightweight
126                let domain = pcs.natural_domain_for_degree(height);
127                // pcs.commit takes the trace matrix and in the case of FRI, does in-place cosetDFT
128                // which requires resizing to a larger buffer size. Since we are cloning anyways,
129                // we should just allocate the larger size to avoid memory-reallocation
130                // ref: https://github.com/Plonky3/Plonky3/blob/8c8bbb4c17bd2b7ef2404338ab8f9036d5f08337/dft/src/traits.rs#L116
131                let trace_slice = &matrix.as_ref().values;
132                let new_buffer_size = trace_slice
133                    .len()
134                    .checked_shl(log_blowup_factor.try_into().unwrap())
135                    .unwrap();
136                let mut new_buffer = Vec::with_capacity(new_buffer_size);
137                // SAFETY:
138                // - `trace_slice` is allocated for `trace_slice.len() * size_of::<F>` bytes,
139                //   obviously
140                // - we just allocated `new_buffer` for at least `trace_slice.len() * size_of::<F>`
141                //   bytes above (more if there's blowup)
142                // - both are slices of &[F] so alignment is guaranteed
143                // - `new_buffer` is newly allocated so non-overlapping with `trace_slice`
144                unsafe {
145                    std::ptr::copy_nonoverlapping(
146                        trace_slice.as_ptr(),
147                        new_buffer.as_mut_ptr(),
148                        trace_slice.len(),
149                    );
150                    new_buffer.set_len(trace_slice.len());
151                }
152                (
153                    log_height,
154                    (domain, RowMajorMatrix::new(new_buffer, matrix.width)),
155                )
156            })
157            .unzip();
158        let (commit, data) = pcs.commit(traces_with_domains);
159        (
160            commit,
161            PcsData {
162                data: Arc::new(data),
163                log_trace_heights,
164            },
165        )
166    }
167}
168
169impl<SC: StarkGenericConfig> hal::RapPartialProver<CpuBackend<SC>> for CpuDevice<SC> {
170    fn partially_prove(
171        &self,
172        challenger: &mut SC::Challenger,
173        mpk: &DeviceMultiStarkProvingKeyView<CpuBackend<SC>>,
174        trace_views: Vec<AirView<Arc<RowMajorMatrix<Val<SC>>>, Val<SC>>>,
175    ) -> (
176        Option<RapPhaseSeqPartialProof<SC>>,
177        ProverDataAfterRapPhases<CpuBackend<SC>>,
178    ) {
179        let num_airs = mpk.per_air.len();
180        assert_eq!(num_airs, trace_views.len());
181
182        let (constraints_per_air, rap_pk_per_air): (Vec<_>, Vec<_>) = mpk
183            .per_air
184            .iter()
185            .map(|pk| {
186                (
187                    SymbolicConstraints::from(&pk.vk.symbolic_constraints),
188                    &pk.rap_partial_pk,
189                )
190            })
191            .unzip();
192
193        let trace_views = zip(&mpk.per_air, trace_views)
194            .map(|(pk, v)| PairView {
195                log_trace_height: log2_strict_usize(v.partitioned_main.first().unwrap().height())
196                    as u8,
197                preprocessed: pk.preprocessed_data.as_ref().map(|p| p.trace.clone()), // Arc::clone for now
198                partitioned_main: v.partitioned_main,
199                public_values: v.public_values,
200            })
201            .collect_vec();
202        let (rap_phase_seq_proof, rap_phase_seq_data) = self
203            .config()
204            .rap_phase_seq()
205            .partially_prove(
206                challenger,
207                &constraints_per_air.iter().collect_vec(),
208                &rap_pk_per_air,
209                trace_views,
210            )
211            .map_or((None, None), |(p, d)| (Some(p), Some(d)));
212
213        let mvk_view = mpk.vk_view();
214
215        let mut perm_matrix_idx = 0usize;
216        let rap_views_per_phase;
217        let perm_trace_per_air = if let Some(phase_data) = rap_phase_seq_data {
218            assert_eq!(mvk_view.num_phases(), 1);
219            assert_eq!(
220                mvk_view.num_challenges_in_phase(0),
221                phase_data.challenges.len()
222            );
223            let perm_views = zip_eq(
224                &phase_data.after_challenge_trace_per_air,
225                phase_data.exposed_values_per_air,
226            )
227            .map(|(perm_trace, exposed_values)| {
228                let mut matrix_idx = None;
229                if perm_trace.is_some() {
230                    matrix_idx = Some(perm_matrix_idx);
231                    perm_matrix_idx += 1;
232                }
233                RapSinglePhaseView {
234                    inner: matrix_idx,
235                    challenges: phase_data.challenges.clone(),
236                    exposed_values: exposed_values.unwrap_or_default(),
237                }
238            })
239            .collect_vec();
240            rap_views_per_phase = vec![perm_views]; // 1 challenge phase
241            phase_data.after_challenge_trace_per_air
242        } else {
243            assert_eq!(mvk_view.num_phases(), 0);
244            rap_views_per_phase = vec![];
245            vec![None; num_airs]
246        };
247
248        // Commit to permutation traces: this means only 1 challenge round right now
249        // One shared commit for all permutation traces
250        let committed_pcs_data_per_phase: Vec<(Com<SC>, PcsData<SC>)> =
251            info_span!("perm_trace_commit")
252                .in_scope(|| {
253                    let (log_trace_heights, flattened_traces): (Vec<_>, Vec<_>) =
254                        perm_trace_per_air
255                            .into_iter()
256                            .flatten()
257                            .map(|perm_trace| {
258                                // SAFETY: `Challenge` is assumed to be extension field of `F`
259                                // with memory layout `[F; Challenge::D]`
260                                let trace = unsafe { transmute_to_base(perm_trace) };
261                                let height = trace.height();
262                                let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
263                                let domain = self.pcs().natural_domain_for_degree(height);
264                                (log_height, (domain, trace))
265                            })
266                            .collect();
267                    // Only commit if there are permutation traces
268                    if !flattened_traces.is_empty() {
269                        let (commit, data) = self.pcs().commit(flattened_traces);
270                        Some((commit, PcsData::new(Arc::new(data), log_trace_heights)))
271                    } else {
272                        None
273                    }
274                })
275                .into_iter()
276                .collect();
277        let prover_view = ProverDataAfterRapPhases {
278            committed_pcs_data_per_phase,
279            rap_views_per_phase,
280        };
281        (rap_phase_seq_proof, prover_view)
282    }
283}
284
285impl<SC: StarkGenericConfig> hal::QuotientCommitter<CpuBackend<SC>> for CpuDevice<SC> {
286    fn eval_and_commit_quotient(
287        &self,
288        challenger: &mut SC::Challenger,
289        pk_views: &[&DeviceStarkProvingKey<CpuBackend<SC>>],
290        public_values: &[Vec<Val<SC>>],
291        cached_pcs_datas_per_air: &[Vec<PcsData<SC>>],
292        common_main_pcs_data: &PcsData<SC>,
293        prover_data_after: &ProverDataAfterRapPhases<CpuBackend<SC>>,
294    ) -> (Com<SC>, PcsData<SC>) {
295        let pcs = self.pcs();
296        // Generate `alpha` challenge
297        let alpha: SC::Challenge = challenger.sample_ext_element();
298        tracing::debug!("alpha: {alpha:?}");
299        // Prepare extended views:
300        let mut common_main_idx = 0;
301        let extended_views = izip!(pk_views, cached_pcs_datas_per_air, public_values)
302            .enumerate()
303            .map(|(i, (pk, cached_pcs_datas, pvs))| {
304                let quotient_degree = pk.vk.quotient_degree;
305                let log_trace_height = if pk.vk.has_common_main() {
306                    common_main_pcs_data.log_trace_heights[common_main_idx]
307                } else {
308                    cached_pcs_datas[0].log_trace_heights[0]
309                };
310                let trace_domain = pcs.natural_domain_for_degree(1usize << log_trace_height);
311                let quotient_domain = trace_domain
312                    .create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
313                // **IMPORTANT**: the return type of `get_evaluations_on_domain` is a matrix view.
314                // DO NOT call to_row_major_matrix as this will allocate new memory
315                let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
316                    pcs.get_evaluations_on_domain(
317                        &cv.data.data,
318                        cv.matrix_idx as usize,
319                        quotient_domain,
320                    )
321                });
322                // Each cached pcs data is commitment of a single matrix, so matrix_idx=0
323                let mut partitioned_main: Vec<_> = cached_pcs_datas
324                    .iter()
325                    .map(|cv| pcs.get_evaluations_on_domain(&cv.data, 0, quotient_domain))
326                    .collect();
327                if pk.vk.has_common_main() {
328                    partitioned_main.push(pcs.get_evaluations_on_domain(
329                        &common_main_pcs_data.data,
330                        common_main_idx,
331                        quotient_domain,
332                    ));
333                    common_main_idx += 1;
334                }
335                let mut per_phase = zip(
336                    &prover_data_after.committed_pcs_data_per_phase,
337                    &prover_data_after.rap_views_per_phase,
338                )
339                .map(|((_, pcs_data), rap_views)| -> Option<_> {
340                    let rap_view = rap_views.get(i)?;
341                    let matrix_idx = rap_view.inner?;
342                    let extended_matrix =
343                        pcs.get_evaluations_on_domain(&pcs_data.data, matrix_idx, quotient_domain);
344                    Some(RapSinglePhaseView {
345                        inner: Some(extended_matrix),
346                        challenges: rap_view.challenges.clone(),
347                        exposed_values: rap_view.exposed_values.clone(),
348                    })
349                })
350                .collect_vec();
351                while let Some(last) = per_phase.last() {
352                    if last.is_none() {
353                        per_phase.pop();
354                    } else {
355                        break;
356                    }
357                }
358                let per_phase = per_phase
359                    .into_iter()
360                    .map(|v| v.unwrap_or_default())
361                    .collect();
362
363                RapView {
364                    log_trace_height,
365                    preprocessed,
366                    partitioned_main,
367                    public_values: pvs.to_vec(),
368                    per_phase,
369                }
370            })
371            .collect_vec();
372
373        let (constraints, quotient_degrees): (Vec<_>, Vec<_>) = pk_views
374            .iter()
375            .map(|pk| {
376                (
377                    &pk.vk.symbolic_constraints.constraints,
378                    pk.vk.quotient_degree,
379                )
380            })
381            .unzip();
382        let qc = QuotientCommitter::new(self.pcs(), alpha, self.log_blowup_factor);
383        let quotient_values = qc.quotient_values(&constraints, extended_views, &quotient_degrees);
384
385        // Commit to quotient polynomials. One shared commit for all quotient polynomials
386        qc.commit(quotient_values)
387    }
388}
389
390impl<SC: StarkGenericConfig> hal::OpeningProver<CpuBackend<SC>> for CpuDevice<SC> {
391    fn open(
392        &self,
393        challenger: &mut SC::Challenger,
394        // For each preprocessed trace commitment, the prover data and
395        // the log height of the matrix, in order
396        preprocessed: Vec<&PcsData<SC>>,
397        // For each main trace commitment, the prover data and
398        // the log height of each matrix, in order
399        // Note: this is all one challenge phase.
400        main: Vec<PcsData<SC>>,
401        // `after_phase[i]` has shared commitment prover data for all matrices in phase `i + 1`.
402        after_phase: Vec<PcsData<SC>>,
403        // Quotient poly commitment prover data
404        quotient_data: PcsData<SC>,
405        // Quotient degree for each RAP committed in quotient_data, in order
406        quotient_degrees: &[u8],
407    ) -> OpeningProof<PcsProof<SC>, SC::Challenge> {
408        // Draw `zeta` challenge
409        let zeta: SC::Challenge = challenger.sample_ext_element();
410        tracing::debug!("zeta: {zeta:?}");
411
412        let pcs = self.pcs();
413        let domain = |log_height| pcs.natural_domain_for_degree(1usize << log_height);
414        let opener = OpeningProver::<SC>::new(pcs, zeta);
415        let preprocessed = preprocessed
416            .iter()
417            .map(|v| {
418                assert_eq!(v.log_trace_heights.len(), 1);
419                (v.data.as_ref(), domain(v.log_trace_heights[0]))
420            })
421            .collect();
422        let main = main
423            .iter()
424            .map(|v| {
425                let domains = v.log_trace_heights.iter().copied().map(domain).collect();
426                (v.data.as_ref(), domains)
427            })
428            .collect();
429        let after_phase: Vec<_> = after_phase
430            .iter()
431            .map(|v| {
432                let domains = v.log_trace_heights.iter().copied().map(domain).collect();
433                (v.data.as_ref(), domains)
434            })
435            .collect();
436        opener.open(
437            challenger,
438            preprocessed,
439            main,
440            after_phase,
441            &quotient_data.data,
442            quotient_degrees,
443        )
444    }
445}
446
447impl<SC> DeviceDataTransporter<SC, CpuBackend<SC>> for CpuDevice<SC>
448where
449    SC: StarkGenericConfig,
450{
451    fn transport_pk_to_device(
452        &self,
453        mpk: &MultiStarkProvingKey<SC>,
454    ) -> DeviceMultiStarkProvingKey<CpuBackend<SC>> {
455        let per_air = mpk
456            .per_air
457            .iter()
458            .map(|pk| {
459                let preprocessed_data = pk.preprocessed_data.as_ref().map(|pd| {
460                    let pcs_data_view = PcsData {
461                        data: pd.data.clone(),
462                        log_trace_heights: vec![log2_strict_usize(pd.trace.height()) as u8],
463                    };
464                    SingleCommitPreimage {
465                        trace: pd.trace.clone(),
466                        data: pcs_data_view,
467                        matrix_idx: 0,
468                    }
469                });
470                DeviceStarkProvingKey {
471                    air_name: pk.air_name.clone(),
472                    vk: pk.vk.clone(),
473                    preprocessed_data,
474                    rap_partial_pk: pk.rap_partial_pk.clone(),
475                }
476            })
477            .collect();
478        DeviceMultiStarkProvingKey::new(
479            per_air,
480            mpk.trace_height_constraints.clone(),
481            mpk.vk_pre_hash.clone(),
482        )
483    }
484    fn transport_matrix_to_device(
485        &self,
486        matrix: &Arc<RowMajorMatrix<Val<SC>>>,
487    ) -> Arc<RowMajorMatrix<Val<SC>>> {
488        matrix.clone()
489    }
490
491    fn transport_committed_trace_to_device(
492        &self,
493        commitment: Com<SC>,
494        trace: &Arc<RowMajorMatrix<Val<SC>>>,
495        prover_data: &Arc<PcsProverData<SC>>,
496    ) -> CommittedTraceData<CpuBackend<SC>> {
497        let log_trace_height: u8 = log2_strict_usize(trace.height()).try_into().unwrap();
498        let data = PcsData::new(prover_data.clone(), vec![log_trace_height]);
499        CommittedTraceData {
500            commitment,
501            trace: trace.clone(),
502            data,
503        }
504    }
505
506    fn transport_matrix_from_device_to_host(
507        &self,
508        matrix: &Arc<RowMajorMatrix<Val<SC>>>,
509    ) -> Arc<RowMajorMatrix<Val<SC>>> {
510        matrix.clone()
511    }
512}
513
514// TODO[jpw]: Avoid using this after switching to new plonky3 commit with <https://github.com/Plonky3/Plonky3/pull/796>
515/// # Safety
516/// Assumes that `EF` is `repr(C)` or `repr(transparent)` with internal memory layout `[F; EF::D]`.
517/// This ensures `EF` and `F` have the same alignment and `size_of::<EF>() == size_of::<F>() *
518/// EF::D`. We assume that `EF::as_base_slice` is the same as transmuting `EF` to `[F; EF::D]`.
519unsafe fn transmute_to_base<F: Field, EF: ExtensionField<F>>(
520    ext_matrix: RowMajorMatrix<EF>,
521) -> RowMajorMatrix<F> {
522    debug_assert_eq!(align_of::<EF>(), align_of::<F>());
523    debug_assert_eq!(size_of::<EF>(), size_of::<F>() * EF::D);
524    let width = ext_matrix.width * EF::D;
525    // Prevent ptr from deallocating
526    let mut values = ManuallyDrop::new(ext_matrix.values);
527    let mut len = values.len();
528    let mut cap = values.capacity();
529    let ptr = values.as_mut_ptr();
530    len *= EF::D;
531    cap *= EF::D;
532    // SAFETY:
533    // - We know that `ptr` is from `Vec` so it is allocated by global allocator,
534    // - Based on assumptions, `EF` and `F` have the same alignment
535    // - Based on memory layout assumptions, length and capacity is correct
536    let base_values = Vec::from_raw_parts(ptr as *mut F, len, cap);
537    RowMajorMatrix::new(base_values, width)
538}