openvm_stark_backend/
engine.rs

1use std::{iter::zip, sync::Arc};
2
3use itertools::{zip_eq, Itertools};
4use p3_matrix::Matrix;
5use p3_util::log2_strict_usize;
6
7use crate::{
8    air_builders::debug::debug_constraints_and_interactions,
9    config::StarkGenericConfig,
10    keygen::{
11        types::{MultiStarkProvingKey, MultiStarkVerifyingKey, StarkProvingKey},
12        MultiStarkKeygenBuilder,
13    },
14    proof::Proof,
15    prover::{
16        cpu::{CpuBackend, CpuDevice, PcsData},
17        hal::{DeviceDataTransporter, TraceCommitter},
18        types::{
19            AirProofInput, AirProvingContext, ProofInput, ProvingContext, SingleCommitPreimage,
20        },
21        MultiTraceStarkProver, Prover,
22    },
23    verifier::{MultiTraceStarkVerifier, VerificationError},
24    AirRef,
25};
26
27/// Data for verifying a Stark proof.
28pub struct VerificationData<SC: StarkGenericConfig> {
29    pub vk: MultiStarkVerifyingKey<SC>,
30    pub proof: Proof<SC>,
31}
32
33/// A helper trait to collect the different steps in multi-trace STARK
34/// keygen and proving. Currently this trait is CPU specific.
35pub trait StarkEngine<SC: StarkGenericConfig> {
36    /// Stark config
37    fn config(&self) -> &SC;
38
39    /// During keygen, the circuit may be optimized but it will **try** to keep the
40    /// constraint degree at most this value.
41    fn max_constraint_degree(&self) -> Option<usize> {
42        None
43    }
44
45    /// Creates a new challenger with a deterministic state.
46    /// Creating new challenger for prover and verifier separately will result in
47    /// them having the same starting state.
48    fn new_challenger(&self) -> SC::Challenger;
49
50    fn keygen_builder(&self) -> MultiStarkKeygenBuilder<SC> {
51        let mut builder = MultiStarkKeygenBuilder::new(self.config());
52        if let Some(max_constraint_degree) = self.max_constraint_degree() {
53            builder.set_max_constraint_degree(max_constraint_degree);
54        }
55        builder
56    }
57
58    fn prover<'a>(&'a self) -> MultiTraceStarkProver<'a, SC>
59    where
60        Self: 'a,
61    {
62        MultiTraceStarkProver::new(
63            CpuBackend::<SC>::default(),
64            CpuDevice::new(self.config()),
65            self.new_challenger(),
66        )
67    }
68
69    fn verifier(&self) -> MultiTraceStarkVerifier<SC> {
70        MultiTraceStarkVerifier::new(self.config())
71    }
72
73    /// Add AIRs and get AIR IDs
74    fn set_up_keygen_builder(
75        &self,
76        keygen_builder: &mut MultiStarkKeygenBuilder<'_, SC>,
77        airs: &[AirRef<SC>],
78    ) -> Vec<usize> {
79        airs.iter()
80            .map(|air| keygen_builder.add_air(air.clone()))
81            .collect()
82    }
83
84    fn prove_then_verify(
85        &self,
86        mpk: &MultiStarkProvingKey<SC>,
87        proof_input: ProofInput<SC>,
88    ) -> Result<(), VerificationError> {
89        let proof = self.prove(mpk, proof_input);
90        self.verify(&mpk.get_vk(), &proof)
91    }
92
93    fn prove(&self, mpk: &MultiStarkProvingKey<SC>, proof_input: ProofInput<SC>) -> Proof<SC> {
94        let mut prover = self.prover();
95        let backend = prover.backend;
96        let air_ids = proof_input.per_air.iter().map(|(id, _)| *id).collect();
97        // Commit cached traces if they are not provided
98        let cached_mains_per_air = proof_input
99            .per_air
100            .iter()
101            .map(|(_, input)| {
102                if input.cached_mains_pdata.len() != input.raw.cached_mains.len() {
103                    input
104                        .raw
105                        .cached_mains
106                        .iter()
107                        .map(|trace| {
108                            let trace = backend.transport_matrix_to_device(trace);
109                            let (com, data) = prover.device.commit(&[trace.clone()]);
110                            (
111                                com,
112                                SingleCommitPreimage {
113                                    trace,
114                                    data,
115                                    matrix_idx: 0,
116                                },
117                            )
118                        })
119                        .collect_vec()
120                } else {
121                    zip(&input.cached_mains_pdata, &input.raw.cached_mains)
122                        .map(|((com, data), trace)| {
123                            let data_view = PcsData {
124                                data: data.clone(),
125                                log_trace_heights: vec![log2_strict_usize(trace.height()) as u8],
126                            };
127                            let preimage = SingleCommitPreimage {
128                                trace: trace.clone(),
129                                data: data_view,
130                                matrix_idx: 0,
131                            };
132                            (com.clone(), preimage)
133                        })
134                        .collect_vec()
135                }
136            })
137            .collect_vec();
138        let ctx_per_air = zip(proof_input.per_air, &cached_mains_per_air)
139            .map(|((air_id, input), cached_mains)| {
140                let cached_mains = cached_mains
141                    .iter()
142                    .map(|(com, preimage)| {
143                        (
144                            com.clone(),
145                            SingleCommitPreimage {
146                                trace: &preimage.trace,
147                                data: &preimage.data,
148                                matrix_idx: preimage.matrix_idx,
149                            },
150                        )
151                    })
152                    .collect_vec();
153                let air_ctx = AirProvingContext {
154                    cached_mains,
155                    common_main: input.raw.common_main.map(Arc::new),
156                    public_values: input.raw.public_values,
157                };
158                (air_id, air_ctx)
159            })
160            .collect();
161        let ctx = ProvingContext {
162            per_air: ctx_per_air,
163        };
164        let mpk_view = backend.transport_pk_to_device(mpk, air_ids);
165        let proof = Prover::prove(&mut prover, &mpk_view, ctx);
166        proof.into()
167    }
168
169    fn verify(
170        &self,
171        vk: &MultiStarkVerifyingKey<SC>,
172        proof: &Proof<SC>,
173    ) -> Result<(), VerificationError> {
174        let mut challenger = self.new_challenger();
175        let verifier = self.verifier();
176        verifier.verify(&mut challenger, vk, proof)
177    }
178
179    // mpk can be removed if we use BaseAir trait to regenerate preprocessed traces
180    fn debug(
181        &self,
182        airs: &[AirRef<SC>],
183        pk: &[StarkProvingKey<SC>],
184        proof_inputs: &[AirProofInput<SC>],
185    ) {
186        let (trace_views, pvs): (Vec<_>, Vec<_>) = proof_inputs
187            .iter()
188            .map(|input| {
189                let mut views = input
190                    .raw
191                    .cached_mains
192                    .iter()
193                    .map(|trace| trace.as_view())
194                    .collect_vec();
195                if let Some(trace) = input.raw.common_main.as_ref() {
196                    views.push(trace.as_view());
197                }
198                (views, input.raw.public_values.clone())
199            })
200            .unzip();
201        debug_constraints_and_interactions(airs, pk, &trace_views, &pvs);
202    }
203
204    /// Runs a single end-to-end test for a given set of chips and traces partitions.
205    /// This includes proving/verifying key generation, creating a proof, and verifying the proof.
206    fn run_test_impl(
207        &self,
208        airs: Vec<AirRef<SC>>,
209        air_proof_inputs: Vec<AirProofInput<SC>>,
210    ) -> Result<VerificationData<SC>, VerificationError> {
211        let mut keygen_builder = self.keygen_builder();
212        let air_ids = self.set_up_keygen_builder(&mut keygen_builder, &airs);
213        let pk = keygen_builder.generate_pk();
214        self.debug(&airs, &pk.per_air, &air_proof_inputs);
215        let vk = pk.get_vk();
216        let proof_input = ProofInput {
217            per_air: zip_eq(air_ids, air_proof_inputs).collect(),
218        };
219        let proof = self.prove(&pk, proof_input);
220        self.verify(&vk, &proof)?;
221        Ok(VerificationData { vk, proof })
222    }
223}