openvm_stark_backend/prover/
coordinator.rs1use std::{iter, marker::PhantomData};
2
3use itertools::{izip, Itertools};
4use p3_challenger::CanObserve;
5use p3_field::FieldAlgebra;
6use p3_util::log2_strict_usize;
7use tracing::{info, info_span, instrument};
8
9use super::{
10 hal::{ProverBackend, ProverDevice},
11 types::{HalProof, ProvingContext},
12 Prover,
13};
14#[cfg(feature = "metrics")]
15use crate::prover::metrics::trace_metrics;
16use crate::{
17 config::{Com, StarkGenericConfig, Val},
18 keygen::view::MultiStarkVerifyingKeyView,
19 proof::{AirProofData, Commitments},
20 prover::{
21 hal::MatrixDimensions,
22 types::{AirView, DeviceMultiStarkProvingKeyView},
23 },
24};
25
26pub struct Coordinator<SC: StarkGenericConfig, PB, PD> {
33 pub backend: PB,
34 pub device: PD,
35 challenger: SC::Challenger,
36 phantom: PhantomData<(SC, PB)>,
37}
38
39impl<SC: StarkGenericConfig, PB, PD> Coordinator<SC, PB, PD> {
40 pub fn new(backend: PB, device: PD, challenger: SC::Challenger) -> Self {
41 Self {
42 backend,
43 device,
44 challenger,
45 phantom: PhantomData,
46 }
47 }
48}
49
50impl<SC, PB, PD> Prover for Coordinator<SC, PB, PD>
51where
52 SC: StarkGenericConfig,
53 PB: ProverBackend<
54 Val = Val<SC>,
55 Challenge = SC::Challenge,
56 Commitment = Com<SC>,
57 Challenger = SC::Challenger,
58 >,
59 PD: ProverDevice<PB>,
60{
61 type Proof = HalProof<PB>;
62 type ProvingKeyView<'a>
63 = DeviceMultiStarkProvingKeyView<'a, PB>
64 where
65 Self: 'a;
66
67 type ProvingContext<'a>
68 = ProvingContext<PB>
69 where
70 Self: 'a;
71
72 #[instrument(name = "stark_prove_excluding_trace", level = "info", skip_all)]
79 fn prove<'a>(
80 &'a mut self,
81 mpk: Self::ProvingKeyView<'a>,
82 ctx: Self::ProvingContext<'a>,
83 ) -> Self::Proof {
84 assert!(mpk.validate(&ctx), "Invalid proof input");
85 self.challenger.observe(mpk.vk_pre_hash.clone());
86
87 let num_air = ctx.per_air.len();
88 self.challenger
89 .observe(Val::<SC>::from_canonical_usize(num_air));
90 info!(num_air);
91 #[allow(clippy::type_complexity)]
92 let (cached_commits_per_air, cached_views_per_air, common_main_per_air, pvs_per_air): (
93 Vec<Vec<PB::Commitment>>,
94 Vec<Vec<(PB::Matrix, PB::PcsData)>>,
95 Vec<Option<PB::Matrix>>,
96 Vec<Vec<PB::Val>>,
97 ) = ctx
98 .into_iter()
99 .map(|(air_id, ctx)| {
100 self.challenger.observe(Val::<SC>::from_canonical_usize(air_id));
101 let (cached_commits, cached_views): (Vec<_>, Vec<_>) =
102 ctx.cached_mains.into_iter().map(|cm| (cm.commitment, (cm.trace, cm.data))).unzip();
103 (
104 cached_commits,
105 cached_views,
106 ctx.common_main,
107 ctx.public_values,
108 )
109 })
110 .multiunzip();
111
112 let (common_main_traces, (common_main_commit, common_main_pcs_data)) =
116 info_span!("main_trace_commit").in_scope(|| {
117 let traces = common_main_per_air.into_iter().flatten().collect_vec();
118 let prover_data = self.device.commit(&traces);
119 (traces, prover_data)
120 });
121
122 let main_trace_commitments: Vec<PB::Commitment> = cached_commits_per_air
128 .iter()
129 .flatten()
130 .chain(iter::once(&common_main_commit))
131 .cloned()
132 .collect();
133
134 let mut common_main_traces_it = common_main_traces.into_iter();
137 let mut log_trace_height_per_air: Vec<u8> = Vec::with_capacity(num_air);
138 let mut air_trace_views_per_air = Vec::with_capacity(num_air);
139 let mut cached_pcs_datas_per_air = Vec::with_capacity(num_air);
140 for (pk, cached_views, pvs) in izip!(&mpk.per_air, cached_views_per_air, &pvs_per_air) {
141 let (mut main_trace_views, cached_pcs_datas): (Vec<PB::Matrix>, Vec<PB::PcsData>) =
142 cached_views.into_iter().unzip();
143 cached_pcs_datas_per_air.push(cached_pcs_datas);
144 if pk.vk.has_common_main() {
145 main_trace_views.push(common_main_traces_it.next().expect("expected common main"));
146 }
147 let trace_height = main_trace_views.first().expect("no main trace").height();
148 let log_trace_height: u8 = log2_strict_usize(trace_height).try_into().unwrap();
149 let air_trace_view = AirView {
150 partitioned_main: main_trace_views,
151 public_values: pvs.to_vec(),
152 };
153 log_trace_height_per_air.push(log_trace_height);
154 air_trace_views_per_air.push(air_trace_view);
155 }
156 #[cfg(feature = "metrics")]
157 trace_metrics(&mpk, &log_trace_height_per_air).emit();
158
159 for pvs in &pvs_per_air {
162 self.challenger.observe_slice(pvs);
163 }
164
165 let mvk = mpk.vk_view();
167 let preprocessed_commits = mvk.flattened_preprocessed_commits();
168 self.challenger.observe_slice(&preprocessed_commits);
169 self.challenger.observe_slice(&main_trace_commitments);
170 self.challenger.observe_slice(
172 &log_trace_height_per_air
173 .iter()
174 .copied()
175 .map(Val::<SC>::from_canonical_u8)
176 .collect_vec(),
177 );
178
179 let (rap_partial_proof, prover_data_after) =
182 self.device
183 .partially_prove(&mut self.challenger, &mpk, air_trace_views_per_air);
184 for (commit, _) in &prover_data_after.committed_pcs_data_per_phase {
188 self.challenger.observe(commit.clone());
189 }
190
191 let exposed_values_per_air = (0..num_air)
194 .map(|i| {
195 let mut values = prover_data_after
196 .rap_views_per_phase
197 .iter()
198 .map(|per_air| {
199 per_air
200 .get(i)
201 .and_then(|v| v.inner.map(|_| v.exposed_values.clone()))
202 })
203 .collect_vec();
204 while let Some(last) = values.last() {
206 if last.is_none() {
207 values.pop();
208 } else {
209 break;
210 }
211 }
212 values
213 .into_iter()
214 .map(|v| v.unwrap_or_default())
215 .collect_vec()
216 })
217 .collect_vec();
218
219 let (quotient_commit, quotient_data) = self.device.eval_and_commit_quotient(
223 &mut self.challenger,
224 &mpk.per_air,
225 &pvs_per_air,
226 &cached_pcs_datas_per_air,
227 &common_main_pcs_data,
228 &prover_data_after,
229 );
230 self.challenger.observe(quotient_commit.clone());
232
233 let (commitments_after, pcs_data_after): (Vec<_>, Vec<_>) = prover_data_after
234 .committed_pcs_data_per_phase
235 .into_iter()
236 .unzip();
237 let opening = info_span!("pcs_opening").in_scope(|| {
239 let mut quotient_degrees = Vec::with_capacity(mpk.per_air.len());
240 let mut preprocessed = Vec::new();
241
242 for pk in mpk.per_air {
243 quotient_degrees.push(pk.vk.quotient_degree);
244 if let Some(preprocessed_data) = &pk.preprocessed_data {
245 preprocessed.push(&preprocessed_data.data);
246 }
247 }
248
249 let main = cached_pcs_datas_per_air
250 .into_iter()
251 .flatten()
252 .chain(iter::once(common_main_pcs_data))
253 .collect();
254 self.device.open(
255 &mut self.challenger,
256 preprocessed,
257 main,
258 pcs_data_after,
259 quotient_data,
260 "ient_degrees,
261 )
262 });
263
264 let commitments = Commitments {
267 main_trace: main_trace_commitments,
268 after_challenge: commitments_after,
269 quotient: quotient_commit,
270 };
271 HalProof {
272 commitments,
273 opening,
274 per_air: izip!(
275 &mpk.air_ids,
276 log_trace_height_per_air,
277 exposed_values_per_air,
278 pvs_per_air
279 )
280 .map(
281 |(&air_id, log_height, exposed_values, public_values)| AirProofData {
282 air_id,
283 degree: 1 << log_height,
284 public_values,
285 exposed_values_after_challenge: exposed_values,
286 },
287 )
288 .collect(),
289 rap_partial_proof,
290 }
291 }
292}
293
294impl<'a, PB: ProverBackend> DeviceMultiStarkProvingKeyView<'a, PB> {
295 pub(crate) fn validate(&self, ctx: &ProvingContext<PB>) -> bool {
296 ctx.per_air.len() == self.air_ids.len()
297 && ctx
298 .per_air
299 .iter()
300 .zip(&self.air_ids)
301 .all(|((id1, _), id2)| id1 == id2)
302 && ctx.per_air.iter().tuple_windows().all(|(a, b)| a.0 < b.0)
303 }
304
305 pub(crate) fn vk_view(&'a self) -> MultiStarkVerifyingKeyView<'a, PB::Val, PB::Commitment> {
306 MultiStarkVerifyingKeyView::new(
307 self.per_air.iter().map(|pk| &pk.vk).collect(),
308 self.trace_height_constraints,
309 self.vk_pre_hash.clone(),
310 )
311 }
312}