openvm_stark_backend/
engine.rsuse std::sync::Arc;
use itertools::izip;
use p3_matrix::dense::DenseMatrix;
use crate::{
config::{StarkGenericConfig, Val},
keygen::{
types::{MultiStarkProvingKey, MultiStarkVerifyingKey},
MultiStarkKeygenBuilder,
},
prover::{
types::{AirProofInput, Proof, ProofInput, TraceCommitter},
MultiTraceStarkProver,
},
rap::AnyRap,
verifier::{MultiTraceStarkVerifier, VerificationError},
};
pub struct VerificationData<SC: StarkGenericConfig> {
pub vk: MultiStarkVerifyingKey<SC>,
pub proof: Proof<SC>,
}
pub trait StarkEngine<SC: StarkGenericConfig> {
fn config(&self) -> &SC;
fn new_challenger(&self) -> SC::Challenger;
fn keygen_builder(&self) -> MultiStarkKeygenBuilder<SC> {
MultiStarkKeygenBuilder::new(self.config())
}
fn prover(&self) -> MultiTraceStarkProver<SC> {
MultiTraceStarkProver::new(self.config())
}
fn verifier(&self) -> MultiTraceStarkVerifier<SC> {
MultiTraceStarkVerifier::new(self.config())
}
fn run_simple_test_impl(
&self,
chips: Vec<Arc<dyn AnyRap<SC>>>,
traces: Vec<DenseMatrix<Val<SC>>>,
public_values: Vec<Vec<Val<SC>>>,
) -> Result<VerificationData<SC>, VerificationError> {
self.run_test_impl(AirProofInput::multiple_simple(chips, traces, public_values))
}
fn run_test_impl(
&self,
air_proof_inputs: Vec<AirProofInput<SC>>,
) -> Result<VerificationData<SC>, VerificationError> {
let mut keygen_builder = self.keygen_builder();
let air_ids = self.set_up_keygen_builder(&mut keygen_builder, &air_proof_inputs);
let proof_input = ProofInput {
per_air: izip!(air_ids, air_proof_inputs).collect(),
};
let pk = keygen_builder.generate_pk();
let vk = pk.get_vk();
let proof = self.prove(&pk, proof_input);
self.verify(&vk, &proof)?;
Ok(VerificationData { vk, proof })
}
fn set_up_keygen_builder(
&self,
keygen_builder: &mut MultiStarkKeygenBuilder<'_, SC>,
air_proof_inputs: &[AirProofInput<SC>],
) -> Vec<usize> {
air_proof_inputs
.iter()
.map(|air_proof_input| {
let air = air_proof_input.air.clone();
assert_eq!(
air_proof_input.raw.cached_mains.len(),
air.cached_main_widths().len()
);
let common_main_width = air.common_main_width();
if common_main_width == 0 {
assert!(air_proof_input.raw.common_main.is_none());
} else {
assert_eq!(
air_proof_input.raw.common_main.as_ref().unwrap().width,
air.common_main_width()
);
}
keygen_builder.add_air(air)
})
.collect()
}
fn prove_then_verify(
&self,
pk: &MultiStarkProvingKey<SC>,
proof_input: ProofInput<SC>,
) -> Result<(), VerificationError> {
let proof = self.prove(pk, proof_input);
self.verify(&pk.get_vk(), &proof)
}
fn prove(&self, pk: &MultiStarkProvingKey<SC>, proof_input: ProofInput<SC>) -> Proof<SC> {
let prover = self.prover();
let committer = TraceCommitter::new(prover.pcs());
let air_proof_inputs = proof_input
.per_air
.into_iter()
.map(|(air_id, mut air_proof_input)| {
if air_proof_input.cached_mains_pdata.is_empty()
&& !air_proof_input.raw.cached_mains.is_empty()
{
air_proof_input.cached_mains_pdata = air_proof_input
.raw
.cached_mains
.iter()
.map(|trace| committer.commit(vec![trace.as_ref().clone()]))
.collect();
} else {
assert_eq!(
air_proof_input.cached_mains_pdata.len(),
air_proof_input.raw.cached_mains.len()
);
}
(air_id, air_proof_input)
})
.collect();
let proof_input = ProofInput {
per_air: air_proof_inputs,
};
let mut challenger = self.new_challenger();
#[cfg(feature = "bench-metrics")]
let prove_start = std::time::Instant::now();
let _proof = prover.prove(&mut challenger, pk, proof_input);
#[cfg(feature = "bench-metrics")]
metrics::gauge!("stark_prove_excluding_trace_time_ms")
.set(prove_start.elapsed().as_millis() as f64);
_proof
}
fn verify(
&self,
vk: &MultiStarkVerifyingKey<SC>,
proof: &Proof<SC>,
) -> Result<(), VerificationError> {
let mut challenger = self.new_challenger();
let verifier = self.verifier();
verifier.verify(&mut challenger, vk, proof)
}
}