openvm_stark_backend/
engine.rs

1use itertools::{zip_eq, Itertools};
2
3use crate::{
4    air_builders::debug::debug_constraints_and_interactions,
5    config::{Com, RapPhaseSeqPartialProof, StarkGenericConfig, Val},
6    keygen::{
7        types::{MultiStarkProvingKey, MultiStarkVerifyingKey, StarkProvingKey},
8        MultiStarkKeygenBuilder,
9    },
10    proof::{OpeningProof, Proof},
11    prover::{
12        coordinator::Coordinator,
13        hal::{DeviceDataTransporter, ProverBackend, ProverDevice},
14        types::{AirProofRawInput, AirProvingContext, DeviceMultiStarkProvingKey, ProvingContext},
15        Prover,
16    },
17    verifier::{MultiTraceStarkVerifier, VerificationError},
18    AirRef,
19};
20
21/// Data for verifying a Stark proof.
22pub struct VerificationData<SC: StarkGenericConfig> {
23    pub vk: MultiStarkVerifyingKey<SC>,
24    pub proof: Proof<SC>,
25}
26
27/// A helper trait to collect the different steps in multi-trace STARK
28/// keygen and proving. Currently this trait is CPU specific.
29pub trait StarkEngine
30where
31    <Self::PB as ProverBackend>::OpeningProof: Into<OpeningProof<Self::SC>>,
32    <Self::PB as ProverBackend>::RapPartialProof: Into<Option<RapPhaseSeqPartialProof<Self::SC>>>,
33{
34    type SC: StarkGenericConfig;
35    type PB: ProverBackend<
36        Val = Val<Self::SC>,
37        Challenge = <Self::SC as StarkGenericConfig>::Challenge,
38        Commitment = Com<Self::SC>,
39        Challenger = <Self::SC as StarkGenericConfig>::Challenger,
40    >;
41    type PD: ProverDevice<Self::PB> + DeviceDataTransporter<Self::SC, Self::PB>;
42
43    /// Stark config
44    fn config(&self) -> &Self::SC;
45
46    /// During keygen, the circuit may be optimized but it will **try** to keep the
47    /// constraint degree at most this value.
48    fn max_constraint_degree(&self) -> Option<usize> {
49        None
50    }
51
52    /// Creates a new challenger with a deterministic state.
53    /// Creating new challenger for prover and verifier separately will result in
54    /// them having the same starting state.
55    fn new_challenger(&self) -> <Self::SC as StarkGenericConfig>::Challenger;
56
57    fn keygen_builder(&self) -> MultiStarkKeygenBuilder<'_, Self::SC> {
58        let mut builder = MultiStarkKeygenBuilder::new(self.config());
59        if let Some(max_constraint_degree) = self.max_constraint_degree() {
60            builder.set_max_constraint_degree(max_constraint_degree);
61        }
62        builder
63    }
64
65    fn device(&self) -> &Self::PD;
66
67    fn prover(&self) -> Coordinator<Self::SC, Self::PB, Self::PD>;
68
69    fn verifier(&self) -> MultiTraceStarkVerifier<'_, Self::SC> {
70        MultiTraceStarkVerifier::new(self.config())
71    }
72
73    /// Add AIRs and get AIR IDs
74    fn set_up_keygen_builder(
75        &self,
76        keygen_builder: &mut MultiStarkKeygenBuilder<'_, Self::SC>,
77        airs: &[AirRef<Self::SC>],
78    ) -> Vec<usize> {
79        airs.iter()
80            .map(|air| keygen_builder.add_air(air.clone()))
81            .collect()
82    }
83
84    /// As a convenience, this function also transports the proving key from host to device.
85    /// Note that the [Self::prove] function starts from a [DeviceMultiStarkProvingKey],
86    /// which should be used if the proving key is already cached in device memory.
87    fn prove_then_verify(
88        &self,
89        pk: &MultiStarkProvingKey<Self::SC>,
90        ctx: ProvingContext<Self::PB>,
91    ) -> Result<Proof<Self::SC>, VerificationError> {
92        let pk_device = self.device().transport_pk_to_device(pk);
93        let proof = self.prove(&pk_device, ctx);
94        self.verify(&pk.get_vk(), &proof)?;
95        Ok(proof)
96    }
97
98    fn prove(
99        &self,
100        pk: &DeviceMultiStarkProvingKey<Self::PB>,
101        ctx: ProvingContext<Self::PB>,
102    ) -> Proof<Self::SC> {
103        let mpk_view = pk.view(ctx.air_ids());
104        let mut prover = self.prover();
105        let proof = prover.prove(mpk_view, ctx);
106        proof.into()
107    }
108
109    fn verify(
110        &self,
111        vk: &MultiStarkVerifyingKey<Self::SC>,
112        proof: &Proof<Self::SC>,
113    ) -> Result<(), VerificationError> {
114        let mut challenger = self.new_challenger();
115        let verifier = self.verifier();
116        verifier.verify(&mut challenger, vk, proof)
117    }
118
119    // mpk can be removed if we use BaseAir trait to regenerate preprocessed traces
120    fn debug(
121        &self,
122        airs: &[AirRef<Self::SC>],
123        pk: &[StarkProvingKey<Self::SC>],
124        proof_inputs: &[AirProofRawInput<Val<Self::SC>>],
125    ) {
126        let (trace_views, pvs): (Vec<_>, Vec<_>) = proof_inputs
127            .iter()
128            .map(|input| {
129                let mut views = input
130                    .cached_mains
131                    .iter()
132                    .map(|trace| trace.as_view())
133                    .collect_vec();
134                if let Some(trace) = input.common_main.as_ref() {
135                    views.push(trace.as_view());
136                }
137                (views, input.public_values.clone())
138            })
139            .unzip();
140        debug_constraints_and_interactions(airs, pk, &trace_views, &pvs);
141    }
142
143    /// Runs a single end-to-end test for a given set of chips and traces partitions.
144    /// This includes proving/verifying key generation, creating a proof, and verifying the proof.
145    fn run_test_impl(
146        &self,
147        airs: Vec<AirRef<Self::SC>>,
148        ctx: Vec<AirProvingContext<Self::PB>>,
149    ) -> Result<VerificationData<Self::SC>, VerificationError> {
150        let mut keygen_builder = self.keygen_builder();
151        let air_ids = self.set_up_keygen_builder(&mut keygen_builder, &airs);
152        let pk = keygen_builder.generate_pk();
153        let device = self.prover().device;
154        let proof_inputs = ctx
155            .iter()
156            .map(|air_ctx| {
157                let cached_mains = air_ctx
158                    .cached_mains
159                    .iter()
160                    .map(|pre| device.transport_matrix_from_device_to_host(&pre.trace))
161                    .collect_vec();
162                let common_main = air_ctx
163                    .common_main
164                    .as_ref()
165                    .map(|m| device.transport_matrix_from_device_to_host(m));
166                let public_values = air_ctx.public_values.clone();
167                AirProofRawInput {
168                    cached_mains,
169                    common_main,
170                    public_values,
171                }
172            })
173            .collect_vec();
174        self.debug(&airs, &pk.per_air, &proof_inputs);
175        let vk = pk.get_vk();
176        let ctx = ProvingContext {
177            per_air: zip_eq(air_ids, ctx).collect(),
178        };
179        let proof = self.prove_then_verify(&pk, ctx)?;
180        Ok(VerificationData { vk, proof })
181    }
182}