openvm_stark_backend/prover/
coordinator.rs1use std::{iter, marker::PhantomData};
2
3use itertools::{izip, Itertools};
4use p3_challenger::CanObserve;
5use p3_field::PrimeCharacteristicRing;
6use p3_util::log2_strict_usize;
7use tracing::{info, info_span, instrument};
8
9use super::{
10 hal::{ProverBackend, ProverDevice},
11 types::{HalProof, ProvingContext},
12 Prover,
13};
14#[cfg(feature = "metrics")]
15use crate::prover::metrics::trace_metrics;
16use crate::{
17 config::{Com, StarkGenericConfig, Val},
18 keygen::view::MultiStarkVerifyingKeyView,
19 proof::{AirProofData, Commitments},
20 prover::{
21 hal::MatrixDimensions,
22 types::{AirView, DeviceMultiStarkProvingKeyView},
23 },
24};
25
26pub struct Coordinator<SC: StarkGenericConfig, PB, PD> {
33 pub backend: PB,
34 pub device: PD,
35 challenger: SC::Challenger,
36 phantom: PhantomData<(SC, PB)>,
37}
38
39impl<SC: StarkGenericConfig, PB, PD> Coordinator<SC, PB, PD> {
40 pub fn new(backend: PB, device: PD, challenger: SC::Challenger) -> Self {
41 Self {
42 backend,
43 device,
44 challenger,
45 phantom: PhantomData,
46 }
47 }
48}
49
50impl<SC, PB, PD> Prover for Coordinator<SC, PB, PD>
51where
52 SC: StarkGenericConfig,
53 PB: ProverBackend<
54 Val = Val<SC>,
55 Challenge = SC::Challenge,
56 Commitment = Com<SC>,
57 Challenger = SC::Challenger,
58 >,
59 PD: ProverDevice<PB>,
60{
61 type Proof = HalProof<PB>;
62 type ProvingKeyView<'a>
63 = DeviceMultiStarkProvingKeyView<'a, PB>
64 where
65 Self: 'a;
66
67 type ProvingContext<'a>
68 = ProvingContext<PB>
69 where
70 Self: 'a;
71
72 #[instrument(name = "stark_prove_excluding_trace", level = "info", skip_all)]
79 fn prove<'a>(
80 &'a mut self,
81 mpk: Self::ProvingKeyView<'a>,
82 ctx: Self::ProvingContext<'a>,
83 ) -> Self::Proof {
84 assert!(mpk.validate(&ctx), "Invalid proof input");
85 self.challenger.observe(mpk.vk_pre_hash.clone());
86
87 let num_air = ctx.per_air.len();
88 self.challenger.observe(Val::<SC>::from_usize(num_air));
89 info!(num_air);
90 #[allow(clippy::type_complexity)]
91 let (cached_commits_per_air, cached_views_per_air, common_main_per_air, pvs_per_air): (
92 Vec<Vec<PB::Commitment>>,
93 Vec<Vec<(PB::Matrix, PB::PcsData)>>,
94 Vec<Option<PB::Matrix>>,
95 Vec<Vec<PB::Val>>,
96 ) = ctx
97 .into_iter()
98 .map(|(air_id, ctx)| {
99 self.challenger.observe(Val::<SC>::from_usize(air_id));
100 let (cached_commits, cached_views): (Vec<_>, Vec<_>) =
101 ctx.cached_mains.into_iter().map(|cm| (cm.commitment, (cm.trace, cm.data))).unzip();
102 (
103 cached_commits,
104 cached_views,
105 ctx.common_main,
106 ctx.public_values,
107 )
108 })
109 .multiunzip();
110
111 let (common_main_traces, (common_main_commit, common_main_pcs_data)) =
115 info_span!("main_trace_commit").in_scope(|| {
116 let traces = common_main_per_air.into_iter().flatten().collect_vec();
117 let prover_data = self.device.commit(&traces);
118 (traces, prover_data)
119 });
120
121 let main_trace_commitments: Vec<PB::Commitment> = cached_commits_per_air
127 .iter()
128 .flatten()
129 .chain(iter::once(&common_main_commit))
130 .cloned()
131 .collect();
132
133 let mut common_main_traces_it = common_main_traces.into_iter();
136 let mut log_trace_height_per_air: Vec<u8> = Vec::with_capacity(num_air);
137 let mut air_trace_views_per_air = Vec::with_capacity(num_air);
138 let mut cached_pcs_datas_per_air = Vec::with_capacity(num_air);
139 for (pk, cached_views, pvs) in izip!(&mpk.per_air, cached_views_per_air, &pvs_per_air) {
140 let (mut main_trace_views, cached_pcs_datas): (Vec<PB::Matrix>, Vec<PB::PcsData>) =
141 cached_views.into_iter().unzip();
142 cached_pcs_datas_per_air.push(cached_pcs_datas);
143 if pk.vk.has_common_main() {
144 main_trace_views.push(common_main_traces_it.next().expect("expected common main"));
145 }
146 let trace_height = main_trace_views.first().expect("no main trace").height();
147 let log_trace_height: u8 = log2_strict_usize(trace_height).try_into().unwrap();
148 let air_trace_view = AirView {
149 partitioned_main: main_trace_views,
150 public_values: pvs.to_vec(),
151 };
152 log_trace_height_per_air.push(log_trace_height);
153 air_trace_views_per_air.push(air_trace_view);
154 }
155 #[cfg(feature = "metrics")]
156 trace_metrics(&mpk, &log_trace_height_per_air).emit();
157
158 for pvs in &pvs_per_air {
161 self.challenger.observe_slice(pvs);
162 }
163
164 let mvk = mpk.vk_view();
166 let preprocessed_commits = mvk.flattened_preprocessed_commits();
167 self.challenger.observe_slice(&preprocessed_commits);
168 self.challenger.observe_slice(&main_trace_commitments);
169 self.challenger.observe_slice(
171 &log_trace_height_per_air
172 .iter()
173 .copied()
174 .map(|h| Val::<SC>::from_usize(h as usize))
175 .collect_vec(),
176 );
177
178 let (rap_partial_proof, prover_data_after) =
181 self.device
182 .partially_prove(&mut self.challenger, &mpk, air_trace_views_per_air);
183 for (commit, _) in &prover_data_after.committed_pcs_data_per_phase {
187 self.challenger.observe(commit.clone());
188 }
189
190 let exposed_values_per_air = (0..num_air)
193 .map(|i| {
194 let mut values = prover_data_after
195 .rap_views_per_phase
196 .iter()
197 .map(|per_air| {
198 per_air
199 .get(i)
200 .and_then(|v| v.inner.map(|_| v.exposed_values.clone()))
201 })
202 .collect_vec();
203 while let Some(last) = values.last() {
205 if last.is_none() {
206 values.pop();
207 } else {
208 break;
209 }
210 }
211 values
212 .into_iter()
213 .map(|v| v.unwrap_or_default())
214 .collect_vec()
215 })
216 .collect_vec();
217
218 let (quotient_commit, quotient_data) = self.device.eval_and_commit_quotient(
222 &mut self.challenger,
223 &mpk.per_air,
224 &pvs_per_air,
225 &cached_pcs_datas_per_air,
226 &common_main_pcs_data,
227 &prover_data_after,
228 );
229 self.challenger.observe(quotient_commit.clone());
231
232 let (commitments_after, pcs_data_after): (Vec<_>, Vec<_>) = prover_data_after
233 .committed_pcs_data_per_phase
234 .into_iter()
235 .unzip();
236 let opening = info_span!("pcs_opening").in_scope(|| {
238 let mut quotient_degrees = Vec::with_capacity(mpk.per_air.len());
239 let mut preprocessed = Vec::new();
240
241 for pk in mpk.per_air {
242 quotient_degrees.push(pk.vk.quotient_degree);
243 if let Some(preprocessed_data) = &pk.preprocessed_data {
244 preprocessed.push(&preprocessed_data.data);
245 }
246 }
247
248 let main = cached_pcs_datas_per_air
249 .into_iter()
250 .flatten()
251 .chain(iter::once(common_main_pcs_data))
252 .collect();
253 self.device.open(
254 &mut self.challenger,
255 preprocessed,
256 main,
257 pcs_data_after,
258 quotient_data,
259 "ient_degrees,
260 )
261 });
262
263 let commitments = Commitments {
266 main_trace: main_trace_commitments,
267 after_challenge: commitments_after,
268 quotient: quotient_commit,
269 };
270 HalProof {
271 commitments,
272 opening,
273 per_air: izip!(
274 &mpk.air_ids,
275 log_trace_height_per_air,
276 exposed_values_per_air,
277 pvs_per_air
278 )
279 .map(
280 |(&air_id, log_height, exposed_values, public_values)| AirProofData {
281 air_id,
282 degree: 1 << log_height,
283 public_values,
284 exposed_values_after_challenge: exposed_values,
285 },
286 )
287 .collect(),
288 rap_partial_proof,
289 }
290 }
291}
292
293impl<'a, PB: ProverBackend> DeviceMultiStarkProvingKeyView<'a, PB> {
294 pub(crate) fn validate(&self, ctx: &ProvingContext<PB>) -> bool {
295 ctx.per_air.len() == self.air_ids.len()
296 && ctx
297 .per_air
298 .iter()
299 .zip(&self.air_ids)
300 .all(|((id1, _), id2)| id1 == id2)
301 && ctx.per_air.iter().tuple_windows().all(|(a, b)| a.0 < b.0)
302 }
303
304 pub(crate) fn vk_view(&'a self) -> MultiStarkVerifyingKeyView<'a, PB::Val, PB::Commitment> {
305 MultiStarkVerifyingKeyView::new(
306 self.per_air.iter().map(|pk| &pk.vk).collect(),
307 self.trace_height_constraints,
308 self.vk_pre_hash.clone(),
309 )
310 }
311}