1use std::{iter::zip, marker::PhantomData, mem::ManuallyDrop, ops::Deref, sync::Arc};
2
3use derivative::Derivative;
4use itertools::{izip, zip_eq, Itertools};
5use opener::OpeningProver;
6use p3_challenger::FieldChallenger;
7use p3_commit::{Pcs, PolynomialSpace};
8use p3_field::{ExtensionField, Field, FieldExtensionAlgebra};
9use p3_matrix::{dense::RowMajorMatrix, Matrix};
10use p3_util::log2_strict_usize;
11use quotient::QuotientCommitter;
12use tracing::info_span;
13
14use super::{
15 hal::{self, DeviceDataTransporter, MatrixDimensions, ProverBackend, ProverDevice},
16 types::{
17 AirView, DeviceMultiStarkProvingKey, DeviceStarkProvingKey, ProverDataAfterRapPhases,
18 RapView, SingleCommitPreimage,
19 },
20};
21use crate::{
22 air_builders::symbolic::SymbolicConstraints,
23 config::{
24 Com, PcsProof, PcsProverData, RapPartialProvingKey, RapPhaseSeqPartialProof,
25 StarkGenericConfig, Val,
26 },
27 interaction::RapPhaseSeq,
28 keygen::types::MultiStarkProvingKey,
29 proof::OpeningProof,
30 prover::{
31 hal::TraceCommitter,
32 types::{CommittedTraceData, DeviceMultiStarkProvingKeyView, PairView, RapSinglePhaseView},
33 },
34};
35
36pub mod opener;
38pub mod quotient;
40
41#[derive(Derivative)]
52#[derivative(Clone(bound = ""), Copy(bound = ""), Default(bound = ""))]
53pub struct CpuBackend<SC> {
54 phantom: PhantomData<SC>,
55}
56
57#[derive(Derivative, derive_new::new)]
60#[derivative(Clone(bound = ""))]
61pub struct CpuDevice<SC> {
62 pub config: Arc<SC>,
64 log_blowup_factor: usize,
67}
68
69impl<SC: StarkGenericConfig> ProverBackend for CpuBackend<SC> {
70 const CHALLENGE_EXT_DEGREE: u8 = <SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D as u8;
71
72 type Val = Val<SC>;
73 type Challenge = SC::Challenge;
74 type OpeningProof = OpeningProof<PcsProof<SC>, SC::Challenge>;
75 type RapPartialProof = Option<RapPhaseSeqPartialProof<SC>>;
76 type Commitment = Com<SC>;
77 type Challenger = SC::Challenger;
78 type Matrix = Arc<RowMajorMatrix<Val<SC>>>;
79 type PcsData = PcsData<SC>;
80 type RapPartialProvingKey = RapPartialProvingKey<SC>;
81}
82
83#[derive(Derivative, derive_new::new)]
84#[derivative(Clone(bound = ""))]
85pub struct PcsData<SC: StarkGenericConfig> {
86 pub data: Arc<PcsProverData<SC>>,
88 pub log_trace_heights: Vec<u8>,
91}
92
93impl<T: Send + Sync + Clone> MatrixDimensions for Arc<RowMajorMatrix<T>> {
94 fn height(&self) -> usize {
95 self.deref().height()
96 }
97 fn width(&self) -> usize {
98 self.deref().width()
99 }
100}
101
102impl<SC> CpuDevice<SC> {
103 pub fn config(&self) -> &SC {
104 &self.config
105 }
106}
107
108impl<SC: StarkGenericConfig> CpuDevice<SC> {
109 pub fn pcs(&self) -> &SC::Pcs {
110 self.config.pcs()
111 }
112}
113
114impl<SC: StarkGenericConfig> ProverDevice<CpuBackend<SC>> for CpuDevice<SC> {}
115
116impl<SC: StarkGenericConfig> TraceCommitter<CpuBackend<SC>> for CpuDevice<SC> {
117 fn commit(&self, traces: &[Arc<RowMajorMatrix<Val<SC>>>]) -> (Com<SC>, PcsData<SC>) {
118 let log_blowup_factor = self.log_blowup_factor;
119 let pcs = self.pcs();
120 let (log_trace_heights, traces_with_domains): (Vec<_>, Vec<_>) = traces
121 .iter()
122 .map(|matrix| {
123 let height = matrix.height();
124 let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
125 let domain = pcs.natural_domain_for_degree(height);
127 let trace_slice = &matrix.as_ref().values;
132 let new_buffer_size = trace_slice
133 .len()
134 .checked_shl(log_blowup_factor.try_into().unwrap())
135 .unwrap();
136 let mut new_buffer = Vec::with_capacity(new_buffer_size);
137 unsafe {
145 std::ptr::copy_nonoverlapping(
146 trace_slice.as_ptr(),
147 new_buffer.as_mut_ptr(),
148 trace_slice.len(),
149 );
150 new_buffer.set_len(trace_slice.len());
151 }
152 (
153 log_height,
154 (domain, RowMajorMatrix::new(new_buffer, matrix.width)),
155 )
156 })
157 .unzip();
158 let (commit, data) = pcs.commit(traces_with_domains);
159 (
160 commit,
161 PcsData {
162 data: Arc::new(data),
163 log_trace_heights,
164 },
165 )
166 }
167}
168
169impl<SC: StarkGenericConfig> hal::RapPartialProver<CpuBackend<SC>> for CpuDevice<SC> {
170 fn partially_prove(
171 &self,
172 challenger: &mut SC::Challenger,
173 mpk: &DeviceMultiStarkProvingKeyView<CpuBackend<SC>>,
174 trace_views: Vec<AirView<Arc<RowMajorMatrix<Val<SC>>>, Val<SC>>>,
175 ) -> (
176 Option<RapPhaseSeqPartialProof<SC>>,
177 ProverDataAfterRapPhases<CpuBackend<SC>>,
178 ) {
179 let num_airs = mpk.per_air.len();
180 assert_eq!(num_airs, trace_views.len());
181
182 let (constraints_per_air, rap_pk_per_air): (Vec<_>, Vec<_>) = mpk
183 .per_air
184 .iter()
185 .map(|pk| {
186 (
187 SymbolicConstraints::from(&pk.vk.symbolic_constraints),
188 &pk.rap_partial_pk,
189 )
190 })
191 .unzip();
192
193 let trace_views = zip(&mpk.per_air, trace_views)
194 .map(|(pk, v)| PairView {
195 log_trace_height: log2_strict_usize(v.partitioned_main.first().unwrap().height())
196 as u8,
197 preprocessed: pk.preprocessed_data.as_ref().map(|p| p.trace.clone()), partitioned_main: v.partitioned_main,
199 public_values: v.public_values,
200 })
201 .collect_vec();
202 let (rap_phase_seq_proof, rap_phase_seq_data) = self
203 .config()
204 .rap_phase_seq()
205 .partially_prove(
206 challenger,
207 &constraints_per_air.iter().collect_vec(),
208 &rap_pk_per_air,
209 trace_views,
210 )
211 .map_or((None, None), |(p, d)| (Some(p), Some(d)));
212
213 let mvk_view = mpk.vk_view();
214
215 let mut perm_matrix_idx = 0usize;
216 let rap_views_per_phase;
217 let perm_trace_per_air = if let Some(phase_data) = rap_phase_seq_data {
218 assert_eq!(mvk_view.num_phases(), 1);
219 assert_eq!(
220 mvk_view.num_challenges_in_phase(0),
221 phase_data.challenges.len()
222 );
223 let perm_views = zip_eq(
224 &phase_data.after_challenge_trace_per_air,
225 phase_data.exposed_values_per_air,
226 )
227 .map(|(perm_trace, exposed_values)| {
228 let mut matrix_idx = None;
229 if perm_trace.is_some() {
230 matrix_idx = Some(perm_matrix_idx);
231 perm_matrix_idx += 1;
232 }
233 RapSinglePhaseView {
234 inner: matrix_idx,
235 challenges: phase_data.challenges.clone(),
236 exposed_values: exposed_values.unwrap_or_default(),
237 }
238 })
239 .collect_vec();
240 rap_views_per_phase = vec![perm_views]; phase_data.after_challenge_trace_per_air
242 } else {
243 assert_eq!(mvk_view.num_phases(), 0);
244 rap_views_per_phase = vec![];
245 vec![None; num_airs]
246 };
247
248 let committed_pcs_data_per_phase: Vec<(Com<SC>, PcsData<SC>)> =
251 info_span!("perm_trace_commit")
252 .in_scope(|| {
253 let (log_trace_heights, flattened_traces): (Vec<_>, Vec<_>) =
254 perm_trace_per_air
255 .into_iter()
256 .flatten()
257 .map(|perm_trace| {
258 let trace = unsafe { transmute_to_base(perm_trace) };
261 let height = trace.height();
262 let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
263 let domain = self.pcs().natural_domain_for_degree(height);
264 (log_height, (domain, trace))
265 })
266 .collect();
267 if !flattened_traces.is_empty() {
269 let (commit, data) = self.pcs().commit(flattened_traces);
270 Some((commit, PcsData::new(Arc::new(data), log_trace_heights)))
271 } else {
272 None
273 }
274 })
275 .into_iter()
276 .collect();
277 let prover_view = ProverDataAfterRapPhases {
278 committed_pcs_data_per_phase,
279 rap_views_per_phase,
280 };
281 (rap_phase_seq_proof, prover_view)
282 }
283}
284
285impl<SC: StarkGenericConfig> hal::QuotientCommitter<CpuBackend<SC>> for CpuDevice<SC> {
286 fn eval_and_commit_quotient(
287 &self,
288 challenger: &mut SC::Challenger,
289 pk_views: &[&DeviceStarkProvingKey<CpuBackend<SC>>],
290 public_values: &[Vec<Val<SC>>],
291 cached_pcs_datas_per_air: &[Vec<PcsData<SC>>],
292 common_main_pcs_data: &PcsData<SC>,
293 prover_data_after: &ProverDataAfterRapPhases<CpuBackend<SC>>,
294 ) -> (Com<SC>, PcsData<SC>) {
295 let pcs = self.pcs();
296 let alpha: SC::Challenge = challenger.sample_ext_element();
298 tracing::debug!("alpha: {alpha:?}");
299 let mut common_main_idx = 0;
301 let extended_views = izip!(pk_views, cached_pcs_datas_per_air, public_values)
302 .enumerate()
303 .map(|(i, (pk, cached_pcs_datas, pvs))| {
304 let quotient_degree = pk.vk.quotient_degree;
305 let log_trace_height = if pk.vk.has_common_main() {
306 common_main_pcs_data.log_trace_heights[common_main_idx]
307 } else {
308 cached_pcs_datas[0].log_trace_heights[0]
309 };
310 let trace_domain = pcs.natural_domain_for_degree(1usize << log_trace_height);
311 let quotient_domain = trace_domain
312 .create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
313 let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
316 pcs.get_evaluations_on_domain(
317 &cv.data.data,
318 cv.matrix_idx as usize,
319 quotient_domain,
320 )
321 });
322 let mut partitioned_main: Vec<_> = cached_pcs_datas
324 .iter()
325 .map(|cv| pcs.get_evaluations_on_domain(&cv.data, 0, quotient_domain))
326 .collect();
327 if pk.vk.has_common_main() {
328 partitioned_main.push(pcs.get_evaluations_on_domain(
329 &common_main_pcs_data.data,
330 common_main_idx,
331 quotient_domain,
332 ));
333 common_main_idx += 1;
334 }
335 let mut per_phase = zip(
336 &prover_data_after.committed_pcs_data_per_phase,
337 &prover_data_after.rap_views_per_phase,
338 )
339 .map(|((_, pcs_data), rap_views)| -> Option<_> {
340 let rap_view = rap_views.get(i)?;
341 let matrix_idx = rap_view.inner?;
342 let extended_matrix =
343 pcs.get_evaluations_on_domain(&pcs_data.data, matrix_idx, quotient_domain);
344 Some(RapSinglePhaseView {
345 inner: Some(extended_matrix),
346 challenges: rap_view.challenges.clone(),
347 exposed_values: rap_view.exposed_values.clone(),
348 })
349 })
350 .collect_vec();
351 while let Some(last) = per_phase.last() {
352 if last.is_none() {
353 per_phase.pop();
354 } else {
355 break;
356 }
357 }
358 let per_phase = per_phase
359 .into_iter()
360 .map(|v| v.unwrap_or_default())
361 .collect();
362
363 RapView {
364 log_trace_height,
365 preprocessed,
366 partitioned_main,
367 public_values: pvs.to_vec(),
368 per_phase,
369 }
370 })
371 .collect_vec();
372
373 let (constraints, quotient_degrees): (Vec<_>, Vec<_>) = pk_views
374 .iter()
375 .map(|pk| {
376 (
377 &pk.vk.symbolic_constraints.constraints,
378 pk.vk.quotient_degree,
379 )
380 })
381 .unzip();
382 let qc = QuotientCommitter::new(self.pcs(), alpha, self.log_blowup_factor);
383 let quotient_values = qc.quotient_values(&constraints, extended_views, "ient_degrees);
384
385 qc.commit(quotient_values)
387 }
388}
389
390impl<SC: StarkGenericConfig> hal::OpeningProver<CpuBackend<SC>> for CpuDevice<SC> {
391 fn open(
392 &self,
393 challenger: &mut SC::Challenger,
394 preprocessed: Vec<&PcsData<SC>>,
397 main: Vec<PcsData<SC>>,
401 after_phase: Vec<PcsData<SC>>,
403 quotient_data: PcsData<SC>,
405 quotient_degrees: &[u8],
407 ) -> OpeningProof<PcsProof<SC>, SC::Challenge> {
408 let zeta: SC::Challenge = challenger.sample_ext_element();
410 tracing::debug!("zeta: {zeta:?}");
411
412 let pcs = self.pcs();
413 let domain = |log_height| pcs.natural_domain_for_degree(1usize << log_height);
414 let opener = OpeningProver::<SC>::new(pcs, zeta);
415 let preprocessed = preprocessed
416 .iter()
417 .map(|v| {
418 assert_eq!(v.log_trace_heights.len(), 1);
419 (v.data.as_ref(), domain(v.log_trace_heights[0]))
420 })
421 .collect();
422 let main = main
423 .iter()
424 .map(|v| {
425 let domains = v.log_trace_heights.iter().copied().map(domain).collect();
426 (v.data.as_ref(), domains)
427 })
428 .collect();
429 let after_phase: Vec<_> = after_phase
430 .iter()
431 .map(|v| {
432 let domains = v.log_trace_heights.iter().copied().map(domain).collect();
433 (v.data.as_ref(), domains)
434 })
435 .collect();
436 opener.open(
437 challenger,
438 preprocessed,
439 main,
440 after_phase,
441 "ient_data.data,
442 quotient_degrees,
443 )
444 }
445}
446
447impl<SC> DeviceDataTransporter<SC, CpuBackend<SC>> for CpuDevice<SC>
448where
449 SC: StarkGenericConfig,
450{
451 fn transport_pk_to_device(
452 &self,
453 mpk: &MultiStarkProvingKey<SC>,
454 ) -> DeviceMultiStarkProvingKey<CpuBackend<SC>> {
455 let per_air = mpk
456 .per_air
457 .iter()
458 .map(|pk| {
459 let preprocessed_data = pk.preprocessed_data.as_ref().map(|pd| {
460 let pcs_data_view = PcsData {
461 data: pd.data.clone(),
462 log_trace_heights: vec![log2_strict_usize(pd.trace.height()) as u8],
463 };
464 SingleCommitPreimage {
465 trace: pd.trace.clone(),
466 data: pcs_data_view,
467 matrix_idx: 0,
468 }
469 });
470 DeviceStarkProvingKey {
471 air_name: pk.air_name.clone(),
472 vk: pk.vk.clone(),
473 preprocessed_data,
474 rap_partial_pk: pk.rap_partial_pk.clone(),
475 }
476 })
477 .collect();
478 DeviceMultiStarkProvingKey::new(
479 per_air,
480 mpk.trace_height_constraints.clone(),
481 mpk.vk_pre_hash.clone(),
482 )
483 }
484 fn transport_matrix_to_device(
485 &self,
486 matrix: &Arc<RowMajorMatrix<Val<SC>>>,
487 ) -> Arc<RowMajorMatrix<Val<SC>>> {
488 matrix.clone()
489 }
490
491 fn transport_committed_trace_to_device(
492 &self,
493 commitment: Com<SC>,
494 trace: &Arc<RowMajorMatrix<Val<SC>>>,
495 prover_data: &Arc<PcsProverData<SC>>,
496 ) -> CommittedTraceData<CpuBackend<SC>> {
497 let log_trace_height: u8 = log2_strict_usize(trace.height()).try_into().unwrap();
498 let data = PcsData::new(prover_data.clone(), vec![log_trace_height]);
499 CommittedTraceData {
500 commitment,
501 trace: trace.clone(),
502 data,
503 }
504 }
505
506 fn transport_matrix_from_device_to_host(
507 &self,
508 matrix: &Arc<RowMajorMatrix<Val<SC>>>,
509 ) -> Arc<RowMajorMatrix<Val<SC>>> {
510 matrix.clone()
511 }
512}
513
514unsafe fn transmute_to_base<F: Field, EF: ExtensionField<F>>(
520 ext_matrix: RowMajorMatrix<EF>,
521) -> RowMajorMatrix<F> {
522 debug_assert_eq!(align_of::<EF>(), align_of::<F>());
523 debug_assert_eq!(size_of::<EF>(), size_of::<F>() * EF::D);
524 let width = ext_matrix.width * EF::D;
525 let mut values = ManuallyDrop::new(ext_matrix.values);
527 let mut len = values.len();
528 let mut cap = values.capacity();
529 let ptr = values.as_mut_ptr();
530 len *= EF::D;
531 cap *= EF::D;
532 let base_values = Vec::from_raw_parts(ptr as *mut F, len, cap);
537 RowMajorMatrix::new(base_values, width)
538}