1use std::{iter::zip, marker::PhantomData, mem::ManuallyDrop, ops::Deref, sync::Arc};
2
3use derivative::Derivative;
4use itertools::{izip, zip_eq, Itertools};
5use opener::OpeningProver;
6use p3_challenger::{FieldChallenger, GrindingChallenger};
7use p3_commit::{Pcs, PolynomialSpace};
8use p3_field::{BasedVectorSpace, ExtensionField, Field};
9use p3_matrix::{dense::RowMajorMatrix, Matrix};
10use p3_util::log2_strict_usize;
11use quotient::QuotientCommitter;
12use tracing::info_span;
13
14use super::{
15 hal::{self, DeviceDataTransporter, MatrixDimensions, ProverBackend, ProverDevice},
16 types::{
17 AirView, DeviceMultiStarkProvingKey, DeviceStarkProvingKey, ProverDataAfterRapPhases,
18 RapView, SingleCommitPreimage,
19 },
20};
21use crate::{
22 air_builders::symbolic::SymbolicConstraints,
23 config::{
24 Com, PcsProverData, RapPartialProvingKey, RapPhaseSeqPartialProof, StarkGenericConfig, Val,
25 },
26 interaction::RapPhaseSeq,
27 keygen::types::MultiStarkProvingKey,
28 proof::OpeningProof,
29 prover::{
30 hal::TraceCommitter,
31 types::{CommittedTraceData, DeviceMultiStarkProvingKeyView, PairView, RapSinglePhaseView},
32 },
33};
34
35pub mod opener;
37pub mod quotient;
39
40#[derive(Derivative)]
51#[derivative(Clone(bound = ""), Copy(bound = ""), Default(bound = ""))]
52pub struct CpuBackend<SC> {
53 phantom: PhantomData<SC>,
54}
55
56#[derive(Derivative, derive_new::new)]
59#[derivative(Clone(bound = ""))]
60pub struct CpuDevice<SC> {
61 pub config: Arc<SC>,
63 log_blowup_factor: usize,
66}
67
68impl<SC: StarkGenericConfig> ProverBackend for CpuBackend<SC> {
69 const CHALLENGE_EXT_DEGREE: u8 = <SC::Challenge as BasedVectorSpace<Val<SC>>>::DIMENSION as u8;
70
71 type Val = Val<SC>;
72 type Challenge = SC::Challenge;
73 type OpeningProof = OpeningProof<SC>;
74 type RapPartialProof = Option<RapPhaseSeqPartialProof<SC>>;
75 type Commitment = Com<SC>;
76 type Challenger = SC::Challenger;
77 type Matrix = Arc<RowMajorMatrix<Val<SC>>>;
78 type PcsData = PcsData<SC>;
79 type RapPartialProvingKey = RapPartialProvingKey<SC>;
80}
81
82#[derive(Derivative, derive_new::new)]
83#[derivative(Clone(bound = ""))]
84pub struct PcsData<SC: StarkGenericConfig> {
85 pub data: Arc<PcsProverData<SC>>,
87 pub log_trace_heights: Vec<u8>,
90}
91
92impl<T: Send + Sync + Clone> MatrixDimensions for Arc<RowMajorMatrix<T>> {
93 fn height(&self) -> usize {
94 self.deref().height()
95 }
96 fn width(&self) -> usize {
97 self.deref().width()
98 }
99}
100
101impl<SC> CpuDevice<SC> {
102 pub fn config(&self) -> &SC {
103 &self.config
104 }
105}
106
107impl<SC: StarkGenericConfig> CpuDevice<SC> {
108 pub fn pcs(&self) -> &SC::Pcs {
109 self.config.pcs()
110 }
111}
112
113impl<SC: StarkGenericConfig> ProverDevice<CpuBackend<SC>> for CpuDevice<SC> {}
114
115impl<SC: StarkGenericConfig> TraceCommitter<CpuBackend<SC>> for CpuDevice<SC> {
116 fn commit(&self, traces: &[Arc<RowMajorMatrix<Val<SC>>>]) -> (Com<SC>, PcsData<SC>) {
117 let log_blowup_factor = self.log_blowup_factor;
118 let pcs = self.pcs();
119 let (log_trace_heights, traces_with_domains): (Vec<_>, Vec<_>) = traces
120 .iter()
121 .map(|matrix| {
122 let height = matrix.height();
123 let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
124 let domain = pcs.natural_domain_for_degree(height);
126 let trace_slice = &matrix.as_ref().values;
131 let new_buffer_size = trace_slice
132 .len()
133 .checked_shl(log_blowup_factor.try_into().unwrap())
134 .unwrap();
135 let mut new_buffer = Vec::with_capacity(new_buffer_size);
136 unsafe {
144 std::ptr::copy_nonoverlapping(
145 trace_slice.as_ptr(),
146 new_buffer.as_mut_ptr(),
147 trace_slice.len(),
148 );
149 new_buffer.set_len(trace_slice.len());
150 }
151 (
152 log_height,
153 (domain, RowMajorMatrix::new(new_buffer, matrix.width)),
154 )
155 })
156 .unzip();
157 let (commit, data) = pcs.commit(traces_with_domains);
158 (
159 commit,
160 PcsData {
161 data: Arc::new(data),
162 log_trace_heights,
163 },
164 )
165 }
166}
167
168impl<SC: StarkGenericConfig> hal::RapPartialProver<CpuBackend<SC>> for CpuDevice<SC> {
169 fn partially_prove(
170 &self,
171 challenger: &mut SC::Challenger,
172 mpk: &DeviceMultiStarkProvingKeyView<CpuBackend<SC>>,
173 trace_views: Vec<AirView<Arc<RowMajorMatrix<Val<SC>>>, Val<SC>>>,
174 ) -> (
175 Option<RapPhaseSeqPartialProof<SC>>,
176 ProverDataAfterRapPhases<CpuBackend<SC>>,
177 ) {
178 let num_airs = mpk.per_air.len();
179 assert_eq!(num_airs, trace_views.len());
180
181 let (constraints_per_air, rap_pk_per_air): (Vec<_>, Vec<_>) = mpk
182 .per_air
183 .iter()
184 .map(|pk| {
185 (
186 SymbolicConstraints::from(&pk.vk.symbolic_constraints),
187 &pk.rap_partial_pk,
188 )
189 })
190 .unzip();
191
192 let trace_views = zip(&mpk.per_air, trace_views)
193 .map(|(pk, v)| PairView {
194 log_trace_height: log2_strict_usize(v.partitioned_main.first().unwrap().height())
195 as u8,
196 preprocessed: pk.preprocessed_data.as_ref().map(|p| p.trace.clone()), partitioned_main: v.partitioned_main,
198 public_values: v.public_values,
199 })
200 .collect_vec();
201 let (rap_phase_seq_proof, rap_phase_seq_data) = self
202 .config()
203 .rap_phase_seq()
204 .partially_prove(
205 challenger,
206 &constraints_per_air.iter().collect_vec(),
207 &rap_pk_per_air,
208 trace_views,
209 )
210 .map_or((None, None), |(p, d)| (Some(p), Some(d)));
211
212 let mvk_view = mpk.vk_view();
213
214 let mut perm_matrix_idx = 0usize;
215 let rap_views_per_phase;
216 let perm_trace_per_air = if let Some(phase_data) = rap_phase_seq_data {
217 assert_eq!(mvk_view.num_phases(), 1);
218 assert_eq!(
219 mvk_view.num_challenges_in_phase(0),
220 phase_data.challenges.len()
221 );
222 let perm_views = zip_eq(
223 &phase_data.after_challenge_trace_per_air,
224 phase_data.exposed_values_per_air,
225 )
226 .map(|(perm_trace, exposed_values)| {
227 let mut matrix_idx = None;
228 if perm_trace.is_some() {
229 matrix_idx = Some(perm_matrix_idx);
230 perm_matrix_idx += 1;
231 }
232 RapSinglePhaseView {
233 inner: matrix_idx,
234 challenges: phase_data.challenges.clone(),
235 exposed_values: exposed_values.unwrap_or_default(),
236 }
237 })
238 .collect_vec();
239 rap_views_per_phase = vec![perm_views]; phase_data.after_challenge_trace_per_air
241 } else {
242 assert_eq!(mvk_view.num_phases(), 0);
243 rap_views_per_phase = vec![];
244 vec![None; num_airs]
245 };
246
247 let committed_pcs_data_per_phase: Vec<(Com<SC>, PcsData<SC>)> =
250 info_span!("perm_trace_commit")
251 .in_scope(|| {
252 let (log_trace_heights, flattened_traces): (Vec<_>, Vec<_>) =
253 perm_trace_per_air
254 .into_iter()
255 .flatten()
256 .map(|perm_trace| {
257 let trace = unsafe { transmute_to_base(perm_trace) };
260 let height = trace.height();
261 let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
262 let domain = self.pcs().natural_domain_for_degree(height);
263 (log_height, (domain, trace))
264 })
265 .collect();
266 if !flattened_traces.is_empty() {
268 let (commit, data) = self.pcs().commit(flattened_traces);
269 Some((commit, PcsData::new(Arc::new(data), log_trace_heights)))
270 } else {
271 None
272 }
273 })
274 .into_iter()
275 .collect();
276 let prover_view = ProverDataAfterRapPhases {
277 committed_pcs_data_per_phase,
278 rap_views_per_phase,
279 };
280 (rap_phase_seq_proof, prover_view)
281 }
282}
283
284impl<SC: StarkGenericConfig> hal::QuotientCommitter<CpuBackend<SC>> for CpuDevice<SC> {
285 fn eval_and_commit_quotient(
286 &self,
287 challenger: &mut SC::Challenger,
288 pk_views: &[&DeviceStarkProvingKey<CpuBackend<SC>>],
289 public_values: &[Vec<Val<SC>>],
290 cached_pcs_datas_per_air: &[Vec<PcsData<SC>>],
291 common_main_pcs_data: &PcsData<SC>,
292 prover_data_after: &ProverDataAfterRapPhases<CpuBackend<SC>>,
293 ) -> (Com<SC>, PcsData<SC>) {
294 let pcs = self.pcs();
295 let alpha: SC::Challenge = challenger.sample_algebra_element();
297 tracing::debug!("alpha: {alpha:?}");
298 let mut common_main_idx = 0;
300 let extended_views = izip!(pk_views, cached_pcs_datas_per_air, public_values)
301 .enumerate()
302 .map(|(i, (pk, cached_pcs_datas, pvs))| {
303 let quotient_degree = pk.vk.quotient_degree;
304 let log_trace_height = if pk.vk.has_common_main() {
305 common_main_pcs_data.log_trace_heights[common_main_idx]
306 } else {
307 cached_pcs_datas[0].log_trace_heights[0]
308 };
309 let trace_domain = pcs.natural_domain_for_degree(1usize << log_trace_height);
310 let quotient_domain = trace_domain
311 .create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
312 let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
315 pcs.get_evaluations_on_domain(
316 &cv.data.data,
317 cv.matrix_idx as usize,
318 quotient_domain,
319 )
320 });
321 let mut partitioned_main: Vec<_> = cached_pcs_datas
323 .iter()
324 .map(|cv| pcs.get_evaluations_on_domain(&cv.data, 0, quotient_domain))
325 .collect();
326 if pk.vk.has_common_main() {
327 partitioned_main.push(pcs.get_evaluations_on_domain(
328 &common_main_pcs_data.data,
329 common_main_idx,
330 quotient_domain,
331 ));
332 common_main_idx += 1;
333 }
334 let mut per_phase = zip(
335 &prover_data_after.committed_pcs_data_per_phase,
336 &prover_data_after.rap_views_per_phase,
337 )
338 .map(|((_, pcs_data), rap_views)| -> Option<_> {
339 let rap_view = rap_views.get(i)?;
340 let matrix_idx = rap_view.inner?;
341 let extended_matrix =
342 pcs.get_evaluations_on_domain(&pcs_data.data, matrix_idx, quotient_domain);
343 Some(RapSinglePhaseView {
344 inner: Some(extended_matrix),
345 challenges: rap_view.challenges.clone(),
346 exposed_values: rap_view.exposed_values.clone(),
347 })
348 })
349 .collect_vec();
350 while let Some(last) = per_phase.last() {
351 if last.is_none() {
352 per_phase.pop();
353 } else {
354 break;
355 }
356 }
357 let per_phase = per_phase
358 .into_iter()
359 .map(|v| v.unwrap_or_default())
360 .collect();
361
362 RapView {
363 log_trace_height,
364 preprocessed,
365 partitioned_main,
366 public_values: pvs.to_vec(),
367 per_phase,
368 }
369 })
370 .collect_vec();
371
372 let (constraints, quotient_degrees): (Vec<_>, Vec<_>) = pk_views
373 .iter()
374 .map(|pk| {
375 (
376 &pk.vk.symbolic_constraints.constraints,
377 pk.vk.quotient_degree,
378 )
379 })
380 .unzip();
381 let qc = QuotientCommitter::new(self.pcs(), alpha, self.log_blowup_factor);
382 let quotient_values = qc.quotient_values(&constraints, extended_views, "ient_degrees);
383
384 qc.commit(quotient_values)
386 }
387}
388
389impl<SC: StarkGenericConfig> hal::OpeningProver<CpuBackend<SC>> for CpuDevice<SC> {
390 fn open(
391 &self,
392 challenger: &mut SC::Challenger,
393 preprocessed: Vec<&PcsData<SC>>,
396 main: Vec<PcsData<SC>>,
400 after_phase: Vec<PcsData<SC>>,
402 quotient_data: PcsData<SC>,
404 quotient_degrees: &[u8],
406 ) -> OpeningProof<SC> {
407 let deep_pow_witness = challenger.grind(self.config().deep_ali_params().deep_pow_bits);
409 let zeta: SC::Challenge = challenger.sample_algebra_element();
411 tracing::debug!("zeta: {zeta:?}");
412
413 let pcs = self.pcs();
414 let domain = |log_height| pcs.natural_domain_for_degree(1usize << log_height);
415 let opener = OpeningProver::<SC>::new(pcs, zeta, deep_pow_witness);
416 let preprocessed = preprocessed
417 .iter()
418 .map(|v| {
419 assert_eq!(v.log_trace_heights.len(), 1);
420 (v.data.as_ref(), domain(v.log_trace_heights[0]))
421 })
422 .collect();
423 let main = main
424 .iter()
425 .map(|v| {
426 let domains = v.log_trace_heights.iter().copied().map(domain).collect();
427 (v.data.as_ref(), domains)
428 })
429 .collect();
430 let after_phase: Vec<_> = after_phase
431 .iter()
432 .map(|v| {
433 let domains = v.log_trace_heights.iter().copied().map(domain).collect();
434 (v.data.as_ref(), domains)
435 })
436 .collect();
437 opener.open(
438 challenger,
439 preprocessed,
440 main,
441 after_phase,
442 "ient_data.data,
443 quotient_degrees,
444 )
445 }
446}
447
448impl<SC> DeviceDataTransporter<SC, CpuBackend<SC>> for CpuDevice<SC>
449where
450 SC: StarkGenericConfig,
451{
452 fn transport_pk_to_device(
453 &self,
454 mpk: &MultiStarkProvingKey<SC>,
455 ) -> DeviceMultiStarkProvingKey<CpuBackend<SC>> {
456 let per_air = mpk
457 .per_air
458 .iter()
459 .map(|pk| {
460 let preprocessed_data = pk.preprocessed_data.as_ref().map(|pd| {
461 let pcs_data_view = PcsData {
462 data: pd.data.clone(),
463 log_trace_heights: vec![log2_strict_usize(pd.trace.height()) as u8],
464 };
465 SingleCommitPreimage {
466 trace: pd.trace.clone(),
467 data: pcs_data_view,
468 matrix_idx: 0,
469 }
470 });
471 DeviceStarkProvingKey {
472 air_name: pk.air_name.clone(),
473 vk: pk.vk.clone(),
474 preprocessed_data,
475 rap_partial_pk: pk.rap_partial_pk.clone(),
476 }
477 })
478 .collect();
479 DeviceMultiStarkProvingKey::new(
480 per_air,
481 mpk.trace_height_constraints.clone(),
482 mpk.vk_pre_hash.clone(),
483 )
484 }
485 fn transport_matrix_to_device(
486 &self,
487 matrix: &Arc<RowMajorMatrix<Val<SC>>>,
488 ) -> Arc<RowMajorMatrix<Val<SC>>> {
489 matrix.clone()
490 }
491
492 fn transport_committed_trace_to_device(
493 &self,
494 commitment: Com<SC>,
495 trace: &Arc<RowMajorMatrix<Val<SC>>>,
496 prover_data: &Arc<PcsProverData<SC>>,
497 ) -> CommittedTraceData<CpuBackend<SC>> {
498 let log_trace_height: u8 = log2_strict_usize(trace.height()).try_into().unwrap();
499 let data = PcsData::new(prover_data.clone(), vec![log_trace_height]);
500 CommittedTraceData {
501 commitment,
502 trace: trace.clone(),
503 data,
504 }
505 }
506
507 fn transport_matrix_from_device_to_host(
508 &self,
509 matrix: &Arc<RowMajorMatrix<Val<SC>>>,
510 ) -> Arc<RowMajorMatrix<Val<SC>>> {
511 matrix.clone()
512 }
513}
514
515unsafe fn transmute_to_base<F: Field, EF: ExtensionField<F>>(
522 ext_matrix: RowMajorMatrix<EF>,
523) -> RowMajorMatrix<F> {
524 debug_assert_eq!(align_of::<EF>(), align_of::<F>());
525 debug_assert_eq!(size_of::<EF>(), size_of::<F>() * EF::DIMENSION);
526 let width = ext_matrix.width * EF::DIMENSION;
527 let mut values = ManuallyDrop::new(ext_matrix.values);
529 let mut len = values.len();
530 let mut cap = values.capacity();
531 let ptr = values.as_mut_ptr();
532 len *= EF::DIMENSION;
533 cap *= EF::DIMENSION;
534 let base_values = Vec::from_raw_parts(ptr as *mut F, len, cap);
539 RowMajorMatrix::new(base_values, width)
540}