openvm_stark_backend/prover/
types.rsuse std::sync::Arc;
use derivative::Derivative;
use itertools::Itertools;
use p3_field::Field;
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use serde::{Deserialize, Serialize};
pub use super::trace::{ProverTraceData, TraceCommitter};
use crate::{
config::{Com, RapPhaseSeqPartialProof, StarkGenericConfig, Val},
keygen::types::{MultiStarkProvingKey, MultiStarkVerifyingKey},
prover::opener::OpeningProof,
rap::AnyRap,
};
#[derive(Serialize, Deserialize, Derivative)]
#[serde(bound(
serialize = "Com<SC>: Serialize",
deserialize = "Com<SC>: Deserialize<'de>"
))]
#[derivative(Clone(bound = "Com<SC>: Clone"))]
pub struct Commitments<SC: StarkGenericConfig> {
pub main_trace: Vec<Com<SC>>,
pub after_challenge: Vec<Com<SC>>,
pub quotient: Com<SC>,
}
#[derive(Serialize, Deserialize, Derivative)]
#[serde(bound = "")]
#[derivative(Clone(bound = "Com<SC>: Clone"))]
pub struct Proof<SC: StarkGenericConfig> {
pub commitments: Commitments<SC>,
pub opening: OpeningProof<SC>,
pub per_air: Vec<AirProofData<SC>>,
pub rap_phase_seq_proof: Option<RapPhaseSeqPartialProof<SC>>,
}
#[derive(Serialize, Deserialize, Derivative)]
#[serde(bound = "")]
#[derivative(Clone(bound = "SC::Challenge: Clone"))]
pub struct AirProofData<SC: StarkGenericConfig> {
pub air_id: usize,
pub degree: usize,
pub exposed_values_after_challenge: Vec<Vec<SC::Challenge>>,
pub public_values: Vec<Val<SC>>,
}
pub struct ProofInput<SC: StarkGenericConfig> {
pub per_air: Vec<(usize, AirProofInput<SC>)>,
}
impl<SC: StarkGenericConfig> ProofInput<SC> {
pub fn new(per_air: Vec<(usize, AirProofInput<SC>)>) -> Self {
Self { per_air }
}
pub fn into_air_proof_input_vec(self) -> Vec<AirProofInput<SC>> {
self.per_air.into_iter().map(|(_, x)| x).collect()
}
}
#[derive(Serialize, Deserialize, Derivative)]
#[serde(bound(
serialize = "ProverTraceData<SC>: Serialize",
deserialize = "ProverTraceData<SC>: Deserialize<'de>"
))]
#[derivative(Clone(bound = "Com<SC>: Clone"))]
pub struct CommittedTraceData<SC: StarkGenericConfig> {
pub raw_data: Arc<RowMajorMatrix<Val<SC>>>,
pub prover_data: ProverTraceData<SC>,
}
#[derive(Derivative)]
#[derivative(Clone(bound = "Com<SC>: Clone"))]
pub struct AirProofInput<SC: StarkGenericConfig> {
pub air: Arc<dyn AnyRap<SC>>,
pub cached_mains_pdata: Vec<ProverTraceData<SC>>,
pub raw: AirProofRawInput<Val<SC>>,
}
#[derive(Clone, Debug)]
pub struct AirProofRawInput<F: Field> {
pub cached_mains: Vec<Arc<RowMajorMatrix<F>>>,
pub common_main: Option<RowMajorMatrix<F>>,
pub public_values: Vec<F>,
}
impl<SC: StarkGenericConfig> Proof<SC> {
pub fn get_air_ids(&self) -> Vec<usize> {
self.per_air.iter().map(|p| p.air_id).collect()
}
pub fn get_public_values(&self) -> Vec<Vec<Val<SC>>> {
self.per_air
.iter()
.map(|p| p.public_values.clone())
.collect()
}
}
impl<SC: StarkGenericConfig> ProofInput<SC> {
pub fn sort(&mut self) {
self.per_air.sort_by_key(|p| p.0);
}
}
impl<SC: StarkGenericConfig> MultiStarkVerifyingKey<SC> {
pub fn validate(&self, proof_input: &ProofInput<SC>) -> bool {
if !proof_input
.per_air
.iter()
.all(|input| input.0 < self.per_air.len())
{
return false;
}
if !proof_input
.per_air
.iter()
.tuple_windows()
.all(|(a, b)| a.0 < b.0)
{
return false;
}
true
}
}
impl<SC: StarkGenericConfig> MultiStarkProvingKey<SC> {
pub fn validate(&self, proof_input: &ProofInput<SC>) -> bool {
self.get_vk().validate(proof_input)
}
}
impl<F: Field> AirProofRawInput<F> {
pub fn height(&self) -> usize {
let mut height = None;
for m in self.cached_mains.iter() {
if let Some(h) = height {
assert_eq!(h, m.height());
} else {
height = Some(m.height());
}
}
let common_h = self.common_main.as_ref().map(|trace| trace.height());
if let Some(h) = height {
if let Some(common_h) = common_h {
assert_eq!(h, common_h);
}
h
} else {
common_h.unwrap_or(0)
}
}
}