openvm_stark_backend/prover/cpu/
mod.rs

1use std::{iter::zip, marker::PhantomData, mem::ManuallyDrop, ops::Deref, sync::Arc};
2
3use derivative::Derivative;
4use itertools::{izip, zip_eq, Itertools};
5use opener::OpeningProver;
6use p3_challenger::{FieldChallenger, GrindingChallenger};
7use p3_commit::{Pcs, PolynomialSpace};
8use p3_field::{BasedVectorSpace, ExtensionField, Field};
9use p3_matrix::{dense::RowMajorMatrix, Matrix};
10use p3_util::log2_strict_usize;
11use quotient::QuotientCommitter;
12use tracing::info_span;
13
14use super::{
15    hal::{self, DeviceDataTransporter, MatrixDimensions, ProverBackend, ProverDevice},
16    types::{
17        AirView, DeviceMultiStarkProvingKey, DeviceStarkProvingKey, ProverDataAfterRapPhases,
18        RapView, SingleCommitPreimage,
19    },
20};
21use crate::{
22    air_builders::symbolic::SymbolicConstraints,
23    config::{
24        Com, PcsProverData, RapPartialProvingKey, RapPhaseSeqPartialProof, StarkGenericConfig, Val,
25    },
26    interaction::RapPhaseSeq,
27    keygen::types::MultiStarkProvingKey,
28    proof::OpeningProof,
29    prover::{
30        hal::TraceCommitter,
31        types::{CommittedTraceData, DeviceMultiStarkProvingKeyView, PairView, RapSinglePhaseView},
32    },
33};
34
35/// Polynomial opening proofs
36pub mod opener;
37/// Computation of DEEP quotient polynomial and commitment
38pub mod quotient;
39
40/// CPU backend using Plonky3 traits.
41///
42/// # Safety
43/// For performance optimization of extension field operations, we assumes that `SC::Challenge` is
44/// an extension field of `F = Val<SC>` that is `repr(C)` or `repr(transparent)` with
45/// internal memory layout `[F; SC::Challenge::DIMENSION]`.
46/// This ensures `SC::Challenge` and `F` have the same alignment and
47/// `size_of::<SC::Challenge>() == size_of::<F>() * SC::Challenge::DIMENSION`.
48/// We assume that `<SC::Challenge as ExtensionField<F>::as_basis_coefficients_slice` is the same as
49/// transmuting `SC::Challenge` to `[F; SC::Challenge::DIMENSION]`.
50#[derive(Derivative)]
51#[derivative(Clone(bound = ""), Copy(bound = ""), Default(bound = ""))]
52pub struct CpuBackend<SC> {
53    phantom: PhantomData<SC>,
54}
55
56/// # Safety
57/// See [`CpuBackend`].
58#[derive(Derivative, derive_new::new)]
59#[derivative(Clone(bound = ""))]
60pub struct CpuDevice<SC> {
61    // Use Arc to get around Clone-ing SC
62    pub config: Arc<SC>,
63    /// When committing a matrix, the matrix is cloned into newly allocated memory.
64    /// The size of the newly allocated memory will be `matrix.size() << log_blowup_factor`.
65    log_blowup_factor: usize,
66}
67
68impl<SC: StarkGenericConfig> ProverBackend for CpuBackend<SC> {
69    const CHALLENGE_EXT_DEGREE: u8 = <SC::Challenge as BasedVectorSpace<Val<SC>>>::DIMENSION as u8;
70
71    type Val = Val<SC>;
72    type Challenge = SC::Challenge;
73    type OpeningProof = OpeningProof<SC>;
74    type RapPartialProof = Option<RapPhaseSeqPartialProof<SC>>;
75    type Commitment = Com<SC>;
76    type Challenger = SC::Challenger;
77    type Matrix = Arc<RowMajorMatrix<Val<SC>>>;
78    type PcsData = PcsData<SC>;
79    type RapPartialProvingKey = RapPartialProvingKey<SC>;
80}
81
82#[derive(Derivative, derive_new::new)]
83#[derivative(Clone(bound = ""))]
84pub struct PcsData<SC: StarkGenericConfig> {
85    /// The preimage of a single commitment.
86    pub data: Arc<PcsProverData<SC>>,
87    /// A mixed matrix commitment scheme commits to multiple trace matrices within a single
88    /// commitment. This is the ordered list of log2 heights of all committed trace matrices.
89    pub log_trace_heights: Vec<u8>,
90}
91
92impl<T: Send + Sync + Clone> MatrixDimensions for Arc<RowMajorMatrix<T>> {
93    fn height(&self) -> usize {
94        self.deref().height()
95    }
96    fn width(&self) -> usize {
97        self.deref().width()
98    }
99}
100
101impl<SC> CpuDevice<SC> {
102    pub fn config(&self) -> &SC {
103        &self.config
104    }
105}
106
107impl<SC: StarkGenericConfig> CpuDevice<SC> {
108    pub fn pcs(&self) -> &SC::Pcs {
109        self.config.pcs()
110    }
111}
112
113impl<SC: StarkGenericConfig> ProverDevice<CpuBackend<SC>> for CpuDevice<SC> {}
114
115impl<SC: StarkGenericConfig> TraceCommitter<CpuBackend<SC>> for CpuDevice<SC> {
116    fn commit(&self, traces: &[Arc<RowMajorMatrix<Val<SC>>>]) -> (Com<SC>, PcsData<SC>) {
117        let log_blowup_factor = self.log_blowup_factor;
118        let pcs = self.pcs();
119        let (log_trace_heights, traces_with_domains): (Vec<_>, Vec<_>) = traces
120            .iter()
121            .map(|matrix| {
122                let height = matrix.height();
123                let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
124                // Recomputing the domain is lightweight
125                let domain = pcs.natural_domain_for_degree(height);
126                // pcs.commit takes the trace matrix and in the case of FRI, does in-place cosetDFT
127                // which requires resizing to a larger buffer size. Since we are cloning anyways,
128                // we should just allocate the larger size to avoid memory-reallocation
129                // ref: https://github.com/Plonky3/Plonky3/blob/8c8bbb4c17bd2b7ef2404338ab8f9036d5f08337/dft/src/traits.rs#L116
130                let trace_slice = &matrix.as_ref().values;
131                let new_buffer_size = trace_slice
132                    .len()
133                    .checked_shl(log_blowup_factor.try_into().unwrap())
134                    .unwrap();
135                let mut new_buffer = Vec::with_capacity(new_buffer_size);
136                // SAFETY:
137                // - `trace_slice` is allocated for `trace_slice.len() * size_of::<F>` bytes,
138                //   obviously
139                // - we just allocated `new_buffer` for at least `trace_slice.len() * size_of::<F>`
140                //   bytes above (more if there's blowup)
141                // - both are slices of &[F] so alignment is guaranteed
142                // - `new_buffer` is newly allocated so non-overlapping with `trace_slice`
143                unsafe {
144                    std::ptr::copy_nonoverlapping(
145                        trace_slice.as_ptr(),
146                        new_buffer.as_mut_ptr(),
147                        trace_slice.len(),
148                    );
149                    new_buffer.set_len(trace_slice.len());
150                }
151                (
152                    log_height,
153                    (domain, RowMajorMatrix::new(new_buffer, matrix.width)),
154                )
155            })
156            .unzip();
157        let (commit, data) = pcs.commit(traces_with_domains);
158        (
159            commit,
160            PcsData {
161                data: Arc::new(data),
162                log_trace_heights,
163            },
164        )
165    }
166}
167
168impl<SC: StarkGenericConfig> hal::RapPartialProver<CpuBackend<SC>> for CpuDevice<SC> {
169    fn partially_prove(
170        &self,
171        challenger: &mut SC::Challenger,
172        mpk: &DeviceMultiStarkProvingKeyView<CpuBackend<SC>>,
173        trace_views: Vec<AirView<Arc<RowMajorMatrix<Val<SC>>>, Val<SC>>>,
174    ) -> (
175        Option<RapPhaseSeqPartialProof<SC>>,
176        ProverDataAfterRapPhases<CpuBackend<SC>>,
177    ) {
178        let num_airs = mpk.per_air.len();
179        assert_eq!(num_airs, trace_views.len());
180
181        let (constraints_per_air, rap_pk_per_air): (Vec<_>, Vec<_>) = mpk
182            .per_air
183            .iter()
184            .map(|pk| {
185                (
186                    SymbolicConstraints::from(&pk.vk.symbolic_constraints),
187                    &pk.rap_partial_pk,
188                )
189            })
190            .unzip();
191
192        let trace_views = zip(&mpk.per_air, trace_views)
193            .map(|(pk, v)| PairView {
194                log_trace_height: log2_strict_usize(v.partitioned_main.first().unwrap().height())
195                    as u8,
196                preprocessed: pk.preprocessed_data.as_ref().map(|p| p.trace.clone()), // Arc::clone for now
197                partitioned_main: v.partitioned_main,
198                public_values: v.public_values,
199            })
200            .collect_vec();
201        let (rap_phase_seq_proof, rap_phase_seq_data) = self
202            .config()
203            .rap_phase_seq()
204            .partially_prove(
205                challenger,
206                &constraints_per_air.iter().collect_vec(),
207                &rap_pk_per_air,
208                trace_views,
209            )
210            .map_or((None, None), |(p, d)| (Some(p), Some(d)));
211
212        let mvk_view = mpk.vk_view();
213
214        let mut perm_matrix_idx = 0usize;
215        let rap_views_per_phase;
216        let perm_trace_per_air = if let Some(phase_data) = rap_phase_seq_data {
217            assert_eq!(mvk_view.num_phases(), 1);
218            assert_eq!(
219                mvk_view.num_challenges_in_phase(0),
220                phase_data.challenges.len()
221            );
222            let perm_views = zip_eq(
223                &phase_data.after_challenge_trace_per_air,
224                phase_data.exposed_values_per_air,
225            )
226            .map(|(perm_trace, exposed_values)| {
227                let mut matrix_idx = None;
228                if perm_trace.is_some() {
229                    matrix_idx = Some(perm_matrix_idx);
230                    perm_matrix_idx += 1;
231                }
232                RapSinglePhaseView {
233                    inner: matrix_idx,
234                    challenges: phase_data.challenges.clone(),
235                    exposed_values: exposed_values.unwrap_or_default(),
236                }
237            })
238            .collect_vec();
239            rap_views_per_phase = vec![perm_views]; // 1 challenge phase
240            phase_data.after_challenge_trace_per_air
241        } else {
242            assert_eq!(mvk_view.num_phases(), 0);
243            rap_views_per_phase = vec![];
244            vec![None; num_airs]
245        };
246
247        // Commit to permutation traces: this means only 1 challenge round right now
248        // One shared commit for all permutation traces
249        let committed_pcs_data_per_phase: Vec<(Com<SC>, PcsData<SC>)> =
250            info_span!("perm_trace_commit")
251                .in_scope(|| {
252                    let (log_trace_heights, flattened_traces): (Vec<_>, Vec<_>) =
253                        perm_trace_per_air
254                            .into_iter()
255                            .flatten()
256                            .map(|perm_trace| {
257                                // SAFETY: `Challenge` is assumed to be extension field of `F`
258                                // with memory layout `[F; Challenge::DIMENSION]`
259                                let trace = unsafe { transmute_to_base(perm_trace) };
260                                let height = trace.height();
261                                let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
262                                let domain = self.pcs().natural_domain_for_degree(height);
263                                (log_height, (domain, trace))
264                            })
265                            .collect();
266                    // Only commit if there are permutation traces
267                    if !flattened_traces.is_empty() {
268                        let (commit, data) = self.pcs().commit(flattened_traces);
269                        Some((commit, PcsData::new(Arc::new(data), log_trace_heights)))
270                    } else {
271                        None
272                    }
273                })
274                .into_iter()
275                .collect();
276        let prover_view = ProverDataAfterRapPhases {
277            committed_pcs_data_per_phase,
278            rap_views_per_phase,
279        };
280        (rap_phase_seq_proof, prover_view)
281    }
282}
283
284impl<SC: StarkGenericConfig> hal::QuotientCommitter<CpuBackend<SC>> for CpuDevice<SC> {
285    fn eval_and_commit_quotient(
286        &self,
287        challenger: &mut SC::Challenger,
288        pk_views: &[&DeviceStarkProvingKey<CpuBackend<SC>>],
289        public_values: &[Vec<Val<SC>>],
290        cached_pcs_datas_per_air: &[Vec<PcsData<SC>>],
291        common_main_pcs_data: &PcsData<SC>,
292        prover_data_after: &ProverDataAfterRapPhases<CpuBackend<SC>>,
293    ) -> (Com<SC>, PcsData<SC>) {
294        let pcs = self.pcs();
295        // Generate `alpha` challenge for algebraic batching
296        let alpha: SC::Challenge = challenger.sample_algebra_element();
297        tracing::debug!("alpha: {alpha:?}");
298        // Prepare extended views:
299        let mut common_main_idx = 0;
300        let extended_views = izip!(pk_views, cached_pcs_datas_per_air, public_values)
301            .enumerate()
302            .map(|(i, (pk, cached_pcs_datas, pvs))| {
303                let quotient_degree = pk.vk.quotient_degree;
304                let log_trace_height = if pk.vk.has_common_main() {
305                    common_main_pcs_data.log_trace_heights[common_main_idx]
306                } else {
307                    cached_pcs_datas[0].log_trace_heights[0]
308                };
309                let trace_domain = pcs.natural_domain_for_degree(1usize << log_trace_height);
310                let quotient_domain = trace_domain
311                    .create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
312                // **IMPORTANT**: the return type of `get_evaluations_on_domain` is a matrix view.
313                // DO NOT call to_row_major_matrix as this will allocate new memory
314                let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
315                    pcs.get_evaluations_on_domain(
316                        &cv.data.data,
317                        cv.matrix_idx as usize,
318                        quotient_domain,
319                    )
320                });
321                // Each cached pcs data is commitment of a single matrix, so matrix_idx=0
322                let mut partitioned_main: Vec<_> = cached_pcs_datas
323                    .iter()
324                    .map(|cv| pcs.get_evaluations_on_domain(&cv.data, 0, quotient_domain))
325                    .collect();
326                if pk.vk.has_common_main() {
327                    partitioned_main.push(pcs.get_evaluations_on_domain(
328                        &common_main_pcs_data.data,
329                        common_main_idx,
330                        quotient_domain,
331                    ));
332                    common_main_idx += 1;
333                }
334                let mut per_phase = zip(
335                    &prover_data_after.committed_pcs_data_per_phase,
336                    &prover_data_after.rap_views_per_phase,
337                )
338                .map(|((_, pcs_data), rap_views)| -> Option<_> {
339                    let rap_view = rap_views.get(i)?;
340                    let matrix_idx = rap_view.inner?;
341                    let extended_matrix =
342                        pcs.get_evaluations_on_domain(&pcs_data.data, matrix_idx, quotient_domain);
343                    Some(RapSinglePhaseView {
344                        inner: Some(extended_matrix),
345                        challenges: rap_view.challenges.clone(),
346                        exposed_values: rap_view.exposed_values.clone(),
347                    })
348                })
349                .collect_vec();
350                while let Some(last) = per_phase.last() {
351                    if last.is_none() {
352                        per_phase.pop();
353                    } else {
354                        break;
355                    }
356                }
357                let per_phase = per_phase
358                    .into_iter()
359                    .map(|v| v.unwrap_or_default())
360                    .collect();
361
362                RapView {
363                    log_trace_height,
364                    preprocessed,
365                    partitioned_main,
366                    public_values: pvs.to_vec(),
367                    per_phase,
368                }
369            })
370            .collect_vec();
371
372        let (constraints, quotient_degrees): (Vec<_>, Vec<_>) = pk_views
373            .iter()
374            .map(|pk| {
375                (
376                    &pk.vk.symbolic_constraints.constraints,
377                    pk.vk.quotient_degree,
378                )
379            })
380            .unzip();
381        let qc = QuotientCommitter::new(self.pcs(), alpha, self.log_blowup_factor);
382        let quotient_values = qc.quotient_values(&constraints, extended_views, &quotient_degrees);
383
384        // Commit to quotient polynomials. One shared commit for all quotient polynomials
385        qc.commit(quotient_values)
386    }
387}
388
389impl<SC: StarkGenericConfig> hal::OpeningProver<CpuBackend<SC>> for CpuDevice<SC> {
390    fn open(
391        &self,
392        challenger: &mut SC::Challenger,
393        // For each preprocessed trace commitment, the prover data and
394        // the log height of the matrix, in order
395        preprocessed: Vec<&PcsData<SC>>,
396        // For each main trace commitment, the prover data and
397        // the log height of each matrix, in order
398        // Note: this is all one challenge phase.
399        main: Vec<PcsData<SC>>,
400        // `after_phase[i]` has shared commitment prover data for all matrices in phase `i + 1`.
401        after_phase: Vec<PcsData<SC>>,
402        // Quotient poly commitment prover data
403        quotient_data: PcsData<SC>,
404        // Quotient degree for each RAP committed in quotient_data, in order
405        quotient_degrees: &[u8],
406    ) -> OpeningProof<SC> {
407        // Grind before non-interactive Fiat-Shamir sampling of the out-of-domain point
408        let deep_pow_witness = challenger.grind(self.config().deep_ali_params().deep_pow_bits);
409        // Draw `zeta` challenge
410        let zeta: SC::Challenge = challenger.sample_algebra_element();
411        tracing::debug!("zeta: {zeta:?}");
412
413        let pcs = self.pcs();
414        let domain = |log_height| pcs.natural_domain_for_degree(1usize << log_height);
415        let opener = OpeningProver::<SC>::new(pcs, zeta, deep_pow_witness);
416        let preprocessed = preprocessed
417            .iter()
418            .map(|v| {
419                assert_eq!(v.log_trace_heights.len(), 1);
420                (v.data.as_ref(), domain(v.log_trace_heights[0]))
421            })
422            .collect();
423        let main = main
424            .iter()
425            .map(|v| {
426                let domains = v.log_trace_heights.iter().copied().map(domain).collect();
427                (v.data.as_ref(), domains)
428            })
429            .collect();
430        let after_phase: Vec<_> = after_phase
431            .iter()
432            .map(|v| {
433                let domains = v.log_trace_heights.iter().copied().map(domain).collect();
434                (v.data.as_ref(), domains)
435            })
436            .collect();
437        opener.open(
438            challenger,
439            preprocessed,
440            main,
441            after_phase,
442            &quotient_data.data,
443            quotient_degrees,
444        )
445    }
446}
447
448impl<SC> DeviceDataTransporter<SC, CpuBackend<SC>> for CpuDevice<SC>
449where
450    SC: StarkGenericConfig,
451{
452    fn transport_pk_to_device(
453        &self,
454        mpk: &MultiStarkProvingKey<SC>,
455    ) -> DeviceMultiStarkProvingKey<CpuBackend<SC>> {
456        let per_air = mpk
457            .per_air
458            .iter()
459            .map(|pk| {
460                let preprocessed_data = pk.preprocessed_data.as_ref().map(|pd| {
461                    let pcs_data_view = PcsData {
462                        data: pd.data.clone(),
463                        log_trace_heights: vec![log2_strict_usize(pd.trace.height()) as u8],
464                    };
465                    SingleCommitPreimage {
466                        trace: pd.trace.clone(),
467                        data: pcs_data_view,
468                        matrix_idx: 0,
469                    }
470                });
471                DeviceStarkProvingKey {
472                    air_name: pk.air_name.clone(),
473                    vk: pk.vk.clone(),
474                    preprocessed_data,
475                    rap_partial_pk: pk.rap_partial_pk.clone(),
476                }
477            })
478            .collect();
479        DeviceMultiStarkProvingKey::new(
480            per_air,
481            mpk.trace_height_constraints.clone(),
482            mpk.vk_pre_hash.clone(),
483        )
484    }
485    fn transport_matrix_to_device(
486        &self,
487        matrix: &Arc<RowMajorMatrix<Val<SC>>>,
488    ) -> Arc<RowMajorMatrix<Val<SC>>> {
489        matrix.clone()
490    }
491
492    fn transport_committed_trace_to_device(
493        &self,
494        commitment: Com<SC>,
495        trace: &Arc<RowMajorMatrix<Val<SC>>>,
496        prover_data: &Arc<PcsProverData<SC>>,
497    ) -> CommittedTraceData<CpuBackend<SC>> {
498        let log_trace_height: u8 = log2_strict_usize(trace.height()).try_into().unwrap();
499        let data = PcsData::new(prover_data.clone(), vec![log_trace_height]);
500        CommittedTraceData {
501            commitment,
502            trace: trace.clone(),
503            data,
504        }
505    }
506
507    fn transport_matrix_from_device_to_host(
508        &self,
509        matrix: &Arc<RowMajorMatrix<Val<SC>>>,
510    ) -> Arc<RowMajorMatrix<Val<SC>>> {
511        matrix.clone()
512    }
513}
514
515// TODO[jpw]: Avoid using this after switching to new plonky3 commit with <https://github.com/Plonky3/Plonky3/pull/796>
516/// # Safety
517/// Assumes that `EF` is `repr(C)` or `repr(transparent)` with internal memory layout `[F;
518/// EF::DIMENSION]`. This ensures `EF` and `F` have the same alignment and `size_of::<EF>() ==
519/// size_of::<F>() * EF::DIMENSION`. We assume that `EF::as_basis_coefficients_slice` is the same as
520/// transmuting `EF` to `[F; EF::DIMENSION]`.
521unsafe fn transmute_to_base<F: Field, EF: ExtensionField<F>>(
522    ext_matrix: RowMajorMatrix<EF>,
523) -> RowMajorMatrix<F> {
524    debug_assert_eq!(align_of::<EF>(), align_of::<F>());
525    debug_assert_eq!(size_of::<EF>(), size_of::<F>() * EF::DIMENSION);
526    let width = ext_matrix.width * EF::DIMENSION;
527    // Prevent ptr from deallocating
528    let mut values = ManuallyDrop::new(ext_matrix.values);
529    let mut len = values.len();
530    let mut cap = values.capacity();
531    let ptr = values.as_mut_ptr();
532    len *= EF::DIMENSION;
533    cap *= EF::DIMENSION;
534    // SAFETY:
535    // - We know that `ptr` is from `Vec` so it is allocated by global allocator,
536    // - Based on assumptions, `EF` and `F` have the same alignment
537    // - Based on memory layout assumptions, length and capacity is correct
538    let base_values = Vec::from_raw_parts(ptr as *mut F, len, cap);
539    RowMajorMatrix::new(base_values, width)
540}