openvm_stark_backend/
engine.rs

1use itertools::{zip_eq, Itertools};
2
3use crate::{
4    air_builders::debug::debug_constraints_and_interactions,
5    config::{Com, PcsProof, RapPhaseSeqPartialProof, StarkGenericConfig, Val},
6    keygen::{
7        types::{MultiStarkProvingKey, MultiStarkVerifyingKey, StarkProvingKey},
8        MultiStarkKeygenBuilder,
9    },
10    proof::{OpeningProof, Proof},
11    prover::{
12        coordinator::Coordinator,
13        hal::{DeviceDataTransporter, ProverBackend, ProverDevice},
14        types::{AirProofRawInput, AirProvingContext, DeviceMultiStarkProvingKey, ProvingContext},
15        Prover,
16    },
17    verifier::{MultiTraceStarkVerifier, VerificationError},
18    AirRef,
19};
20
21/// Data for verifying a Stark proof.
22pub struct VerificationData<SC: StarkGenericConfig> {
23    pub vk: MultiStarkVerifyingKey<SC>,
24    pub proof: Proof<SC>,
25}
26
27/// A helper trait to collect the different steps in multi-trace STARK
28/// keygen and proving. Currently this trait is CPU specific.
29pub trait StarkEngine
30where
31    <Self::PB as ProverBackend>::OpeningProof:
32        Into<OpeningProof<PcsProof<Self::SC>, <Self::SC as StarkGenericConfig>::Challenge>>,
33    <Self::PB as ProverBackend>::RapPartialProof: Into<Option<RapPhaseSeqPartialProof<Self::SC>>>,
34{
35    type SC: StarkGenericConfig;
36    type PB: ProverBackend<
37        Val = Val<Self::SC>,
38        Challenge = <Self::SC as StarkGenericConfig>::Challenge,
39        Commitment = Com<Self::SC>,
40        Challenger = <Self::SC as StarkGenericConfig>::Challenger,
41    >;
42    type PD: ProverDevice<Self::PB> + DeviceDataTransporter<Self::SC, Self::PB>;
43
44    /// Stark config
45    fn config(&self) -> &Self::SC;
46
47    /// During keygen, the circuit may be optimized but it will **try** to keep the
48    /// constraint degree at most this value.
49    fn max_constraint_degree(&self) -> Option<usize> {
50        None
51    }
52
53    /// Creates a new challenger with a deterministic state.
54    /// Creating new challenger for prover and verifier separately will result in
55    /// them having the same starting state.
56    fn new_challenger(&self) -> <Self::SC as StarkGenericConfig>::Challenger;
57
58    fn keygen_builder(&self) -> MultiStarkKeygenBuilder<'_, Self::SC> {
59        let mut builder = MultiStarkKeygenBuilder::new(self.config());
60        if let Some(max_constraint_degree) = self.max_constraint_degree() {
61            builder.set_max_constraint_degree(max_constraint_degree);
62        }
63        builder
64    }
65
66    fn device(&self) -> &Self::PD;
67
68    fn prover(&self) -> Coordinator<Self::SC, Self::PB, Self::PD>;
69
70    fn verifier(&self) -> MultiTraceStarkVerifier<'_, Self::SC> {
71        MultiTraceStarkVerifier::new(self.config())
72    }
73
74    /// Add AIRs and get AIR IDs
75    fn set_up_keygen_builder(
76        &self,
77        keygen_builder: &mut MultiStarkKeygenBuilder<'_, Self::SC>,
78        airs: &[AirRef<Self::SC>],
79    ) -> Vec<usize> {
80        airs.iter()
81            .map(|air| keygen_builder.add_air(air.clone()))
82            .collect()
83    }
84
85    /// As a convenience, this function also transports the proving key from host to device.
86    /// Note that the [Self::prove] function starts from a [DeviceMultiStarkProvingKey],
87    /// which should be used if the proving key is already cached in device memory.
88    fn prove_then_verify(
89        &self,
90        pk: &MultiStarkProvingKey<Self::SC>,
91        ctx: ProvingContext<Self::PB>,
92    ) -> Result<Proof<Self::SC>, VerificationError> {
93        let pk_device = self.device().transport_pk_to_device(pk);
94        let proof = self.prove(&pk_device, ctx);
95        self.verify(&pk.get_vk(), &proof)?;
96        Ok(proof)
97    }
98
99    fn prove(
100        &self,
101        pk: &DeviceMultiStarkProvingKey<Self::PB>,
102        ctx: ProvingContext<Self::PB>,
103    ) -> Proof<Self::SC> {
104        let mpk_view = pk.view(ctx.air_ids());
105        let mut prover = self.prover();
106        let proof = prover.prove(mpk_view, ctx);
107        proof.into()
108    }
109
110    fn verify(
111        &self,
112        vk: &MultiStarkVerifyingKey<Self::SC>,
113        proof: &Proof<Self::SC>,
114    ) -> Result<(), VerificationError> {
115        let mut challenger = self.new_challenger();
116        let verifier = self.verifier();
117        verifier.verify(&mut challenger, vk, proof)
118    }
119
120    // mpk can be removed if we use BaseAir trait to regenerate preprocessed traces
121    fn debug(
122        &self,
123        airs: &[AirRef<Self::SC>],
124        pk: &[StarkProvingKey<Self::SC>],
125        proof_inputs: &[AirProofRawInput<Val<Self::SC>>],
126    ) {
127        let (trace_views, pvs): (Vec<_>, Vec<_>) = proof_inputs
128            .iter()
129            .map(|input| {
130                let mut views = input
131                    .cached_mains
132                    .iter()
133                    .map(|trace| trace.as_view())
134                    .collect_vec();
135                if let Some(trace) = input.common_main.as_ref() {
136                    views.push(trace.as_view());
137                }
138                (views, input.public_values.clone())
139            })
140            .unzip();
141        debug_constraints_and_interactions(airs, pk, &trace_views, &pvs);
142    }
143
144    /// Runs a single end-to-end test for a given set of chips and traces partitions.
145    /// This includes proving/verifying key generation, creating a proof, and verifying the proof.
146    fn run_test_impl(
147        &self,
148        airs: Vec<AirRef<Self::SC>>,
149        ctx: Vec<AirProvingContext<Self::PB>>,
150    ) -> Result<VerificationData<Self::SC>, VerificationError> {
151        let mut keygen_builder = self.keygen_builder();
152        let air_ids = self.set_up_keygen_builder(&mut keygen_builder, &airs);
153        let pk = keygen_builder.generate_pk();
154        let device = self.prover().device;
155        let proof_inputs = ctx
156            .iter()
157            .map(|air_ctx| {
158                let cached_mains = air_ctx
159                    .cached_mains
160                    .iter()
161                    .map(|pre| device.transport_matrix_from_device_to_host(&pre.trace))
162                    .collect_vec();
163                let common_main = air_ctx
164                    .common_main
165                    .as_ref()
166                    .map(|m| device.transport_matrix_from_device_to_host(m));
167                let public_values = air_ctx.public_values.clone();
168                AirProofRawInput {
169                    cached_mains,
170                    common_main,
171                    public_values,
172                }
173            })
174            .collect_vec();
175        self.debug(&airs, &pk.per_air, &proof_inputs);
176        let vk = pk.get_vk();
177        let ctx = ProvingContext {
178            per_air: zip_eq(air_ids, ctx).collect(),
179        };
180        let proof = self.prove_then_verify(&pk, ctx)?;
181        Ok(VerificationData { vk, proof })
182    }
183}