1use std::{iter::zip, sync::Arc};
2
3use itertools::{zip_eq, Itertools};
4use p3_matrix::Matrix;
5use p3_util::log2_strict_usize;
6
7use crate::{
8 air_builders::debug::debug_constraints_and_interactions,
9 config::StarkGenericConfig,
10 keygen::{
11 types::{MultiStarkProvingKey, MultiStarkVerifyingKey, StarkProvingKey},
12 MultiStarkKeygenBuilder,
13 },
14 proof::Proof,
15 prover::{
16 cpu::{CpuBackend, CpuDevice, PcsData},
17 hal::{DeviceDataTransporter, TraceCommitter},
18 types::{
19 AirProofInput, AirProvingContext, ProofInput, ProvingContext, SingleCommitPreimage,
20 },
21 MultiTraceStarkProver, Prover,
22 },
23 verifier::{MultiTraceStarkVerifier, VerificationError},
24 AirRef,
25};
26
27pub struct VerificationData<SC: StarkGenericConfig> {
29 pub vk: MultiStarkVerifyingKey<SC>,
30 pub proof: Proof<SC>,
31}
32
33pub trait StarkEngine<SC: StarkGenericConfig> {
36 fn config(&self) -> &SC;
38
39 fn max_constraint_degree(&self) -> Option<usize> {
42 None
43 }
44
45 fn new_challenger(&self) -> SC::Challenger;
49
50 fn keygen_builder(&self) -> MultiStarkKeygenBuilder<SC> {
51 let mut builder = MultiStarkKeygenBuilder::new(self.config());
52 if let Some(max_constraint_degree) = self.max_constraint_degree() {
53 builder.set_max_constraint_degree(max_constraint_degree);
54 }
55 builder
56 }
57
58 fn prover<'a>(&'a self) -> MultiTraceStarkProver<'a, SC>
59 where
60 Self: 'a,
61 {
62 MultiTraceStarkProver::new(
63 CpuBackend::<SC>::default(),
64 CpuDevice::new(self.config()),
65 self.new_challenger(),
66 )
67 }
68
69 fn verifier(&self) -> MultiTraceStarkVerifier<SC> {
70 MultiTraceStarkVerifier::new(self.config())
71 }
72
73 fn set_up_keygen_builder(
75 &self,
76 keygen_builder: &mut MultiStarkKeygenBuilder<'_, SC>,
77 airs: &[AirRef<SC>],
78 ) -> Vec<usize> {
79 airs.iter()
80 .map(|air| keygen_builder.add_air(air.clone()))
81 .collect()
82 }
83
84 fn prove_then_verify(
85 &self,
86 mpk: &MultiStarkProvingKey<SC>,
87 proof_input: ProofInput<SC>,
88 ) -> Result<(), VerificationError> {
89 let proof = self.prove(mpk, proof_input);
90 self.verify(&mpk.get_vk(), &proof)
91 }
92
93 fn prove(&self, mpk: &MultiStarkProvingKey<SC>, proof_input: ProofInput<SC>) -> Proof<SC> {
94 let mut prover = self.prover();
95 let backend = prover.backend;
96 let air_ids = proof_input.per_air.iter().map(|(id, _)| *id).collect();
97 let cached_mains_per_air = proof_input
99 .per_air
100 .iter()
101 .map(|(_, input)| {
102 if input.cached_mains_pdata.len() != input.raw.cached_mains.len() {
103 input
104 .raw
105 .cached_mains
106 .iter()
107 .map(|trace| {
108 let trace = backend.transport_matrix_to_device(trace);
109 let (com, data) = prover.device.commit(&[trace.clone()]);
110 (
111 com,
112 SingleCommitPreimage {
113 trace,
114 data,
115 matrix_idx: 0,
116 },
117 )
118 })
119 .collect_vec()
120 } else {
121 zip(&input.cached_mains_pdata, &input.raw.cached_mains)
122 .map(|((com, data), trace)| {
123 let data_view = PcsData {
124 data: data.clone(),
125 log_trace_heights: vec![log2_strict_usize(trace.height()) as u8],
126 };
127 let preimage = SingleCommitPreimage {
128 trace: trace.clone(),
129 data: data_view,
130 matrix_idx: 0,
131 };
132 (com.clone(), preimage)
133 })
134 .collect_vec()
135 }
136 })
137 .collect_vec();
138 let ctx_per_air = zip(proof_input.per_air, &cached_mains_per_air)
139 .map(|((air_id, input), cached_mains)| {
140 let cached_mains = cached_mains
141 .iter()
142 .map(|(com, preimage)| {
143 (
144 com.clone(),
145 SingleCommitPreimage {
146 trace: &preimage.trace,
147 data: &preimage.data,
148 matrix_idx: preimage.matrix_idx,
149 },
150 )
151 })
152 .collect_vec();
153 let air_ctx = AirProvingContext {
154 cached_mains,
155 common_main: input.raw.common_main.map(Arc::new),
156 public_values: input.raw.public_values,
157 };
158 (air_id, air_ctx)
159 })
160 .collect();
161 let ctx = ProvingContext {
162 per_air: ctx_per_air,
163 };
164 let mpk_view = backend.transport_pk_to_device(mpk, air_ids);
165 let proof = Prover::prove(&mut prover, &mpk_view, ctx);
166 proof.into()
167 }
168
169 fn verify(
170 &self,
171 vk: &MultiStarkVerifyingKey<SC>,
172 proof: &Proof<SC>,
173 ) -> Result<(), VerificationError> {
174 let mut challenger = self.new_challenger();
175 let verifier = self.verifier();
176 verifier.verify(&mut challenger, vk, proof)
177 }
178
179 fn debug(
181 &self,
182 airs: &[AirRef<SC>],
183 pk: &[StarkProvingKey<SC>],
184 proof_inputs: &[AirProofInput<SC>],
185 ) {
186 let (trace_views, pvs): (Vec<_>, Vec<_>) = proof_inputs
187 .iter()
188 .map(|input| {
189 let mut views = input
190 .raw
191 .cached_mains
192 .iter()
193 .map(|trace| trace.as_view())
194 .collect_vec();
195 if let Some(trace) = input.raw.common_main.as_ref() {
196 views.push(trace.as_view());
197 }
198 (views, input.raw.public_values.clone())
199 })
200 .unzip();
201 debug_constraints_and_interactions(airs, pk, &trace_views, &pvs);
202 }
203
204 fn run_test_impl(
207 &self,
208 airs: Vec<AirRef<SC>>,
209 air_proof_inputs: Vec<AirProofInput<SC>>,
210 ) -> Result<VerificationData<SC>, VerificationError> {
211 let mut keygen_builder = self.keygen_builder();
212 let air_ids = self.set_up_keygen_builder(&mut keygen_builder, &airs);
213 let pk = keygen_builder.generate_pk();
214 self.debug(&airs, &pk.per_air, &air_proof_inputs);
215 let vk = pk.get_vk();
216 let proof_input = ProofInput {
217 per_air: zip_eq(air_ids, air_proof_inputs).collect(),
218 };
219 let proof = self.prove(&pk, proof_input);
220 self.verify(&vk, &proof)?;
221 Ok(VerificationData { vk, proof })
222 }
223}