1use alloc::collections::BTreeMap;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::fmt::Debug;
5use core::marker::PhantomData;
6
7use itertools::{izip, Itertools};
8use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
9use p3_commit::{Mmcs, OpenedValues, Pcs, PolynomialSpace, TwoAdicMultiplicativeCoset};
10use p3_dft::TwoAdicSubgroupDft;
11use p3_field::{
12 batch_multiplicative_inverse, cyclic_subgroup_coset_known_order, dot_product, ExtensionField,
13 Field, TwoAdicField,
14};
15use p3_interpolation::interpolate_coset;
16use p3_matrix::bitrev::{BitReversableMatrix, BitReversalPerm, BitReversedMatrixView};
17use p3_matrix::dense::{DenseMatrix, RowMajorMatrix};
18use p3_matrix::{Dimensions, Matrix};
19use p3_maybe_rayon::prelude::*;
20use p3_util::linear_map::LinearMap;
21use p3_util::zip_eq::zip_eq;
22use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
23use serde::{Deserialize, Serialize};
24use tracing::{info_span, instrument};
25
26use crate::verifier::{self, FriError};
27use crate::{prover, FriConfig, FriGenericConfig, FriProof};
28
29#[derive(Debug)]
30pub struct TwoAdicFriPcs<Val, Dft, InputMmcs, FriMmcs> {
31 dft: Dft,
32 mmcs: InputMmcs,
33 fri: FriConfig<FriMmcs>,
34 _phantom: PhantomData<Val>,
35}
36
37impl<Val, Dft, InputMmcs, FriMmcs> TwoAdicFriPcs<Val, Dft, InputMmcs, FriMmcs> {
38 pub const fn new(dft: Dft, mmcs: InputMmcs, fri: FriConfig<FriMmcs>) -> Self {
39 Self {
40 dft,
41 mmcs,
42 fri,
43 _phantom: PhantomData,
44 }
45 }
46}
47
48#[derive(Serialize, Deserialize, Clone)]
49#[serde(bound = "")]
50pub struct BatchOpening<Val: Field, InputMmcs: Mmcs<Val>> {
51 pub opened_values: Vec<Vec<Val>>,
52 pub opening_proof: <InputMmcs as Mmcs<Val>>::Proof,
53}
54
55pub struct TwoAdicFriGenericConfig<InputProof, InputError>(
56 pub PhantomData<(InputProof, InputError)>,
57);
58
59pub type TwoAdicFriGenericConfigForMmcs<F, M> =
60 TwoAdicFriGenericConfig<Vec<BatchOpening<F, M>>, <M as Mmcs<F>>::Error>;
61
62impl<F: TwoAdicField, InputProof, InputError: Debug> FriGenericConfig<F>
63 for TwoAdicFriGenericConfig<InputProof, InputError>
64{
65 type InputProof = InputProof;
66 type InputError = InputError;
67
68 fn extra_query_index_bits(&self) -> usize {
69 0
70 }
71
72 fn fold_row(
73 &self,
74 index: usize,
75 log_height: usize,
76 beta: F,
77 evals: impl Iterator<Item = F>,
78 ) -> F {
79 let arity = 2;
80 let log_arity = 1;
81 let (e0, e1) = evals
82 .collect_tuple()
83 .expect("TwoAdicFriFolder only supports arity=2");
84 let subgroup_start = F::two_adic_generator(log_height + log_arity)
88 .exp_u64(reverse_bits_len(index, log_height) as u64);
89 let mut xs = F::two_adic_generator(log_arity)
90 .shifted_powers(subgroup_start)
91 .take(arity)
92 .collect_vec();
93 reverse_slice_index_bits(&mut xs);
94 assert_eq!(log_arity, 1, "can only interpolate two points for now");
95 e0 + (beta - xs[0]) * (e1 - e0) / (xs[1] - xs[0])
97 }
98
99 fn fold_matrix<M: Matrix<F>>(&self, beta: F, m: M) -> Vec<F> {
100 let g_inv = F::two_adic_generator(log2_strict_usize(m.height()) + 1).inverse();
111 let one_half = F::ONE.halve();
112 let half_beta = beta * one_half;
113
114 let mut powers = g_inv
118 .shifted_powers(half_beta)
119 .take(m.height())
120 .collect_vec();
121 reverse_slice_index_bits(&mut powers);
122
123 m.par_rows()
124 .zip(powers)
125 .map(|(mut row, power)| {
126 let (lo, hi) = row.next_tuple().unwrap();
127 (one_half + power) * lo + (one_half - power) * hi
128 })
129 .collect()
130 }
131}
132
133impl<Val, Dft, InputMmcs, FriMmcs, Challenge, Challenger> Pcs<Challenge, Challenger>
134 for TwoAdicFriPcs<Val, Dft, InputMmcs, FriMmcs>
135where
136 Val: TwoAdicField,
137 Dft: TwoAdicSubgroupDft<Val>,
138 InputMmcs: Mmcs<Val>,
139 FriMmcs: Mmcs<Challenge>,
140 Challenge: TwoAdicField + ExtensionField<Val>,
141 Challenger:
142 FieldChallenger<Val> + CanObserve<FriMmcs::Commitment> + GrindingChallenger<Witness = Val>,
143{
144 type Domain = TwoAdicMultiplicativeCoset<Val>;
145 type Commitment = InputMmcs::Commitment;
146 type ProverData = InputMmcs::ProverData<RowMajorMatrix<Val>>;
147 type EvaluationsOnDomain<'a> = BitReversedMatrixView<DenseMatrix<Val, &'a [Val]>>;
148 type Proof = FriProof<Challenge, FriMmcs, Val, Vec<BatchOpening<Val, InputMmcs>>>;
149 type Error = FriError<FriMmcs::Error, InputMmcs::Error>;
150
151 fn natural_domain_for_degree(&self, degree: usize) -> Self::Domain {
152 let log_n = log2_strict_usize(degree);
153 TwoAdicMultiplicativeCoset {
154 log_n,
155 shift: Val::ONE,
156 }
157 }
158
159 fn commit(
160 &self,
161 evaluations: Vec<(Self::Domain, RowMajorMatrix<Val>)>,
162 ) -> (Self::Commitment, Self::ProverData) {
163 let ldes: Vec<_> = evaluations
164 .into_iter()
165 .map(|(domain, evals)| {
166 assert_eq!(domain.size(), evals.height());
167 let shift = Val::GENERATOR / domain.shift;
168 self.dft
170 .coset_lde_batch(evals, self.fri.log_blowup, shift)
171 .bit_reverse_rows()
172 .to_row_major_matrix()
173 })
174 .collect();
175
176 self.mmcs.commit(ldes)
177 }
178
179 fn get_evaluations_on_domain<'a>(
180 &self,
181 prover_data: &'a Self::ProverData,
182 idx: usize,
183 domain: Self::Domain,
184 ) -> Self::EvaluationsOnDomain<'a> {
185 assert_eq!(domain.shift, Val::GENERATOR);
187 let lde = self.mmcs.get_matrices(prover_data)[idx];
188 assert!(lde.height() >= domain.size());
189 lde.split_rows(domain.size()).0.bit_reverse_rows()
190 }
191
192 fn open(
193 &self,
194 rounds: Vec<(
196 &Self::ProverData,
197 Vec<
199 Vec<Challenge>,
201 >,
202 )>,
203 challenger: &mut Challenger,
204 ) -> (OpenedValues<Challenge>, Self::Proof) {
205 let mats_and_points = rounds
243 .iter()
244 .map(|(data, points)| {
245 let mats = self
246 .mmcs
247 .get_matrices(data)
248 .into_iter()
249 .map(|m| m.as_view())
250 .collect_vec();
251 debug_assert_eq!(
252 mats.len(),
253 points.len(),
254 "each matrix should have a corresponding set of evaluation points"
255 );
256 (mats, points)
257 })
258 .collect_vec();
259 let mats = mats_and_points
260 .iter()
261 .flat_map(|(mats, _)| mats)
262 .collect_vec();
263
264 let global_max_height = mats.iter().map(|m| m.height()).max().unwrap();
265 let log_global_max_height = log2_strict_usize(global_max_height);
266
267 let inv_denoms = compute_inverse_denominators(&mats_and_points, Val::GENERATOR);
270
271 let all_opened_values = mats_and_points
273 .iter()
274 .map(|(mats, points)| {
275 izip!(mats.iter(), points.iter())
276 .map(|(mat, points_for_mat)| {
277 points_for_mat
278 .iter()
279 .map(|&point| {
280 let _guard =
281 info_span!("evaluate matrix", dims = %mat.dimensions())
282 .entered();
283
284 let ys =
286 info_span!("compute opened values with Lagrange interpolation")
287 .in_scope(|| {
288 let h = mat.height() >> self.fri.log_blowup;
289 let (low_coset, _) = mat.split_rows(h);
290 let mut inv_denoms =
291 inv_denoms.get(&point).unwrap()[..h].to_vec();
292 reverse_slice_index_bits(&mut inv_denoms);
293 interpolate_coset(
294 &BitReversalPerm::new_view(low_coset),
295 Val::GENERATOR,
296 point,
297 Some(&inv_denoms),
298 )
299 });
300 ys.iter().for_each(|&y| challenger.observe_ext_element(y));
301 ys
302 })
303 .collect_vec()
304 })
305 .collect_vec()
306 })
307 .collect_vec();
308
309 let alpha: Challenge = challenger.sample_ext_element();
311
312 let mut num_reduced = [0; 32];
313 let mut reduced_openings: [_; 32] = core::array::from_fn(|_| None);
314
315 for ((mats, points), openings_for_round) in
316 mats_and_points.iter().zip(all_opened_values.iter())
317 {
318 for (mat, points_for_mat, openings_for_mat) in
319 izip!(mats.iter(), points.iter(), openings_for_round.iter())
320 {
321 let _guard =
322 info_span!("reduce matrix quotient", dims = %mat.dimensions()).entered();
323
324 let log_height = log2_strict_usize(mat.height());
325 let reduced_opening_for_log_height = reduced_openings[log_height]
326 .get_or_insert_with(|| vec![Challenge::ZERO; mat.height()]);
327 debug_assert_eq!(reduced_opening_for_log_height.len(), mat.height());
328
329 let mat_compressed = info_span!("compress mat")
330 .in_scope(|| mat.dot_ext_powers(alpha).collect::<Vec<_>>());
331
332 for (&point, openings) in points_for_mat.iter().zip(openings_for_mat) {
333 let alpha_pow_offset = alpha.exp_u64(num_reduced[log_height] as u64);
334 let reduced_openings: Challenge =
335 dot_product(alpha.powers(), openings.iter().copied());
336
337 info_span!("reduce rows").in_scope(|| {
338 mat_compressed
339 .par_iter()
340 .zip(reduced_opening_for_log_height.par_iter_mut())
341 .zip(inv_denoms.get(&point).unwrap().par_iter())
344 .for_each(|((&reduced_row, ro), &inv_denom)| {
345 *ro +=
346 alpha_pow_offset * (reduced_openings - reduced_row) * inv_denom
347 });
348 });
349
350 num_reduced[log_height] += mat.width();
351 }
352 }
353 }
354
355 let fri_input = reduced_openings.into_iter().rev().flatten().collect_vec();
356
357 let g: TwoAdicFriGenericConfigForMmcs<Val, InputMmcs> =
358 TwoAdicFriGenericConfig(PhantomData);
359
360 let fri_proof = prover::prove(&g, &self.fri, fri_input, challenger, |index| {
361 rounds
362 .iter()
363 .map(|(data, _)| {
364 let log_max_height = log2_strict_usize(self.mmcs.get_max_height(data));
365 let bits_reduced = log_global_max_height - log_max_height;
366 let reduced_index = index >> bits_reduced;
367 let (opened_values, opening_proof) = self.mmcs.open_batch(reduced_index, data);
368 BatchOpening {
369 opened_values,
370 opening_proof,
371 }
372 })
373 .collect()
374 });
375
376 (all_opened_values, fri_proof)
377 }
378
379 fn verify(
380 &self,
381 rounds: Vec<(
383 Self::Commitment,
384 Vec<(
386 Self::Domain,
388 Vec<(
390 Challenge,
392 Vec<Challenge>,
394 )>,
395 )>,
396 )>,
397 proof: &Self::Proof,
398 challenger: &mut Challenger,
399 ) -> Result<(), Self::Error> {
400 for (_, round) in rounds.iter() {
402 for (_, mat) in round.iter() {
403 for (_, point) in mat.iter() {
404 point
405 .iter()
406 .for_each(|&opening| challenger.observe_ext_element(opening));
407 }
408 }
409 }
410
411 let alpha: Challenge = challenger.sample_ext_element();
413
414 let log_global_max_height =
415 proof.commit_phase_commits.len() + self.fri.log_blowup + self.fri.log_final_poly_len;
416
417 let g: TwoAdicFriGenericConfigForMmcs<Val, InputMmcs> =
418 TwoAdicFriGenericConfig(PhantomData);
419
420 verifier::verify(&g, &self.fri, proof, challenger, |index, input_proof| {
421 let mut reduced_openings = BTreeMap::<usize, (Challenge, Challenge)>::new();
425
426 for (batch_opening, (batch_commit, mats)) in
427 zip_eq(input_proof, &rounds, FriError::InvalidProofShape)?
428 {
429 let batch_heights = mats
430 .iter()
431 .map(|(domain, _)| domain.size() << self.fri.log_blowup)
432 .collect_vec();
433 let batch_dims = batch_heights
434 .iter()
435 .map(|&height| Dimensions { width: 0, height })
437 .collect_vec();
438
439 if let Some(batch_max_height) = batch_heights.iter().max() {
440 let log_batch_max_height = log2_strict_usize(*batch_max_height);
441 let bits_reduced = log_global_max_height - log_batch_max_height;
442 let reduced_index = index >> bits_reduced;
443
444 self.mmcs.verify_batch(
445 batch_commit,
446 &batch_dims,
447 reduced_index,
448 &batch_opening.opened_values,
449 &batch_opening.opening_proof,
450 )
451 } else {
452 self.mmcs.verify_batch(
454 batch_commit,
455 &[],
456 0,
457 &batch_opening.opened_values,
458 &batch_opening.opening_proof,
459 )
460 }
461 .map_err(FriError::InputError)?;
462
463 for (mat_opening, (mat_domain, mat_points_and_values)) in zip_eq(
464 &batch_opening.opened_values,
465 mats,
466 FriError::InvalidProofShape,
467 )? {
468 let log_height = log2_strict_usize(mat_domain.size()) + self.fri.log_blowup;
469
470 let bits_reduced = log_global_max_height - log_height;
471 let rev_reduced_index = reverse_bits_len(index >> bits_reduced, log_height);
472
473 let x = Val::GENERATOR
476 * Val::two_adic_generator(log_height).exp_u64(rev_reduced_index as u64);
477
478 let (alpha_pow, ro) = reduced_openings
479 .entry(log_height)
480 .or_insert((Challenge::ONE, Challenge::ZERO));
481
482 for (z, ps_at_z) in mat_points_and_values {
483 for (&p_at_x, &p_at_z) in
484 zip_eq(mat_opening, ps_at_z, FriError::InvalidProofShape)?
485 {
486 let quotient = (-p_at_z + p_at_x) / (-*z + x);
487 *ro += *alpha_pow * quotient;
488 *alpha_pow *= alpha;
489 }
490 }
491 }
492 }
493
494 if let Some((_alpha_pow, ro)) = reduced_openings.remove(&self.fri.log_blowup) {
498 assert!(ro.is_zero());
499 }
500
501 Ok(reduced_openings
503 .into_iter()
504 .rev()
505 .map(|(log_height, (_alpha_pow, ro))| (log_height, ro))
506 .collect())
507 })?;
508
509 Ok(())
510 }
511}
512
513#[instrument(skip_all)]
514fn compute_inverse_denominators<F: TwoAdicField, EF: ExtensionField<F>, M: Matrix<F>>(
515 mats_and_points: &[(Vec<M>, &Vec<Vec<EF>>)],
516 coset_shift: F,
517) -> LinearMap<EF, Vec<EF>> {
518 let mut max_log_height_for_point: LinearMap<EF, usize> = LinearMap::new();
519 for (mats, points) in mats_and_points {
520 for (mat, points_for_mat) in izip!(mats, *points) {
521 let log_height = log2_strict_usize(mat.height());
522 for &z in points_for_mat {
523 if let Some(lh) = max_log_height_for_point.get_mut(&z) {
524 *lh = core::cmp::max(*lh, log_height);
525 } else {
526 max_log_height_for_point.insert(z, log_height);
527 }
528 }
529 }
530 }
531
532 let max_log_height = *max_log_height_for_point.values().max().unwrap();
534 let mut subgroup = cyclic_subgroup_coset_known_order(
535 F::two_adic_generator(max_log_height),
536 coset_shift,
537 1 << max_log_height,
538 )
539 .collect_vec();
540 reverse_slice_index_bits(&mut subgroup);
541
542 max_log_height_for_point
543 .into_iter()
544 .map(|(z, log_height)| {
545 (
546 z,
547 batch_multiplicative_inverse(
548 &subgroup[..(1 << log_height)]
549 .iter()
550 .map(|&x| z - x)
551 .collect_vec(),
552 ),
553 )
554 })
555 .collect()
556}