p3_fri/
two_adic_pcs.rs

1use alloc::collections::BTreeMap;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::fmt::Debug;
5use core::marker::PhantomData;
6
7use itertools::{izip, Itertools};
8use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
9use p3_commit::{Mmcs, OpenedValues, Pcs, PolynomialSpace, TwoAdicMultiplicativeCoset};
10use p3_dft::TwoAdicSubgroupDft;
11use p3_field::{
12    batch_multiplicative_inverse, cyclic_subgroup_coset_known_order, dot_product, ExtensionField,
13    Field, TwoAdicField,
14};
15use p3_interpolation::interpolate_coset;
16use p3_matrix::bitrev::{BitReversableMatrix, BitReversalPerm, BitReversedMatrixView};
17use p3_matrix::dense::{DenseMatrix, RowMajorMatrix};
18use p3_matrix::{Dimensions, Matrix};
19use p3_maybe_rayon::prelude::*;
20use p3_util::linear_map::LinearMap;
21use p3_util::zip_eq::zip_eq;
22use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
23use serde::{Deserialize, Serialize};
24use tracing::{info_span, instrument};
25
26use crate::verifier::{self, FriError};
27use crate::{prover, FriConfig, FriGenericConfig, FriProof};
28
29#[derive(Debug)]
30pub struct TwoAdicFriPcs<Val, Dft, InputMmcs, FriMmcs> {
31    dft: Dft,
32    mmcs: InputMmcs,
33    fri: FriConfig<FriMmcs>,
34    _phantom: PhantomData<Val>,
35}
36
37impl<Val, Dft, InputMmcs, FriMmcs> TwoAdicFriPcs<Val, Dft, InputMmcs, FriMmcs> {
38    pub const fn new(dft: Dft, mmcs: InputMmcs, fri: FriConfig<FriMmcs>) -> Self {
39        Self {
40            dft,
41            mmcs,
42            fri,
43            _phantom: PhantomData,
44        }
45    }
46}
47
48#[derive(Serialize, Deserialize, Clone)]
49#[serde(bound = "")]
50pub struct BatchOpening<Val: Field, InputMmcs: Mmcs<Val>> {
51    pub opened_values: Vec<Vec<Val>>,
52    pub opening_proof: <InputMmcs as Mmcs<Val>>::Proof,
53}
54
55pub struct TwoAdicFriGenericConfig<InputProof, InputError>(
56    pub PhantomData<(InputProof, InputError)>,
57);
58
59pub type TwoAdicFriGenericConfigForMmcs<F, M> =
60    TwoAdicFriGenericConfig<Vec<BatchOpening<F, M>>, <M as Mmcs<F>>::Error>;
61
62impl<F: TwoAdicField, InputProof, InputError: Debug> FriGenericConfig<F>
63    for TwoAdicFriGenericConfig<InputProof, InputError>
64{
65    type InputProof = InputProof;
66    type InputError = InputError;
67
68    fn extra_query_index_bits(&self) -> usize {
69        0
70    }
71
72    fn fold_row(
73        &self,
74        index: usize,
75        log_height: usize,
76        beta: F,
77        evals: impl Iterator<Item = F>,
78    ) -> F {
79        let arity = 2;
80        let log_arity = 1;
81        let (e0, e1) = evals
82            .collect_tuple()
83            .expect("TwoAdicFriFolder only supports arity=2");
84        // If performance critical, make this API stateful to avoid this
85        // This is a bit more math than is necessary, but leaving it here
86        // in case we want higher arity in the future
87        let subgroup_start = F::two_adic_generator(log_height + log_arity)
88            .exp_u64(reverse_bits_len(index, log_height) as u64);
89        let mut xs = F::two_adic_generator(log_arity)
90            .shifted_powers(subgroup_start)
91            .take(arity)
92            .collect_vec();
93        reverse_slice_index_bits(&mut xs);
94        assert_eq!(log_arity, 1, "can only interpolate two points for now");
95        // interpolate and evaluate at beta
96        e0 + (beta - xs[0]) * (e1 - e0) / (xs[1] - xs[0])
97    }
98
99    fn fold_matrix<M: Matrix<F>>(&self, beta: F, m: M) -> Vec<F> {
100        // We use the fact that
101        //     p_e(x^2) = (p(x) + p(-x)) / 2
102        //     p_o(x^2) = (p(x) - p(-x)) / (2 x)
103        // that is,
104        //     p_e(g^(2i)) = (p(g^i) + p(g^(n/2 + i))) / 2
105        //     p_o(g^(2i)) = (p(g^i) - p(g^(n/2 + i))) / (2 g^i)
106        // so
107        //     result(g^(2i)) = p_e(g^(2i)) + beta p_o(g^(2i))
108        //                    = (1/2 + beta/2 g_inv^i) p(g^i)
109        //                    + (1/2 - beta/2 g_inv^i) p(g^(n/2 + i))
110        let g_inv = F::two_adic_generator(log2_strict_usize(m.height()) + 1).inverse();
111        let one_half = F::ONE.halve();
112        let half_beta = beta * one_half;
113
114        // TODO: vectorize this (after we have packed extension fields)
115
116        // beta/2 times successive powers of g_inv
117        let mut powers = g_inv
118            .shifted_powers(half_beta)
119            .take(m.height())
120            .collect_vec();
121        reverse_slice_index_bits(&mut powers);
122
123        m.par_rows()
124            .zip(powers)
125            .map(|(mut row, power)| {
126                let (lo, hi) = row.next_tuple().unwrap();
127                (one_half + power) * lo + (one_half - power) * hi
128            })
129            .collect()
130    }
131}
132
133impl<Val, Dft, InputMmcs, FriMmcs, Challenge, Challenger> Pcs<Challenge, Challenger>
134    for TwoAdicFriPcs<Val, Dft, InputMmcs, FriMmcs>
135where
136    Val: TwoAdicField,
137    Dft: TwoAdicSubgroupDft<Val>,
138    InputMmcs: Mmcs<Val>,
139    FriMmcs: Mmcs<Challenge>,
140    Challenge: TwoAdicField + ExtensionField<Val>,
141    Challenger:
142        FieldChallenger<Val> + CanObserve<FriMmcs::Commitment> + GrindingChallenger<Witness = Val>,
143{
144    type Domain = TwoAdicMultiplicativeCoset<Val>;
145    type Commitment = InputMmcs::Commitment;
146    type ProverData = InputMmcs::ProverData<RowMajorMatrix<Val>>;
147    type EvaluationsOnDomain<'a> = BitReversedMatrixView<DenseMatrix<Val, &'a [Val]>>;
148    type Proof = FriProof<Challenge, FriMmcs, Val, Vec<BatchOpening<Val, InputMmcs>>>;
149    type Error = FriError<FriMmcs::Error, InputMmcs::Error>;
150
151    fn natural_domain_for_degree(&self, degree: usize) -> Self::Domain {
152        let log_n = log2_strict_usize(degree);
153        TwoAdicMultiplicativeCoset {
154            log_n,
155            shift: Val::ONE,
156        }
157    }
158
159    fn commit(
160        &self,
161        evaluations: Vec<(Self::Domain, RowMajorMatrix<Val>)>,
162    ) -> (Self::Commitment, Self::ProverData) {
163        let ldes: Vec<_> = evaluations
164            .into_iter()
165            .map(|(domain, evals)| {
166                assert_eq!(domain.size(), evals.height());
167                let shift = Val::GENERATOR / domain.shift;
168                // Commit to the bit-reversed LDE.
169                self.dft
170                    .coset_lde_batch(evals, self.fri.log_blowup, shift)
171                    .bit_reverse_rows()
172                    .to_row_major_matrix()
173            })
174            .collect();
175
176        self.mmcs.commit(ldes)
177    }
178
179    fn get_evaluations_on_domain<'a>(
180        &self,
181        prover_data: &'a Self::ProverData,
182        idx: usize,
183        domain: Self::Domain,
184    ) -> Self::EvaluationsOnDomain<'a> {
185        // todo: handle extrapolation for LDEs we don't have
186        assert_eq!(domain.shift, Val::GENERATOR);
187        let lde = self.mmcs.get_matrices(prover_data)[idx];
188        assert!(lde.height() >= domain.size());
189        lde.split_rows(domain.size()).0.bit_reverse_rows()
190    }
191
192    fn open(
193        &self,
194        // For each round,
195        rounds: Vec<(
196            &Self::ProverData,
197            // for each matrix,
198            Vec<
199                // points to open
200                Vec<Challenge>,
201            >,
202        )>,
203        challenger: &mut Challenger,
204    ) -> (OpenedValues<Challenge>, Self::Proof) {
205        /*
206
207        A quick rundown of the optimizations in this function:
208        We are trying to compute sum_i alpha^i * (p(X) - y)/(X - z),
209        for each z an opening point, y = p(z). Each p(X) is given as evaluations in bit-reversed order
210        in the columns of the matrices. y is computed by barycentric interpolation.
211        X and p(X) are in the base field; alpha, y and z are in the extension.
212        The primary goal is to minimize extension multiplications.
213
214        - Instead of computing all alpha^i, we just compute alpha^i for i up to the largest width
215        of a matrix, then multiply by an "alpha offset" when accumulating.
216              a^0 x0 + a^1 x1 + a^2 x2 + a^3 x3 + ...
217            = a^0 ( a^0 x0 + a^1 x1 ) + a^2 ( a^0 x2 + a^1 x3 ) + ...
218            (see `alpha_pows`, `alpha_pow_offset`, `num_reduced`)
219
220        - For each unique point z, we precompute 1/(X-z) for the largest subgroup opened at this point.
221        Since we compute it in bit-reversed order, smaller subgroups can simply truncate the vector.
222            (see `inv_denoms`)
223
224        - Then, for each matrix (with columns p_i) and opening point z, we want:
225            for each row (corresponding to subgroup element X):
226                reduced[X] += alpha_offset * sum_i [ alpha^i * inv_denom[X] * (p_i[X] - y[i]) ]
227
228            We can factor out inv_denom, and expand what's left:
229                reduced[X] += alpha_offset * inv_denom[X] * sum_i [ alpha^i * p_i[X] - alpha^i * y[i] ]
230
231            And separate the sum:
232                reduced[X] += alpha_offset * inv_denom[X] * [ sum_i [ alpha^i * p_i[X] ] - sum_i [ alpha^i * y[i] ] ]
233
234            And now the last sum doesn't depend on X, so we can precompute that for the matrix, too.
235            So the hot loop (that depends on both X and i) is just:
236                sum_i [ alpha^i * p_i[X] ]
237
238            with alpha^i an extension, p_i[X] a base
239
240        */
241
242        let mats_and_points = rounds
243            .iter()
244            .map(|(data, points)| {
245                let mats = self
246                    .mmcs
247                    .get_matrices(data)
248                    .into_iter()
249                    .map(|m| m.as_view())
250                    .collect_vec();
251                debug_assert_eq!(
252                    mats.len(),
253                    points.len(),
254                    "each matrix should have a corresponding set of evaluation points"
255                );
256                (mats, points)
257            })
258            .collect_vec();
259        let mats = mats_and_points
260            .iter()
261            .flat_map(|(mats, _)| mats)
262            .collect_vec();
263
264        let global_max_height = mats.iter().map(|m| m.height()).max().unwrap();
265        let log_global_max_height = log2_strict_usize(global_max_height);
266
267        // For each unique opening point z, we will find the largest degree bound
268        // for that point, and precompute 1/(z - X) for the largest subgroup (in bitrev order).
269        let inv_denoms = compute_inverse_denominators(&mats_and_points, Val::GENERATOR);
270
271        // Evaluate coset representations and write openings to the challenger
272        let all_opened_values = mats_and_points
273            .iter()
274            .map(|(mats, points)| {
275                izip!(mats.iter(), points.iter())
276                    .map(|(mat, points_for_mat)| {
277                        points_for_mat
278                            .iter()
279                            .map(|&point| {
280                                let _guard =
281                                    info_span!("evaluate matrix", dims = %mat.dimensions())
282                                        .entered();
283
284                                // Use Barycentric interpolation to evaluate the matrix at the given point.
285                                let ys =
286                                    info_span!("compute opened values with Lagrange interpolation")
287                                        .in_scope(|| {
288                                            let h = mat.height() >> self.fri.log_blowup;
289                                            let (low_coset, _) = mat.split_rows(h);
290                                            let mut inv_denoms =
291                                                inv_denoms.get(&point).unwrap()[..h].to_vec();
292                                            reverse_slice_index_bits(&mut inv_denoms);
293                                            interpolate_coset(
294                                                &BitReversalPerm::new_view(low_coset),
295                                                Val::GENERATOR,
296                                                point,
297                                                Some(&inv_denoms),
298                                            )
299                                        });
300                                ys.iter().for_each(|&y| challenger.observe_ext_element(y));
301                                ys
302                            })
303                            .collect_vec()
304                    })
305                    .collect_vec()
306            })
307            .collect_vec();
308
309        // Batch combination challenge
310        let alpha: Challenge = challenger.sample_ext_element();
311
312        let mut num_reduced = [0; 32];
313        let mut reduced_openings: [_; 32] = core::array::from_fn(|_| None);
314
315        for ((mats, points), openings_for_round) in
316            mats_and_points.iter().zip(all_opened_values.iter())
317        {
318            for (mat, points_for_mat, openings_for_mat) in
319                izip!(mats.iter(), points.iter(), openings_for_round.iter())
320            {
321                let _guard =
322                    info_span!("reduce matrix quotient", dims = %mat.dimensions()).entered();
323
324                let log_height = log2_strict_usize(mat.height());
325                let reduced_opening_for_log_height = reduced_openings[log_height]
326                    .get_or_insert_with(|| vec![Challenge::ZERO; mat.height()]);
327                debug_assert_eq!(reduced_opening_for_log_height.len(), mat.height());
328
329                let mat_compressed = info_span!("compress mat")
330                    .in_scope(|| mat.dot_ext_powers(alpha).collect::<Vec<_>>());
331
332                for (&point, openings) in points_for_mat.iter().zip(openings_for_mat) {
333                    let alpha_pow_offset = alpha.exp_u64(num_reduced[log_height] as u64);
334                    let reduced_openings: Challenge =
335                        dot_product(alpha.powers(), openings.iter().copied());
336
337                    info_span!("reduce rows").in_scope(|| {
338                        mat_compressed
339                            .par_iter()
340                            .zip(reduced_opening_for_log_height.par_iter_mut())
341                            // This might be longer, but zip will truncate to smaller subgroup
342                            // (which is ok because it's bitrev)
343                            .zip(inv_denoms.get(&point).unwrap().par_iter())
344                            .for_each(|((&reduced_row, ro), &inv_denom)| {
345                                *ro +=
346                                    alpha_pow_offset * (reduced_openings - reduced_row) * inv_denom
347                            });
348                    });
349
350                    num_reduced[log_height] += mat.width();
351                }
352            }
353        }
354
355        let fri_input = reduced_openings.into_iter().rev().flatten().collect_vec();
356
357        let g: TwoAdicFriGenericConfigForMmcs<Val, InputMmcs> =
358            TwoAdicFriGenericConfig(PhantomData);
359
360        let fri_proof = prover::prove(&g, &self.fri, fri_input, challenger, |index| {
361            rounds
362                .iter()
363                .map(|(data, _)| {
364                    let log_max_height = log2_strict_usize(self.mmcs.get_max_height(data));
365                    let bits_reduced = log_global_max_height - log_max_height;
366                    let reduced_index = index >> bits_reduced;
367                    let (opened_values, opening_proof) = self.mmcs.open_batch(reduced_index, data);
368                    BatchOpening {
369                        opened_values,
370                        opening_proof,
371                    }
372                })
373                .collect()
374        });
375
376        (all_opened_values, fri_proof)
377    }
378
379    fn verify(
380        &self,
381        // For each round:
382        rounds: Vec<(
383            Self::Commitment,
384            // for each matrix:
385            Vec<(
386                // its domain,
387                Self::Domain,
388                // for each point:
389                Vec<(
390                    // the point,
391                    Challenge,
392                    // values at the point
393                    Vec<Challenge>,
394                )>,
395            )>,
396        )>,
397        proof: &Self::Proof,
398        challenger: &mut Challenger,
399    ) -> Result<(), Self::Error> {
400        // Write evaluations to challenger
401        for (_, round) in rounds.iter() {
402            for (_, mat) in round.iter() {
403                for (_, point) in mat.iter() {
404                    point
405                        .iter()
406                        .for_each(|&opening| challenger.observe_ext_element(opening));
407                }
408            }
409        }
410
411        // Batch combination challenge
412        let alpha: Challenge = challenger.sample_ext_element();
413
414        let log_global_max_height =
415            proof.commit_phase_commits.len() + self.fri.log_blowup + self.fri.log_final_poly_len;
416
417        let g: TwoAdicFriGenericConfigForMmcs<Val, InputMmcs> =
418            TwoAdicFriGenericConfig(PhantomData);
419
420        verifier::verify(&g, &self.fri, proof, challenger, |index, input_proof| {
421            // TODO: separate this out into functions
422
423            // log_height -> (alpha_pow, reduced_opening)
424            let mut reduced_openings = BTreeMap::<usize, (Challenge, Challenge)>::new();
425
426            for (batch_opening, (batch_commit, mats)) in
427                zip_eq(input_proof, &rounds, FriError::InvalidProofShape)?
428            {
429                let batch_heights = mats
430                    .iter()
431                    .map(|(domain, _)| domain.size() << self.fri.log_blowup)
432                    .collect_vec();
433                let batch_dims = batch_heights
434                    .iter()
435                    // TODO: MMCS doesn't really need width; we put 0 for now.
436                    .map(|&height| Dimensions { width: 0, height })
437                    .collect_vec();
438
439                if let Some(batch_max_height) = batch_heights.iter().max() {
440                    let log_batch_max_height = log2_strict_usize(*batch_max_height);
441                    let bits_reduced = log_global_max_height - log_batch_max_height;
442                    let reduced_index = index >> bits_reduced;
443
444                    self.mmcs.verify_batch(
445                        batch_commit,
446                        &batch_dims,
447                        reduced_index,
448                        &batch_opening.opened_values,
449                        &batch_opening.opening_proof,
450                    )
451                } else {
452                    // Empty batch?
453                    self.mmcs.verify_batch(
454                        batch_commit,
455                        &[],
456                        0,
457                        &batch_opening.opened_values,
458                        &batch_opening.opening_proof,
459                    )
460                }
461                .map_err(FriError::InputError)?;
462
463                for (mat_opening, (mat_domain, mat_points_and_values)) in zip_eq(
464                    &batch_opening.opened_values,
465                    mats,
466                    FriError::InvalidProofShape,
467                )? {
468                    let log_height = log2_strict_usize(mat_domain.size()) + self.fri.log_blowup;
469
470                    let bits_reduced = log_global_max_height - log_height;
471                    let rev_reduced_index = reverse_bits_len(index >> bits_reduced, log_height);
472
473                    // todo: this can be nicer with domain methods?
474
475                    let x = Val::GENERATOR
476                        * Val::two_adic_generator(log_height).exp_u64(rev_reduced_index as u64);
477
478                    let (alpha_pow, ro) = reduced_openings
479                        .entry(log_height)
480                        .or_insert((Challenge::ONE, Challenge::ZERO));
481
482                    for (z, ps_at_z) in mat_points_and_values {
483                        for (&p_at_x, &p_at_z) in
484                            zip_eq(mat_opening, ps_at_z, FriError::InvalidProofShape)?
485                        {
486                            let quotient = (-p_at_z + p_at_x) / (-*z + x);
487                            *ro += *alpha_pow * quotient;
488                            *alpha_pow *= alpha;
489                        }
490                    }
491                }
492            }
493
494            // `reduced_openings` would have a log_height = log_blowup entry only if there was a
495            // trace matrix of height 1. In this case the reduced opening can be skipped as it will
496            // not be checked against any commit phase commit.
497            if let Some((_alpha_pow, ro)) = reduced_openings.remove(&self.fri.log_blowup) {
498                assert!(ro.is_zero());
499            }
500
501            // Return reduced openings descending by log_height.
502            Ok(reduced_openings
503                .into_iter()
504                .rev()
505                .map(|(log_height, (_alpha_pow, ro))| (log_height, ro))
506                .collect())
507        })?;
508
509        Ok(())
510    }
511}
512
513#[instrument(skip_all)]
514fn compute_inverse_denominators<F: TwoAdicField, EF: ExtensionField<F>, M: Matrix<F>>(
515    mats_and_points: &[(Vec<M>, &Vec<Vec<EF>>)],
516    coset_shift: F,
517) -> LinearMap<EF, Vec<EF>> {
518    let mut max_log_height_for_point: LinearMap<EF, usize> = LinearMap::new();
519    for (mats, points) in mats_and_points {
520        for (mat, points_for_mat) in izip!(mats, *points) {
521            let log_height = log2_strict_usize(mat.height());
522            for &z in points_for_mat {
523                if let Some(lh) = max_log_height_for_point.get_mut(&z) {
524                    *lh = core::cmp::max(*lh, log_height);
525                } else {
526                    max_log_height_for_point.insert(z, log_height);
527                }
528            }
529        }
530    }
531
532    // Compute the largest subgroup we will use, in bitrev order.
533    let max_log_height = *max_log_height_for_point.values().max().unwrap();
534    let mut subgroup = cyclic_subgroup_coset_known_order(
535        F::two_adic_generator(max_log_height),
536        coset_shift,
537        1 << max_log_height,
538    )
539    .collect_vec();
540    reverse_slice_index_bits(&mut subgroup);
541
542    max_log_height_for_point
543        .into_iter()
544        .map(|(z, log_height)| {
545            (
546                z,
547                batch_multiplicative_inverse(
548                    &subgroup[..(1 << log_height)]
549                        .iter()
550                        .map(|&x| z - x)
551                        .collect_vec(),
552                ),
553            )
554        })
555        .collect()
556}