1use std::{iter::zip, marker::PhantomData, ops::Deref, sync::Arc};
2
3use derivative::Derivative;
4use itertools::{izip, zip_eq, Itertools};
5use opener::OpeningProver;
6use p3_challenger::FieldChallenger;
7use p3_commit::{Pcs, PolynomialSpace};
8use p3_field::FieldExtensionAlgebra;
9use p3_matrix::{dense::RowMajorMatrix, Matrix};
10use p3_util::log2_strict_usize;
11use quotient::QuotientCommitter;
12
13use super::{
14 hal::{self, DeviceDataTransporter, MatrixDimensions, ProverBackend, ProverDevice},
15 types::{
16 DeviceMultiStarkProvingKey, DeviceStarkProvingKey, PairView, ProverDataAfterRapPhases,
17 RapView, SingleCommitPreimage,
18 },
19};
20use crate::{
21 air_builders::symbolic::SymbolicConstraints,
22 config::{
23 Com, PcsProof, PcsProverData, RapPartialProvingKey, RapPhaseSeqPartialProof,
24 StarkGenericConfig, Val,
25 },
26 interaction::RapPhaseSeq,
27 keygen::types::MultiStarkProvingKey,
28 proof::OpeningProof,
29 prover::{hal::TraceCommitter, types::RapSinglePhaseView},
30 utils::metrics_span,
31};
32
33pub mod opener;
35pub mod quotient;
37
38pub struct MultiTraceStarkProver<'c, SC: StarkGenericConfig> {
41 pub config: &'c SC,
42}
43
44#[derive(Derivative)]
46#[derivative(Clone(bound = ""), Copy(bound = ""), Default(bound = ""))]
47pub struct CpuBackend<SC> {
48 phantom: PhantomData<SC>,
49}
50
51#[derive(Derivative, derive_new::new)]
52#[derivative(Clone(bound = ""), Copy(bound = ""))]
53pub struct CpuDevice<'a, SC> {
54 config: &'a SC,
55}
56
57impl<SC: StarkGenericConfig> ProverBackend for CpuBackend<SC> {
58 const CHALLENGE_EXT_DEGREE: u8 = <SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D as u8;
59
60 type Val = Val<SC>;
61 type Challenge = SC::Challenge;
62 type OpeningProof = OpeningProof<PcsProof<SC>, SC::Challenge>;
63 type RapPartialProof = Option<RapPhaseSeqPartialProof<SC>>;
64 type Commitment = Com<SC>;
65 type Challenger = SC::Challenger;
66 type Matrix = Arc<RowMajorMatrix<Val<SC>>>;
67 type PcsData = PcsData<SC>;
68 type RapPartialProvingKey = RapPartialProvingKey<SC>;
69}
70
71#[derive(Derivative)]
72#[derivative(Clone(bound = ""))]
73pub struct PcsData<SC: StarkGenericConfig> {
74 pub data: Arc<PcsProverData<SC>>,
76 pub log_trace_heights: Vec<u8>,
79}
80
81impl<T: Send + Sync + Clone> MatrixDimensions for Arc<RowMajorMatrix<T>> {
82 fn height(&self) -> usize {
83 self.deref().height()
84 }
85 fn width(&self) -> usize {
86 self.deref().width()
87 }
88}
89
90impl<SC> CpuDevice<'_, SC> {
91 pub fn config(&self) -> &SC {
92 self.config
93 }
94}
95
96impl<SC: StarkGenericConfig> CpuDevice<'_, SC> {
97 pub fn pcs(&self) -> &SC::Pcs {
98 self.config.pcs()
99 }
100}
101
102impl<SC: StarkGenericConfig> ProverDevice<CpuBackend<SC>> for CpuDevice<'_, SC> {}
103
104impl<SC: StarkGenericConfig> TraceCommitter<CpuBackend<SC>> for CpuDevice<'_, SC> {
105 fn commit(&self, traces: &[Arc<RowMajorMatrix<Val<SC>>>]) -> (Com<SC>, PcsData<SC>) {
106 let pcs = self.pcs();
107 let (log_trace_heights, traces_with_domains): (Vec<_>, Vec<_>) = traces
108 .iter()
109 .map(|matrix| {
110 let height = matrix.height();
111 let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
112 let domain = pcs.natural_domain_for_degree(height);
114 (log_height, (domain, matrix.as_ref().clone()))
115 })
116 .unzip();
117 let (commit, data) = pcs.commit(traces_with_domains);
118 (
119 commit,
120 PcsData {
121 data: Arc::new(data),
122 log_trace_heights,
123 },
124 )
125 }
126}
127
128impl<SC: StarkGenericConfig> hal::RapPartialProver<CpuBackend<SC>> for CpuDevice<'_, SC> {
129 fn partially_prove<'a>(
130 &self,
131 challenger: &mut SC::Challenger,
132 mpk: &DeviceMultiStarkProvingKey<'a, CpuBackend<SC>>,
133 trace_views: Vec<PairView<&'a Arc<RowMajorMatrix<Val<SC>>>, Val<SC>>>,
134 ) -> (
135 Option<RapPhaseSeqPartialProof<SC>>,
136 ProverDataAfterRapPhases<CpuBackend<SC>>,
137 ) {
138 let num_airs = mpk.per_air.len();
139 assert_eq!(num_airs, trace_views.len());
140
141 let (constraints_per_air, rap_pk_per_air): (Vec<_>, Vec<_>) = mpk
142 .per_air
143 .iter()
144 .map(|pk| {
145 (
146 SymbolicConstraints::from(&pk.vk.symbolic_constraints),
147 &pk.rap_partial_pk,
148 )
149 })
150 .unzip();
151
152 let trace_views = trace_views
153 .iter()
154 .map(|v| PairView {
155 log_trace_height: v.log_trace_height,
156 preprocessed: v.preprocessed.as_ref().map(|p| p.as_ref()),
157 partitioned_main: v.partitioned_main.iter().map(|m| m.as_ref()).collect(),
158 public_values: v.public_values.clone(),
159 })
160 .collect_vec();
161 let (rap_phase_seq_proof, rap_phase_seq_data) = self
162 .config()
163 .rap_phase_seq()
164 .partially_prove(
165 challenger,
166 &constraints_per_air.iter().collect_vec(),
167 &rap_pk_per_air,
168 &trace_views,
169 )
170 .map_or((None, None), |(p, d)| (Some(p), Some(d)));
171
172 let mvk_view = mpk.vk_view();
173
174 let mut perm_matrix_idx = 0usize;
175 let rap_views_per_phase;
176 let perm_trace_per_air = if let Some(phase_data) = rap_phase_seq_data {
177 assert_eq!(mvk_view.num_phases(), 1);
178 assert_eq!(
179 mvk_view.num_challenges_in_phase(0),
180 phase_data.challenges.len()
181 );
182 let perm_views = zip_eq(
183 &phase_data.after_challenge_trace_per_air,
184 phase_data.exposed_values_per_air,
185 )
186 .map(|(perm_trace, exposed_values)| {
187 let mut matrix_idx = None;
188 if perm_trace.is_some() {
189 matrix_idx = Some(perm_matrix_idx);
190 perm_matrix_idx += 1;
191 }
192 RapSinglePhaseView {
193 inner: matrix_idx,
194 challenges: phase_data.challenges.clone(),
195 exposed_values: exposed_values.unwrap_or_default(),
196 }
197 })
198 .collect_vec();
199 rap_views_per_phase = vec![perm_views]; phase_data.after_challenge_trace_per_air
201 } else {
202 assert_eq!(mvk_view.num_phases(), 0);
203 rap_views_per_phase = vec![];
204 vec![None; num_airs]
205 };
206
207 let committed_pcs_data_per_phase: Vec<(Com<SC>, PcsData<SC>)> =
210 metrics_span("perm_trace_commit_time_ms", || {
211 let flattened_traces: Vec<_> = perm_trace_per_air
212 .into_iter()
213 .flat_map(|perm_trace| {
214 perm_trace.map(|trace| Arc::new(trace.flatten_to_base()))
215 })
216 .collect();
217 if !flattened_traces.is_empty() {
219 let (commit, data) = self.commit(&flattened_traces);
220 Some((commit, data))
221 } else {
222 None
223 }
224 })
225 .into_iter()
226 .collect();
227 let prover_view = ProverDataAfterRapPhases {
228 committed_pcs_data_per_phase,
229 rap_views_per_phase,
230 };
231 (rap_phase_seq_proof, prover_view)
232 }
233}
234
235impl<SC: StarkGenericConfig> hal::QuotientCommitter<CpuBackend<SC>> for CpuDevice<'_, SC> {
236 fn eval_and_commit_quotient(
237 &self,
238 challenger: &mut SC::Challenger,
239 pk_views: &[DeviceStarkProvingKey<CpuBackend<SC>>],
240 public_values: &[Vec<Val<SC>>],
241 cached_views_per_air: &[Vec<
242 SingleCommitPreimage<&Arc<RowMajorMatrix<Val<SC>>>, &PcsData<SC>>,
243 >],
244 common_main_pcs_data: &PcsData<SC>,
245 prover_data_after: &ProverDataAfterRapPhases<CpuBackend<SC>>,
246 ) -> (Com<SC>, PcsData<SC>) {
247 let pcs = self.pcs();
248 let alpha: SC::Challenge = challenger.sample_ext_element();
250 tracing::debug!("alpha: {alpha:?}");
251 let mut common_main_idx = 0;
253 let extended_views = izip!(pk_views, cached_views_per_air, public_values)
254 .enumerate()
255 .map(|(i, (pk, cached_views, pvs))| {
256 let quotient_degree = pk.vk.quotient_degree;
257 let log_trace_height = if pk.vk.has_common_main() {
258 common_main_pcs_data.log_trace_heights[common_main_idx]
259 } else {
260 log2_strict_usize(cached_views[0].trace.height()) as u8
261 };
262 let trace_domain = pcs.natural_domain_for_degree(1usize << log_trace_height);
263 let quotient_domain = trace_domain
264 .create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
265 let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
267 pcs.get_evaluations_on_domain(
268 &cv.data.data,
269 cv.matrix_idx as usize,
270 quotient_domain,
271 )
272 });
273 let mut partitioned_main: Vec<_> = cached_views
274 .iter()
275 .map(|cv| {
276 pcs.get_evaluations_on_domain(
277 &cv.data.data,
278 cv.matrix_idx as usize,
279 quotient_domain,
280 )
281 })
282 .collect();
283 if pk.vk.has_common_main() {
284 partitioned_main.push(pcs.get_evaluations_on_domain(
285 &common_main_pcs_data.data,
286 common_main_idx,
287 quotient_domain,
288 ));
289 common_main_idx += 1;
290 }
291 let pair = PairView {
292 log_trace_height,
293 preprocessed,
294 partitioned_main,
295 public_values: pvs.to_vec(),
296 };
297 let mut per_phase = zip(
298 &prover_data_after.committed_pcs_data_per_phase,
299 &prover_data_after.rap_views_per_phase,
300 )
301 .map(|((_, pcs_data), rap_views)| -> Option<_> {
302 let rap_view = rap_views.get(i)?;
303 let matrix_idx = rap_view.inner?;
304 let extended_matrix =
305 pcs.get_evaluations_on_domain(&pcs_data.data, matrix_idx, quotient_domain);
306 Some(RapSinglePhaseView {
307 inner: Some(extended_matrix),
308 challenges: rap_view.challenges.clone(),
309 exposed_values: rap_view.exposed_values.clone(),
310 })
311 })
312 .collect_vec();
313 while let Some(last) = per_phase.last() {
314 if last.is_none() {
315 per_phase.pop();
316 } else {
317 break;
318 }
319 }
320 let per_phase = per_phase
321 .into_iter()
322 .map(|v| v.unwrap_or_default())
323 .collect();
324
325 RapView { pair, per_phase }
326 })
327 .collect_vec();
328
329 let (constraints, quotient_degrees): (Vec<_>, Vec<_>) = pk_views
330 .iter()
331 .map(|pk| {
332 (
333 &pk.vk.symbolic_constraints.constraints,
334 pk.vk.quotient_degree,
335 )
336 })
337 .unzip();
338 let qc = QuotientCommitter::new(self.pcs(), alpha);
339 let quotient_values = metrics_span("quotient_poly_compute_time_ms", || {
340 qc.quotient_values(&constraints, extended_views, "ient_degrees)
341 });
342
343 metrics_span("quotient_poly_commit_time_ms", || {
345 qc.commit(quotient_values)
346 })
347 }
348}
349
350impl<SC: StarkGenericConfig> hal::OpeningProver<CpuBackend<SC>> for CpuDevice<'_, SC> {
351 fn open(
352 &self,
353 challenger: &mut SC::Challenger,
354 preprocessed: Vec<&PcsData<SC>>,
357 main: Vec<&PcsData<SC>>,
361 after_phase: Vec<PcsData<SC>>,
363 quotient_data: PcsData<SC>,
365 quotient_degrees: &[u8],
367 ) -> OpeningProof<PcsProof<SC>, SC::Challenge> {
368 let zeta: SC::Challenge = challenger.sample_ext_element();
370 tracing::debug!("zeta: {zeta:?}");
371
372 let pcs = self.pcs();
373 let domain = |log_height| pcs.natural_domain_for_degree(1usize << log_height);
374 let opener = OpeningProver::<SC>::new(pcs, zeta);
375 let preprocessed = preprocessed
376 .iter()
377 .map(|v| {
378 assert_eq!(v.log_trace_heights.len(), 1);
379 (v.data.as_ref(), domain(v.log_trace_heights[0]))
380 })
381 .collect();
382 let main = main
383 .iter()
384 .map(|v| {
385 let domains = v.log_trace_heights.iter().copied().map(domain).collect();
386 (v.data.as_ref(), domains)
387 })
388 .collect();
389 let after_phase: Vec<_> = after_phase
390 .iter()
391 .map(|v| {
392 let domains = v.log_trace_heights.iter().copied().map(domain).collect();
393 (v.data.as_ref(), domains)
394 })
395 .collect();
396 opener.open(
397 challenger,
398 preprocessed,
399 main,
400 after_phase,
401 "ient_data.data,
402 quotient_degrees,
403 )
404 }
405}
406
407impl<SC> DeviceDataTransporter<SC, CpuBackend<SC>> for CpuBackend<SC>
408where
409 SC: StarkGenericConfig,
410{
411 fn transport_pk_to_device<'a>(
412 &self,
413 mpk: &'a MultiStarkProvingKey<SC>,
414 air_ids: Vec<usize>,
415 ) -> DeviceMultiStarkProvingKey<'a, CpuBackend<SC>>
416 where
417 SC: 'a,
418 {
419 assert!(
420 air_ids.len() <= mpk.per_air.len(),
421 "filtering more AIRs than available"
422 );
423 let per_air = air_ids
424 .iter()
425 .map(|&air_idx| {
426 let pk = &mpk.per_air[air_idx];
427 let preprocessed_data = pk.preprocessed_data.as_ref().map(|pd| {
428 let pcs_data_view = PcsData {
429 data: pd.data.clone(),
430 log_trace_heights: vec![log2_strict_usize(pd.trace.height()) as u8],
431 };
432 SingleCommitPreimage {
433 trace: pd.trace.clone(),
434 data: pcs_data_view,
435 matrix_idx: 0,
436 }
437 });
438 DeviceStarkProvingKey {
439 air_name: &pk.air_name,
440 vk: &pk.vk,
441 preprocessed_data,
442 rap_partial_pk: pk.rap_partial_pk.clone(),
443 }
444 })
445 .collect();
446 DeviceMultiStarkProvingKey::new(
447 air_ids,
448 per_air,
449 mpk.trace_height_constraints.clone(),
450 mpk.vk_pre_hash.clone(),
451 )
452 }
453 fn transport_matrix_to_device(
454 &self,
455 matrix: &Arc<RowMajorMatrix<Val<SC>>>,
456 ) -> Arc<RowMajorMatrix<Val<SC>>> {
457 matrix.clone()
458 }
459
460 fn transport_pcs_data_to_device(&self, data: &PcsData<SC>) -> PcsData<SC> {
461 data.clone()
462 }
463}