openvm_stark_backend/prover/
coordinator.rs
1use std::{iter, marker::PhantomData};
2
3use itertools::{izip, Itertools};
4use p3_challenger::CanObserve;
5use p3_field::FieldAlgebra;
6use p3_util::log2_strict_usize;
7use tracing::{info, instrument};
8
9use super::{
10 hal::{ProverBackend, ProverDevice},
11 types::{DeviceMultiStarkProvingKey, HalProof, ProvingContext},
12 Prover,
13};
14#[cfg(feature = "bench-metrics")]
15use crate::prover::metrics::trace_metrics;
16use crate::{
17 config::{Com, StarkGenericConfig, Val},
18 keygen::view::MultiStarkVerifyingKeyView,
19 proof::{AirProofData, Commitments},
20 prover::{
21 hal::MatrixDimensions,
22 types::{PairView, SingleCommitPreimage},
23 },
24 utils::metrics_span,
25};
26
27pub struct Coordinator<SC: StarkGenericConfig, PB, PD> {
34 pub backend: PB,
35 pub device: PD,
36 challenger: SC::Challenger,
37 phantom: PhantomData<(SC, PB)>,
38}
39
40impl<SC: StarkGenericConfig, PB, PD> Coordinator<SC, PB, PD> {
41 pub fn new(backend: PB, device: PD, challenger: SC::Challenger) -> Self {
42 Self {
43 backend,
44 device,
45 challenger,
46 phantom: PhantomData,
47 }
48 }
49}
50
51impl<SC, PB, PD> Prover for Coordinator<SC, PB, PD>
52where
53 SC: StarkGenericConfig,
54 PB: ProverBackend<
55 Val = Val<SC>,
56 Challenge = SC::Challenge,
57 Commitment = Com<SC>,
58 Challenger = SC::Challenger,
59 >,
60 PD: ProverDevice<PB>,
61{
62 type Proof = HalProof<PB>;
63 type ProvingKeyView<'a>
64 = &'a DeviceMultiStarkProvingKey<'a, PB>
65 where
66 Self: 'a;
67
68 type ProvingContext<'a>
69 = ProvingContext<'a, PB>
70 where
71 Self: 'a;
72
73 #[instrument(name = "Coordinator::prove", level = "info", skip_all)]
79 fn prove<'a>(
80 &'a mut self,
81 mpk: Self::ProvingKeyView<'a>,
82 ctx: Self::ProvingContext<'a>,
83 ) -> Self::Proof {
84 #[cfg(feature = "bench-metrics")]
85 let start = std::time::Instant::now();
86 assert!(mpk.validate(&ctx), "Invalid proof input");
87 self.challenger.observe(mpk.vk_pre_hash.clone());
88
89 let num_air = ctx.per_air.len();
90 self.challenger
91 .observe(Val::<SC>::from_canonical_usize(num_air));
92 info!(num_air);
93 #[allow(clippy::type_complexity)]
94 let (cached_commits_per_air, cached_views_per_air, common_main_per_air, pvs_per_air): (
95 Vec<Vec<PB::Commitment>>,
96 Vec<Vec<SingleCommitPreimage<&'a PB::Matrix, &'a PB::PcsData>>>,
97 Vec<Option<PB::Matrix>>,
98 Vec<Vec<PB::Val>>,
99 ) = ctx
100 .into_iter()
101 .map(|(air_id, ctx)| {
102 self.challenger.observe(Val::<SC>::from_canonical_usize(air_id));
103 let (cached_commits, cached_views): (Vec<_>, Vec<_>) =
104 ctx.cached_mains.into_iter().unzip();
105 (
106 cached_commits,
107 cached_views,
108 ctx.common_main,
109 ctx.public_values,
110 )
111 })
112 .multiunzip();
113
114 let (common_main_traces, (common_main_commit, common_main_pcs_data)) =
117 metrics_span("main_trace_commit_time_ms", || {
118 let traces = common_main_per_air.into_iter().flatten().collect_vec();
119 let prover_data = self.device.commit(&traces);
120 (traces, prover_data)
121 });
122
123 let main_trace_commitments: Vec<PB::Commitment> = cached_commits_per_air
129 .iter()
130 .flatten()
131 .chain(iter::once(&common_main_commit))
132 .cloned()
133 .collect();
134
135 let mut common_main_idx = 0;
137 let mut log_trace_height_per_air: Vec<u8> = Vec::with_capacity(num_air);
138 let mut pair_trace_view_per_air = Vec::with_capacity(num_air);
139 for (pk, cached_views, pvs) in izip!(&mpk.per_air, &cached_views_per_air, &pvs_per_air) {
140 let mut main_trace_views: Vec<&PB::Matrix> =
141 cached_views.iter().map(|view| view.trace).collect_vec();
142 if pk.vk.has_common_main() {
143 main_trace_views.push(&common_main_traces[common_main_idx]);
144 common_main_idx += 1;
145 }
146 let trace_height = main_trace_views.first().expect("no main trace").height();
147 let log_trace_height: u8 = log2_strict_usize(trace_height).try_into().unwrap();
148 let pair_trace_view = PairView {
149 log_trace_height,
150 preprocessed: pk.preprocessed_data.as_ref().map(|d| &d.trace),
151 partitioned_main: main_trace_views,
152 public_values: pvs.to_vec(),
153 };
154 log_trace_height_per_air.push(log_trace_height);
155 pair_trace_view_per_air.push(pair_trace_view);
156 }
157 #[cfg(feature = "bench-metrics")]
158 trace_metrics(mpk, &log_trace_height_per_air).emit();
159
160 for pvs in &pvs_per_air {
163 self.challenger.observe_slice(pvs);
164 }
165
166 let mvk = mpk.vk_view();
168 let preprocessed_commits = mvk.flattened_preprocessed_commits();
169 self.challenger.observe_slice(&preprocessed_commits);
170 self.challenger.observe_slice(&main_trace_commitments);
171 self.challenger.observe_slice(
173 &log_trace_height_per_air
174 .iter()
175 .copied()
176 .map(Val::<SC>::from_canonical_u8)
177 .collect_vec(),
178 );
179
180 let (rap_partial_proof, prover_data_after) =
182 self.device
183 .partially_prove(&mut self.challenger, mpk, pair_trace_view_per_air);
184 for (commit, _) in &prover_data_after.committed_pcs_data_per_phase {
186 self.challenger.observe(commit.clone());
187 }
188
189 let exposed_values_per_air = (0..num_air)
192 .map(|i| {
193 let mut values = prover_data_after
194 .rap_views_per_phase
195 .iter()
196 .map(|per_air| {
197 per_air
198 .get(i)
199 .and_then(|v| v.inner.map(|_| v.exposed_values.clone()))
200 })
201 .collect_vec();
202 while let Some(last) = values.last() {
204 if last.is_none() {
205 values.pop();
206 } else {
207 break;
208 }
209 }
210 values
211 .into_iter()
212 .map(|v| v.unwrap_or_default())
213 .collect_vec()
214 })
215 .collect_vec();
216
217 let (quotient_commit, quotient_data) = self.device.eval_and_commit_quotient(
221 &mut self.challenger,
222 &mpk.per_air,
223 &pvs_per_air,
224 &cached_views_per_air,
225 &common_main_pcs_data,
226 &prover_data_after,
227 );
228 self.challenger.observe(quotient_commit.clone());
230
231 let (commitments_after, pcs_data_after): (Vec<_>, Vec<_>) = prover_data_after
232 .committed_pcs_data_per_phase
233 .into_iter()
234 .unzip();
235 let opening = metrics_span("pcs_opening_time_ms", || {
237 let mut quotient_degrees = Vec::with_capacity(mpk.per_air.len());
238 let mut preprocessed = Vec::new();
239
240 for pk in &mpk.per_air {
241 quotient_degrees.push(pk.vk.quotient_degree);
242 if let Some(data) = pk.preprocessed_data.as_ref().map(|d| &d.data) {
243 preprocessed.push(data);
244 }
245 }
246
247 let main = cached_views_per_air
248 .into_iter()
249 .flatten()
250 .map(|cv| cv.data)
251 .chain(iter::once(&common_main_pcs_data))
252 .collect();
253 self.device.open(
254 &mut self.challenger,
255 preprocessed,
256 main,
257 pcs_data_after,
258 quotient_data,
259 "ient_degrees,
260 )
261 });
262
263 let commitments = Commitments {
266 main_trace: main_trace_commitments,
267 after_challenge: commitments_after,
268 quotient: quotient_commit,
269 };
270 let proof = HalProof {
271 commitments,
272 opening,
273 per_air: izip!(
274 &mpk.air_ids,
275 log_trace_height_per_air,
276 exposed_values_per_air,
277 pvs_per_air
278 )
279 .map(
280 |(&air_id, log_height, exposed_values, public_values)| AirProofData {
281 air_id,
282 degree: 1 << log_height,
283 public_values,
284 exposed_values_after_challenge: exposed_values,
285 },
286 )
287 .collect(),
288 rap_partial_proof,
289 };
290
291 #[cfg(feature = "bench-metrics")]
292 ::metrics::gauge!("stark_prove_excluding_trace_time_ms")
293 .set(start.elapsed().as_millis() as f64);
294
295 proof
296 }
297}
298
299impl<'a, PB: ProverBackend> DeviceMultiStarkProvingKey<'a, PB> {
300 pub(crate) fn validate(&self, ctx: &ProvingContext<PB>) -> bool {
301 ctx.per_air.len() == self.air_ids.len()
302 && ctx
303 .per_air
304 .iter()
305 .zip(&self.air_ids)
306 .all(|((id1, _), id2)| id1 == id2)
307 && ctx.per_air.iter().tuple_windows().all(|(a, b)| a.0 < b.0)
308 }
309
310 pub(crate) fn vk_view(&'a self) -> MultiStarkVerifyingKeyView<'a, PB::Val, PB::Commitment> {
311 MultiStarkVerifyingKeyView::new(
312 self.per_air.iter().map(|pk| pk.vk).collect(),
313 &self.trace_height_constraints,
314 self.vk_pre_hash.clone(),
315 )
316 }
317}