use alloc::vec;
use alloc::vec::Vec;
use itertools::Itertools;
use p3_air::{Air, BaseAir};
use p3_challenger::{CanObserve, CanSample, FieldChallenger};
use p3_commit::{Pcs, PolynomialSpace};
use p3_field::{AbstractExtensionField, AbstractField, Field};
use p3_matrix::dense::RowMajorMatrixView;
use p3_matrix::stack::VerticalPair;
use tracing::instrument;
use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder};
use crate::{PcsError, Proof, StarkGenericConfig, Val, VerifierConstraintFolder};
#[instrument(skip_all)]
pub fn verify<SC, A>(
config: &SC,
air: &A,
challenger: &mut SC::Challenger,
proof: &Proof<SC>,
public_values: &Vec<Val<SC>>,
) -> Result<(), VerificationError<PcsError<SC>>>
where
SC: StarkGenericConfig,
A: Air<SymbolicAirBuilder<Val<SC>>> + for<'a> Air<VerifierConstraintFolder<'a, SC>>,
{
let Proof {
commitments,
opened_values,
opening_proof,
degree_bits,
} = proof;
let degree = 1 << degree_bits;
let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(air, 0, public_values.len());
let quotient_degree = 1 << log_quotient_degree;
let pcs = config.pcs();
let trace_domain = pcs.natural_domain_for_degree(degree);
let quotient_domain =
trace_domain.create_disjoint_domain(1 << (degree_bits + log_quotient_degree));
let quotient_chunks_domains = quotient_domain.split_domains(quotient_degree);
let air_width = <A as BaseAir<Val<SC>>>::width(air);
let valid_shape = opened_values.trace_local.len() == air_width
&& opened_values.trace_next.len() == air_width
&& opened_values.quotient_chunks.len() == quotient_degree
&& opened_values
.quotient_chunks
.iter()
.all(|qc| qc.len() == <SC::Challenge as AbstractExtensionField<Val<SC>>>::D);
if !valid_shape {
return Err(VerificationError::InvalidProofShape);
}
challenger.observe(Val::<SC>::from_canonical_usize(proof.degree_bits));
challenger.observe(commitments.trace.clone());
challenger.observe_slice(public_values);
let alpha: SC::Challenge = challenger.sample_ext_element();
challenger.observe(commitments.quotient_chunks.clone());
let zeta: SC::Challenge = challenger.sample();
let zeta_next = trace_domain.next_point(zeta).unwrap();
pcs.verify(
vec![
(
commitments.trace.clone(),
vec![(
trace_domain,
vec![
(zeta, opened_values.trace_local.clone()),
(zeta_next, opened_values.trace_next.clone()),
],
)],
),
(
commitments.quotient_chunks.clone(),
quotient_chunks_domains
.iter()
.zip(&opened_values.quotient_chunks)
.map(|(domain, values)| (*domain, vec![(zeta, values.clone())]))
.collect_vec(),
),
],
opening_proof,
challenger,
)
.map_err(VerificationError::InvalidOpeningArgument)?;
let zps = quotient_chunks_domains
.iter()
.enumerate()
.map(|(i, domain)| {
quotient_chunks_domains
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, other_domain)| {
other_domain.zp_at_point(zeta)
* other_domain.zp_at_point(domain.first_point()).inverse()
})
.product::<SC::Challenge>()
})
.collect_vec();
let quotient = opened_values
.quotient_chunks
.iter()
.enumerate()
.map(|(ch_i, ch)| {
ch.iter()
.enumerate()
.map(|(e_i, &c)| zps[ch_i] * SC::Challenge::monomial(e_i) * c)
.sum::<SC::Challenge>()
})
.sum::<SC::Challenge>();
let sels = trace_domain.selectors_at_point(zeta);
let main = VerticalPair::new(
RowMajorMatrixView::new_row(&opened_values.trace_local),
RowMajorMatrixView::new_row(&opened_values.trace_next),
);
let mut folder = VerifierConstraintFolder {
main,
public_values,
is_first_row: sels.is_first_row,
is_last_row: sels.is_last_row,
is_transition: sels.is_transition,
alpha,
accumulator: SC::Challenge::ZERO,
};
air.eval(&mut folder);
let folded_constraints = folder.accumulator;
if folded_constraints * sels.inv_zeroifier != quotient {
return Err(VerificationError::OodEvaluationMismatch);
}
Ok(())
}
#[derive(Debug)]
pub enum VerificationError<PcsErr> {
InvalidProofShape,
InvalidOpeningArgument(PcsErr),
OodEvaluationMismatch,
}