1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::Itertools;
5use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
6use p3_commit::Mmcs;
7use p3_field::{ExtensionField, Field, TwoAdicField};
8use p3_matrix::Dimensions;
9use p3_util::reverse_bits_len;
10use p3_util::zip_eq::zip_eq;
11
12use crate::{CommitPhaseProofStep, FriConfig, FriGenericConfig, FriProof};
13
14#[derive(Debug)]
15pub enum FriError<CommitMmcsErr, InputError> {
16 InvalidProofShape,
17 CommitPhaseMmcsError(CommitMmcsErr),
18 InputError(InputError),
19 FinalPolyMismatch,
20 InvalidPowWitness,
21 MissingInput,
22}
23
24pub fn verify<G, Val, Challenge, M, Challenger>(
25 g: &G,
26 config: &FriConfig<M>,
27 proof: &FriProof<Challenge, M, Challenger::Witness, G::InputProof>,
28 challenger: &mut Challenger,
29 open_input: impl Fn(
30 usize,
31 &G::InputProof,
32 ) -> Result<Vec<(usize, Challenge)>, FriError<M::Error, G::InputError>>,
33) -> Result<(), FriError<M::Error, G::InputError>>
34where
35 Val: Field,
36 Challenge: ExtensionField<Val> + TwoAdicField,
37 M: Mmcs<Challenge>,
38 Challenger: FieldChallenger<Val> + GrindingChallenger + CanObserve<M::Commitment>,
39 G: FriGenericConfig<Challenge>,
40{
41 let betas: Vec<Challenge> = proof
42 .commit_phase_commits
43 .iter()
44 .map(|comm| {
45 challenger.observe(comm.clone());
46 challenger.sample_ext_element()
47 })
48 .collect();
49
50 if proof.final_poly.len() != config.final_poly_len() {
51 return Err(FriError::InvalidProofShape);
52 }
53
54 proof
56 .final_poly
57 .iter()
58 .for_each(|x| challenger.observe_ext_element(*x));
59
60 if proof.query_proofs.len() != config.num_queries {
61 return Err(FriError::InvalidProofShape);
62 }
63
64 if !challenger.check_witness(config.proof_of_work_bits, proof.pow_witness) {
66 return Err(FriError::InvalidPowWitness);
67 }
68
69 let log_max_height =
71 proof.commit_phase_commits.len() + config.log_blowup + config.log_final_poly_len;
72
73 let log_final_height = config.log_blowup + config.log_final_poly_len;
75
76 for qp in &proof.query_proofs {
77 let index = challenger.sample_bits(log_max_height + g.extra_query_index_bits());
78 let ro = open_input(index, &qp.input_proof)?;
79
80 debug_assert!(
81 ro.iter().tuple_windows().all(|((l, _), (r, _))| l > r),
82 "reduced openings sorted by height descending"
83 );
84
85 let mut domain_index = index >> g.extra_query_index_bits();
86
87 let folded_eval = verify_query(
92 g,
93 config,
94 &mut domain_index,
95 zip_eq(
96 zip_eq(
97 &betas,
98 &proof.commit_phase_commits,
99 FriError::InvalidProofShape,
100 )?,
101 &qp.commit_phase_openings,
102 FriError::InvalidProofShape,
103 )?,
104 ro,
105 log_max_height,
106 log_final_height,
107 )?;
108
109 let x = Challenge::two_adic_generator(log_max_height)
113 .exp_u64(reverse_bits_len(domain_index, log_max_height) as u64);
114
115 let mut eval = Challenge::ZERO;
117 for &coeff in proof.final_poly.iter().rev() {
118 eval = eval * x + coeff;
119 }
120
121 if eval != folded_eval {
122 return Err(FriError::FinalPolyMismatch);
123 }
124 }
125
126 Ok(())
127}
128
129type CommitStep<'a, F, M> = (
130 (
131 &'a F, &'a <M as Mmcs<F>>::Commitment, ),
134 &'a CommitPhaseProofStep<F, M>, );
136
137fn verify_query<'a, G, F, M>(
145 g: &G,
146 config: &FriConfig<M>,
147 index: &mut usize,
148 steps: impl ExactSizeIterator<Item = CommitStep<'a, F, M>>,
149 reduced_openings: Vec<(usize, F)>,
150 log_max_height: usize,
151 log_final_height: usize,
152) -> Result<F, FriError<M::Error, G::InputError>>
153where
154 F: Field,
155 M: Mmcs<F> + 'a,
156 G: FriGenericConfig<F>,
157{
158 let mut ro_iter = reduced_openings.into_iter().peekable();
159 let mut folded_eval = ro_iter
160 .next_if(|(lh, _)| *lh == log_max_height)
161 .map(|(_, ro)| ro)
162 .ok_or(FriError::MissingInput)?;
163
164 for (log_folded_height, ((&beta, comm), opening)) in zip_eq(
167 (log_final_height..log_max_height).rev(),
168 steps,
169 FriError::InvalidProofShape,
170 )? {
171 let index_sibling = *index ^ 1;
173
174 let mut evals = vec![folded_eval; 2];
175 evals[index_sibling % 2] = opening.sibling_value;
176
177 let dims = &[Dimensions {
178 width: 2,
179 height: 1 << log_folded_height,
180 }];
181
182 *index >>= 1;
184
185 config
187 .mmcs
188 .verify_batch(comm, dims, *index, &[evals.clone()], &opening.opening_proof)
189 .map_err(FriError::CommitPhaseMmcsError)?;
190
191 folded_eval = g.fold_row(*index, log_folded_height, beta, evals.into_iter());
193
194 if let Some((_, ro)) = ro_iter.next_if(|(lh, _)| *lh == log_folded_height) {
205 folded_eval += beta.square() * ro;
206 }
207 }
208
209 if ro_iter.next().is_some() {
211 return Err(FriError::InvalidProofShape);
212 }
213
214 Ok(folded_eval)
217}