1use alloc::vec;
2use alloc::vec::Vec;
3use core::iter;
4
5use itertools::{izip, Itertools};
6use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
7use p3_commit::Mmcs;
8use p3_dft::{Radix2Dit, TwoAdicSubgroupDft};
9use p3_field::{ExtensionField, Field, TwoAdicField};
10use p3_matrix::dense::RowMajorMatrix;
11use p3_util::{log2_strict_usize, reverse_slice_index_bits};
12use tracing::{debug_span, info_span, instrument};
13
14use crate::{CommitPhaseProofStep, FriConfig, FriGenericConfig, FriProof, QueryProof};
15
16#[instrument(name = "FRI prover", skip_all)]
17pub fn prove<G, Val, Challenge, M, Challenger>(
18 g: &G,
19 config: &FriConfig<M>,
20 inputs: Vec<Vec<Challenge>>,
21 challenger: &mut Challenger,
22 open_input: impl Fn(usize) -> G::InputProof,
23) -> FriProof<Challenge, M, Challenger::Witness, G::InputProof>
24where
25 Val: Field,
26 Challenge: ExtensionField<Val> + TwoAdicField,
27 M: Mmcs<Challenge>,
28 Challenger: FieldChallenger<Val> + GrindingChallenger + CanObserve<M::Commitment>,
29 G: FriGenericConfig<Challenge>,
30{
31 assert!(!inputs.is_empty());
32 assert!(
33 inputs
34 .iter()
35 .tuple_windows()
36 .all(|(l, r)| l.len() >= r.len()),
37 "Inputs are not sorted in descending order of length."
38 );
39
40 let log_max_height = log2_strict_usize(inputs[0].len());
41 let log_min_height = log2_strict_usize(inputs.last().unwrap().len());
42 if config.log_final_poly_len > 0 {
43 assert!(log_min_height > config.log_final_poly_len + config.log_blowup);
44 }
45
46 let commit_phase_result = commit_phase(g, config, inputs, challenger);
47
48 let pow_witness = challenger.grind(config.proof_of_work_bits);
49
50 let query_proofs = info_span!("query phase").in_scope(|| {
51 iter::repeat_with(|| challenger.sample_bits(log_max_height + g.extra_query_index_bits()))
52 .take(config.num_queries)
53 .map(|index| QueryProof {
54 input_proof: open_input(index),
55 commit_phase_openings: answer_query(
56 config,
57 &commit_phase_result.data,
58 index >> g.extra_query_index_bits(),
59 ),
60 })
61 .collect()
62 });
63
64 FriProof {
65 commit_phase_commits: commit_phase_result.commits,
66 query_proofs,
67 final_poly: commit_phase_result.final_poly,
68 pow_witness,
69 }
70}
71
72struct CommitPhaseResult<F: Field, M: Mmcs<F>> {
73 commits: Vec<M::Commitment>,
74 data: Vec<M::ProverData<RowMajorMatrix<F>>>,
75 final_poly: Vec<F>,
76}
77
78#[instrument(name = "commit phase", skip_all)]
79fn commit_phase<G, Val, Challenge, M, Challenger>(
80 g: &G,
81 config: &FriConfig<M>,
82 inputs: Vec<Vec<Challenge>>,
83 challenger: &mut Challenger,
84) -> CommitPhaseResult<Challenge, M>
85where
86 Val: Field,
87 Challenge: ExtensionField<Val> + TwoAdicField,
88 M: Mmcs<Challenge>,
89 Challenger: FieldChallenger<Val> + CanObserve<M::Commitment>,
90 G: FriGenericConfig<Challenge>,
91{
92 let mut inputs_iter = inputs.into_iter().peekable();
93 let mut folded = inputs_iter.next().unwrap();
94 let mut commits = vec![];
95 let mut data = vec![];
96
97 while folded.len() > config.blowup() * config.final_poly_len() {
98 let leaves = RowMajorMatrix::new(folded, 2);
99 let (commit, prover_data) = config.mmcs.commit_matrix(leaves);
100 challenger.observe(commit.clone());
101
102 let beta: Challenge = challenger.sample_ext_element();
103 let leaves = config.mmcs.get_matrices(&prover_data).pop().unwrap();
105 folded = g.fold_matrix(beta, leaves.as_view());
106
107 commits.push(commit);
108 data.push(prover_data);
109
110 if let Some(v) = inputs_iter.next_if(|v| v.len() == folded.len()) {
111 izip!(&mut folded, v).for_each(|(c, x)| *c += x);
112 }
113 }
114
115 reverse_slice_index_bits(&mut folded);
123 let final_poly = debug_span!("idft final poly").in_scope(|| Radix2Dit::default().idft(folded));
126
127 debug_assert!(
130 final_poly
131 .iter()
132 .skip(1 << config.log_final_poly_len)
133 .all(|x| x.is_zero()),
134 "All coefficients beyond final_poly_len must be zero"
135 );
136
137 for &x in &final_poly {
139 challenger.observe_ext_element(x);
140 }
141
142 CommitPhaseResult {
143 commits,
144 data,
145 final_poly,
146 }
147}
148
149fn answer_query<F, M>(
150 config: &FriConfig<M>,
151 commit_phase_commits: &[M::ProverData<RowMajorMatrix<F>>],
152 index: usize,
153) -> Vec<CommitPhaseProofStep<F, M>>
154where
155 F: Field,
156 M: Mmcs<F>,
157{
158 commit_phase_commits
159 .iter()
160 .enumerate()
161 .map(|(i, commit)| {
162 let index_i = index >> i;
163 let index_i_sibling = index_i ^ 1;
164 let index_pair = index_i >> 1;
165
166 let (mut opened_rows, opening_proof) = config.mmcs.open_batch(index_pair, commit);
167 assert_eq!(opened_rows.len(), 1);
168 let opened_row = opened_rows.pop().unwrap();
169 assert_eq!(opened_row.len(), 2, "Committed data should be in pairs");
170 let sibling_value = opened_row[index_i_sibling % 2];
171
172 CommitPhaseProofStep {
173 sibling_value,
174 opening_proof,
175 }
176 })
177 .collect()
178}