openvm_stark_backend/prover/cpu/
mod.rs

1use std::{iter::zip, marker::PhantomData, ops::Deref, sync::Arc};
2
3use derivative::Derivative;
4use itertools::{izip, zip_eq, Itertools};
5use opener::OpeningProver;
6use p3_challenger::FieldChallenger;
7use p3_commit::{Pcs, PolynomialSpace};
8use p3_field::FieldExtensionAlgebra;
9use p3_matrix::{dense::RowMajorMatrix, Matrix};
10use p3_util::log2_strict_usize;
11use quotient::QuotientCommitter;
12
13use super::{
14    hal::{self, DeviceDataTransporter, MatrixDimensions, ProverBackend, ProverDevice},
15    types::{
16        DeviceMultiStarkProvingKey, DeviceStarkProvingKey, PairView, ProverDataAfterRapPhases,
17        RapView, SingleCommitPreimage,
18    },
19};
20use crate::{
21    air_builders::symbolic::SymbolicConstraints,
22    config::{
23        Com, PcsProof, PcsProverData, RapPartialProvingKey, RapPhaseSeqPartialProof,
24        StarkGenericConfig, Val,
25    },
26    interaction::RapPhaseSeq,
27    keygen::types::MultiStarkProvingKey,
28    proof::OpeningProof,
29    prover::{hal::TraceCommitter, types::RapSinglePhaseView},
30    utils::metrics_span,
31};
32
33/// Polynomial opening proofs
34pub mod opener;
35/// Computation of DEEP quotient polynomial and commitment
36pub mod quotient;
37
38/// Proves multiple chips with interactions together.
39/// This prover implementation is specialized for Interactive AIRs.
40pub struct MultiTraceStarkProver<'c, SC: StarkGenericConfig> {
41    pub config: &'c SC,
42}
43
44/// CPU backend using Plonky3 traits.
45#[derive(Derivative)]
46#[derivative(Clone(bound = ""), Copy(bound = ""), Default(bound = ""))]
47pub struct CpuBackend<SC> {
48    phantom: PhantomData<SC>,
49}
50
51#[derive(Derivative, derive_new::new)]
52#[derivative(Clone(bound = ""), Copy(bound = ""))]
53pub struct CpuDevice<'a, SC> {
54    config: &'a SC,
55}
56
57impl<SC: StarkGenericConfig> ProverBackend for CpuBackend<SC> {
58    const CHALLENGE_EXT_DEGREE: u8 = <SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D as u8;
59
60    type Val = Val<SC>;
61    type Challenge = SC::Challenge;
62    type OpeningProof = OpeningProof<PcsProof<SC>, SC::Challenge>;
63    type RapPartialProof = Option<RapPhaseSeqPartialProof<SC>>;
64    type Commitment = Com<SC>;
65    type Challenger = SC::Challenger;
66    type Matrix = Arc<RowMajorMatrix<Val<SC>>>;
67    type PcsData = PcsData<SC>;
68    type RapPartialProvingKey = RapPartialProvingKey<SC>;
69}
70
71#[derive(Derivative)]
72#[derivative(Clone(bound = ""))]
73pub struct PcsData<SC: StarkGenericConfig> {
74    /// The preimage of a single commitment.
75    pub data: Arc<PcsProverData<SC>>,
76    /// A mixed matrix commitment scheme commits to multiple trace matrices within a single commitment.
77    /// This is the ordered list of log2 heights of all committed trace matrices.
78    pub log_trace_heights: Vec<u8>,
79}
80
81impl<T: Send + Sync + Clone> MatrixDimensions for Arc<RowMajorMatrix<T>> {
82    fn height(&self) -> usize {
83        self.deref().height()
84    }
85    fn width(&self) -> usize {
86        self.deref().width()
87    }
88}
89
90impl<SC> CpuDevice<'_, SC> {
91    pub fn config(&self) -> &SC {
92        self.config
93    }
94}
95
96impl<SC: StarkGenericConfig> CpuDevice<'_, SC> {
97    pub fn pcs(&self) -> &SC::Pcs {
98        self.config.pcs()
99    }
100}
101
102impl<SC: StarkGenericConfig> ProverDevice<CpuBackend<SC>> for CpuDevice<'_, SC> {}
103
104impl<SC: StarkGenericConfig> TraceCommitter<CpuBackend<SC>> for CpuDevice<'_, SC> {
105    fn commit(&self, traces: &[Arc<RowMajorMatrix<Val<SC>>>]) -> (Com<SC>, PcsData<SC>) {
106        let pcs = self.pcs();
107        let (log_trace_heights, traces_with_domains): (Vec<_>, Vec<_>) = traces
108            .iter()
109            .map(|matrix| {
110                let height = matrix.height();
111                let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
112                // Recomputing the domain is lightweight
113                let domain = pcs.natural_domain_for_degree(height);
114                (log_height, (domain, matrix.as_ref().clone()))
115            })
116            .unzip();
117        let (commit, data) = pcs.commit(traces_with_domains);
118        (
119            commit,
120            PcsData {
121                data: Arc::new(data),
122                log_trace_heights,
123            },
124        )
125    }
126}
127
128impl<SC: StarkGenericConfig> hal::RapPartialProver<CpuBackend<SC>> for CpuDevice<'_, SC> {
129    fn partially_prove<'a>(
130        &self,
131        challenger: &mut SC::Challenger,
132        mpk: &DeviceMultiStarkProvingKey<'a, CpuBackend<SC>>,
133        trace_views: Vec<PairView<&'a Arc<RowMajorMatrix<Val<SC>>>, Val<SC>>>,
134    ) -> (
135        Option<RapPhaseSeqPartialProof<SC>>,
136        ProverDataAfterRapPhases<CpuBackend<SC>>,
137    ) {
138        let num_airs = mpk.per_air.len();
139        assert_eq!(num_airs, trace_views.len());
140
141        let (constraints_per_air, rap_pk_per_air): (Vec<_>, Vec<_>) = mpk
142            .per_air
143            .iter()
144            .map(|pk| {
145                (
146                    SymbolicConstraints::from(&pk.vk.symbolic_constraints),
147                    &pk.rap_partial_pk,
148                )
149            })
150            .unzip();
151
152        let trace_views = trace_views
153            .iter()
154            .map(|v| PairView {
155                log_trace_height: v.log_trace_height,
156                preprocessed: v.preprocessed.as_ref().map(|p| p.as_ref()),
157                partitioned_main: v.partitioned_main.iter().map(|m| m.as_ref()).collect(),
158                public_values: v.public_values.clone(),
159            })
160            .collect_vec();
161        let (rap_phase_seq_proof, rap_phase_seq_data) = self
162            .config()
163            .rap_phase_seq()
164            .partially_prove(
165                challenger,
166                &constraints_per_air.iter().collect_vec(),
167                &rap_pk_per_air,
168                &trace_views,
169            )
170            .map_or((None, None), |(p, d)| (Some(p), Some(d)));
171
172        let mvk_view = mpk.vk_view();
173
174        let mut perm_matrix_idx = 0usize;
175        let rap_views_per_phase;
176        let perm_trace_per_air = if let Some(phase_data) = rap_phase_seq_data {
177            assert_eq!(mvk_view.num_phases(), 1);
178            assert_eq!(
179                mvk_view.num_challenges_in_phase(0),
180                phase_data.challenges.len()
181            );
182            let perm_views = zip_eq(
183                &phase_data.after_challenge_trace_per_air,
184                phase_data.exposed_values_per_air,
185            )
186            .map(|(perm_trace, exposed_values)| {
187                let mut matrix_idx = None;
188                if perm_trace.is_some() {
189                    matrix_idx = Some(perm_matrix_idx);
190                    perm_matrix_idx += 1;
191                }
192                RapSinglePhaseView {
193                    inner: matrix_idx,
194                    challenges: phase_data.challenges.clone(),
195                    exposed_values: exposed_values.unwrap_or_default(),
196                }
197            })
198            .collect_vec();
199            rap_views_per_phase = vec![perm_views]; // 1 challenge phase
200            phase_data.after_challenge_trace_per_air
201        } else {
202            assert_eq!(mvk_view.num_phases(), 0);
203            rap_views_per_phase = vec![];
204            vec![None; num_airs]
205        };
206
207        // Commit to permutation traces: this means only 1 challenge round right now
208        // One shared commit for all permutation traces
209        let committed_pcs_data_per_phase: Vec<(Com<SC>, PcsData<SC>)> =
210            metrics_span("perm_trace_commit_time_ms", || {
211                let flattened_traces: Vec<_> = perm_trace_per_air
212                    .into_iter()
213                    .flat_map(|perm_trace| {
214                        perm_trace.map(|trace| Arc::new(trace.flatten_to_base()))
215                    })
216                    .collect();
217                // Only commit if there are permutation traces
218                if !flattened_traces.is_empty() {
219                    let (commit, data) = self.commit(&flattened_traces);
220                    Some((commit, data))
221                } else {
222                    None
223                }
224            })
225            .into_iter()
226            .collect();
227        let prover_view = ProverDataAfterRapPhases {
228            committed_pcs_data_per_phase,
229            rap_views_per_phase,
230        };
231        (rap_phase_seq_proof, prover_view)
232    }
233}
234
235impl<SC: StarkGenericConfig> hal::QuotientCommitter<CpuBackend<SC>> for CpuDevice<'_, SC> {
236    fn eval_and_commit_quotient(
237        &self,
238        challenger: &mut SC::Challenger,
239        pk_views: &[DeviceStarkProvingKey<CpuBackend<SC>>],
240        public_values: &[Vec<Val<SC>>],
241        cached_views_per_air: &[Vec<
242            SingleCommitPreimage<&Arc<RowMajorMatrix<Val<SC>>>, &PcsData<SC>>,
243        >],
244        common_main_pcs_data: &PcsData<SC>,
245        prover_data_after: &ProverDataAfterRapPhases<CpuBackend<SC>>,
246    ) -> (Com<SC>, PcsData<SC>) {
247        let pcs = self.pcs();
248        // Generate `alpha` challenge
249        let alpha: SC::Challenge = challenger.sample_ext_element();
250        tracing::debug!("alpha: {alpha:?}");
251        // Prepare extended views:
252        let mut common_main_idx = 0;
253        let extended_views = izip!(pk_views, cached_views_per_air, public_values)
254            .enumerate()
255            .map(|(i, (pk, cached_views, pvs))| {
256                let quotient_degree = pk.vk.quotient_degree;
257                let log_trace_height = if pk.vk.has_common_main() {
258                    common_main_pcs_data.log_trace_heights[common_main_idx]
259                } else {
260                    log2_strict_usize(cached_views[0].trace.height()) as u8
261                };
262                let trace_domain = pcs.natural_domain_for_degree(1usize << log_trace_height);
263                let quotient_domain = trace_domain
264                    .create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
265                // **IMPORTANT**: the return type of `get_evaluations_on_domain` is a matrix view. DO NOT call to_row_major_matrix as this will allocate new memory
266                let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
267                    pcs.get_evaluations_on_domain(
268                        &cv.data.data,
269                        cv.matrix_idx as usize,
270                        quotient_domain,
271                    )
272                });
273                let mut partitioned_main: Vec<_> = cached_views
274                    .iter()
275                    .map(|cv| {
276                        pcs.get_evaluations_on_domain(
277                            &cv.data.data,
278                            cv.matrix_idx as usize,
279                            quotient_domain,
280                        )
281                    })
282                    .collect();
283                if pk.vk.has_common_main() {
284                    partitioned_main.push(pcs.get_evaluations_on_domain(
285                        &common_main_pcs_data.data,
286                        common_main_idx,
287                        quotient_domain,
288                    ));
289                    common_main_idx += 1;
290                }
291                let pair = PairView {
292                    log_trace_height,
293                    preprocessed,
294                    partitioned_main,
295                    public_values: pvs.to_vec(),
296                };
297                let mut per_phase = zip(
298                    &prover_data_after.committed_pcs_data_per_phase,
299                    &prover_data_after.rap_views_per_phase,
300                )
301                .map(|((_, pcs_data), rap_views)| -> Option<_> {
302                    let rap_view = rap_views.get(i)?;
303                    let matrix_idx = rap_view.inner?;
304                    let extended_matrix =
305                        pcs.get_evaluations_on_domain(&pcs_data.data, matrix_idx, quotient_domain);
306                    Some(RapSinglePhaseView {
307                        inner: Some(extended_matrix),
308                        challenges: rap_view.challenges.clone(),
309                        exposed_values: rap_view.exposed_values.clone(),
310                    })
311                })
312                .collect_vec();
313                while let Some(last) = per_phase.last() {
314                    if last.is_none() {
315                        per_phase.pop();
316                    } else {
317                        break;
318                    }
319                }
320                let per_phase = per_phase
321                    .into_iter()
322                    .map(|v| v.unwrap_or_default())
323                    .collect();
324
325                RapView { pair, per_phase }
326            })
327            .collect_vec();
328
329        let (constraints, quotient_degrees): (Vec<_>, Vec<_>) = pk_views
330            .iter()
331            .map(|pk| {
332                (
333                    &pk.vk.symbolic_constraints.constraints,
334                    pk.vk.quotient_degree,
335                )
336            })
337            .unzip();
338        let qc = QuotientCommitter::new(self.pcs(), alpha);
339        let quotient_values = metrics_span("quotient_poly_compute_time_ms", || {
340            qc.quotient_values(&constraints, extended_views, &quotient_degrees)
341        });
342
343        // Commit to quotient polynomials. One shared commit for all quotient polynomials
344        metrics_span("quotient_poly_commit_time_ms", || {
345            qc.commit(quotient_values)
346        })
347    }
348}
349
350impl<SC: StarkGenericConfig> hal::OpeningProver<CpuBackend<SC>> for CpuDevice<'_, SC> {
351    fn open(
352        &self,
353        challenger: &mut SC::Challenger,
354        // For each preprocessed trace commitment, the prover data and
355        // the log height of the matrix, in order
356        preprocessed: Vec<&PcsData<SC>>,
357        // For each main trace commitment, the prover data and
358        // the log height of each matrix, in order
359        // Note: this is all one challenge phase.
360        main: Vec<&PcsData<SC>>,
361        // `after_phase[i]` has shared commitment prover data for all matrices in phase `i + 1`.
362        after_phase: Vec<PcsData<SC>>,
363        // Quotient poly commitment prover data
364        quotient_data: PcsData<SC>,
365        // Quotient degree for each RAP committed in quotient_data, in order
366        quotient_degrees: &[u8],
367    ) -> OpeningProof<PcsProof<SC>, SC::Challenge> {
368        // Draw `zeta` challenge
369        let zeta: SC::Challenge = challenger.sample_ext_element();
370        tracing::debug!("zeta: {zeta:?}");
371
372        let pcs = self.pcs();
373        let domain = |log_height| pcs.natural_domain_for_degree(1usize << log_height);
374        let opener = OpeningProver::<SC>::new(pcs, zeta);
375        let preprocessed = preprocessed
376            .iter()
377            .map(|v| {
378                assert_eq!(v.log_trace_heights.len(), 1);
379                (v.data.as_ref(), domain(v.log_trace_heights[0]))
380            })
381            .collect();
382        let main = main
383            .iter()
384            .map(|v| {
385                let domains = v.log_trace_heights.iter().copied().map(domain).collect();
386                (v.data.as_ref(), domains)
387            })
388            .collect();
389        let after_phase: Vec<_> = after_phase
390            .iter()
391            .map(|v| {
392                let domains = v.log_trace_heights.iter().copied().map(domain).collect();
393                (v.data.as_ref(), domains)
394            })
395            .collect();
396        opener.open(
397            challenger,
398            preprocessed,
399            main,
400            after_phase,
401            &quotient_data.data,
402            quotient_degrees,
403        )
404    }
405}
406
407impl<SC> DeviceDataTransporter<SC, CpuBackend<SC>> for CpuBackend<SC>
408where
409    SC: StarkGenericConfig,
410{
411    fn transport_pk_to_device<'a>(
412        &self,
413        mpk: &'a MultiStarkProvingKey<SC>,
414        air_ids: Vec<usize>,
415    ) -> DeviceMultiStarkProvingKey<'a, CpuBackend<SC>>
416    where
417        SC: 'a,
418    {
419        assert!(
420            air_ids.len() <= mpk.per_air.len(),
421            "filtering more AIRs than available"
422        );
423        let per_air = air_ids
424            .iter()
425            .map(|&air_idx| {
426                let pk = &mpk.per_air[air_idx];
427                let preprocessed_data = pk.preprocessed_data.as_ref().map(|pd| {
428                    let pcs_data_view = PcsData {
429                        data: pd.data.clone(),
430                        log_trace_heights: vec![log2_strict_usize(pd.trace.height()) as u8],
431                    };
432                    SingleCommitPreimage {
433                        trace: pd.trace.clone(),
434                        data: pcs_data_view,
435                        matrix_idx: 0,
436                    }
437                });
438                DeviceStarkProvingKey {
439                    air_name: &pk.air_name,
440                    vk: &pk.vk,
441                    preprocessed_data,
442                    rap_partial_pk: pk.rap_partial_pk.clone(),
443                }
444            })
445            .collect();
446        DeviceMultiStarkProvingKey::new(
447            air_ids,
448            per_air,
449            mpk.trace_height_constraints.clone(),
450            mpk.vk_pre_hash.clone(),
451        )
452    }
453    fn transport_matrix_to_device(
454        &self,
455        matrix: &Arc<RowMajorMatrix<Val<SC>>>,
456    ) -> Arc<RowMajorMatrix<Val<SC>>> {
457        matrix.clone()
458    }
459
460    fn transport_pcs_data_to_device(&self, data: &PcsData<SC>) -> PcsData<SC> {
461        data.clone()
462    }
463}