p3_fri/
prover.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::iter;
4
5use itertools::{izip, Itertools};
6use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
7use p3_commit::Mmcs;
8use p3_dft::{Radix2Dit, TwoAdicSubgroupDft};
9use p3_field::{ExtensionField, Field, TwoAdicField};
10use p3_matrix::dense::RowMajorMatrix;
11use p3_util::{log2_strict_usize, reverse_slice_index_bits};
12use tracing::{debug_span, info_span, instrument};
13
14use crate::{CommitPhaseProofStep, FriConfig, FriGenericConfig, FriProof, QueryProof};
15
16#[instrument(name = "FRI prover", skip_all)]
17pub fn prove<G, Val, Challenge, M, Challenger>(
18    g: &G,
19    config: &FriConfig<M>,
20    inputs: Vec<Vec<Challenge>>,
21    challenger: &mut Challenger,
22    open_input: impl Fn(usize) -> G::InputProof,
23) -> FriProof<Challenge, M, Challenger::Witness, G::InputProof>
24where
25    Val: Field,
26    Challenge: ExtensionField<Val> + TwoAdicField,
27    M: Mmcs<Challenge>,
28    Challenger: FieldChallenger<Val> + GrindingChallenger + CanObserve<M::Commitment>,
29    G: FriGenericConfig<Challenge>,
30{
31    assert!(!inputs.is_empty());
32    assert!(
33        inputs
34            .iter()
35            .tuple_windows()
36            .all(|(l, r)| l.len() >= r.len()),
37        "Inputs are not sorted in descending order of length."
38    );
39
40    let log_max_height = log2_strict_usize(inputs[0].len());
41    let log_min_height = log2_strict_usize(inputs.last().unwrap().len());
42    if config.log_final_poly_len > 0 {
43        assert!(log_min_height > config.log_final_poly_len + config.log_blowup);
44    }
45
46    let commit_phase_result = commit_phase(g, config, inputs, challenger);
47
48    let pow_witness = challenger.grind(config.proof_of_work_bits);
49
50    let query_proofs = info_span!("query phase").in_scope(|| {
51        iter::repeat_with(|| challenger.sample_bits(log_max_height + g.extra_query_index_bits()))
52            .take(config.num_queries)
53            .map(|index| QueryProof {
54                input_proof: open_input(index),
55                commit_phase_openings: answer_query(
56                    config,
57                    &commit_phase_result.data,
58                    index >> g.extra_query_index_bits(),
59                ),
60            })
61            .collect()
62    });
63
64    FriProof {
65        commit_phase_commits: commit_phase_result.commits,
66        query_proofs,
67        final_poly: commit_phase_result.final_poly,
68        pow_witness,
69    }
70}
71
72struct CommitPhaseResult<F: Field, M: Mmcs<F>> {
73    commits: Vec<M::Commitment>,
74    data: Vec<M::ProverData<RowMajorMatrix<F>>>,
75    final_poly: Vec<F>,
76}
77
78#[instrument(name = "commit phase", skip_all)]
79fn commit_phase<G, Val, Challenge, M, Challenger>(
80    g: &G,
81    config: &FriConfig<M>,
82    inputs: Vec<Vec<Challenge>>,
83    challenger: &mut Challenger,
84) -> CommitPhaseResult<Challenge, M>
85where
86    Val: Field,
87    Challenge: ExtensionField<Val> + TwoAdicField,
88    M: Mmcs<Challenge>,
89    Challenger: FieldChallenger<Val> + CanObserve<M::Commitment>,
90    G: FriGenericConfig<Challenge>,
91{
92    let mut inputs_iter = inputs.into_iter().peekable();
93    let mut folded = inputs_iter.next().unwrap();
94    let mut commits = vec![];
95    let mut data = vec![];
96
97    while folded.len() > config.blowup() * config.final_poly_len() {
98        let leaves = RowMajorMatrix::new(folded, 2);
99        let (commit, prover_data) = config.mmcs.commit_matrix(leaves);
100        challenger.observe(commit.clone());
101
102        let beta: Challenge = challenger.sample_ext_element();
103        // We passed ownership of `current` to the MMCS, so get a reference to it
104        let leaves = config.mmcs.get_matrices(&prover_data).pop().unwrap();
105        folded = g.fold_matrix(beta, leaves.as_view());
106
107        commits.push(commit);
108        data.push(prover_data);
109
110        if let Some(v) = inputs_iter.next_if(|v| v.len() == folded.len()) {
111            izip!(&mut folded, v).for_each(|(c, x)| *c += x);
112        }
113    }
114
115    // After repeated folding steps, we end up working over a coset hJ instead of the original
116    // domain. The IDFT we apply operates over a subgroup J, not hJ. This means the polynomial we
117    // recover is G(x), where G(x) = F(hx), and F is the polynomial whose evaluations we actually
118    // observed. For our current construction, this does not cause issues since degree properties
119    // and zero-checks remain valid. If we changed our domain construction (e.g., using multiple
120    // cosets), we would need to carefully reconsider these assumptions.
121
122    reverse_slice_index_bits(&mut folded);
123    // TODO: For better performance, we could run the IDFT on only the first half
124    //       (or less, depending on `log_blowup`) of `final_poly`.
125    let final_poly = debug_span!("idft final poly").in_scope(|| Radix2Dit::default().idft(folded));
126
127    // The evaluation domain is "blown-up" relative to the polynomial degree of `final_poly`,
128    // so all coefficients after the first final_poly_len should be zero.
129    debug_assert!(
130        final_poly
131            .iter()
132            .skip(1 << config.log_final_poly_len)
133            .all(|x| x.is_zero()),
134        "All coefficients beyond final_poly_len must be zero"
135    );
136
137    // Observe all coefficients of the final polynomial.
138    for &x in &final_poly {
139        challenger.observe_ext_element(x);
140    }
141
142    CommitPhaseResult {
143        commits,
144        data,
145        final_poly,
146    }
147}
148
149fn answer_query<F, M>(
150    config: &FriConfig<M>,
151    commit_phase_commits: &[M::ProverData<RowMajorMatrix<F>>],
152    index: usize,
153) -> Vec<CommitPhaseProofStep<F, M>>
154where
155    F: Field,
156    M: Mmcs<F>,
157{
158    commit_phase_commits
159        .iter()
160        .enumerate()
161        .map(|(i, commit)| {
162            let index_i = index >> i;
163            let index_i_sibling = index_i ^ 1;
164            let index_pair = index_i >> 1;
165
166            let (mut opened_rows, opening_proof) = config.mmcs.open_batch(index_pair, commit);
167            assert_eq!(opened_rows.len(), 1);
168            let opened_row = opened_rows.pop().unwrap();
169            assert_eq!(opened_row.len(), 2, "Committed data should be in pairs");
170            let sibling_value = opened_row[index_i_sibling % 2];
171
172            CommitPhaseProofStep {
173                sibling_value,
174                opening_proof,
175            }
176        })
177        .collect()
178}