openvm_stark_backend/prover/
coordinator.rs

1use std::{iter, marker::PhantomData};
2
3use itertools::{izip, Itertools};
4use p3_challenger::CanObserve;
5use p3_field::FieldAlgebra;
6use p3_util::log2_strict_usize;
7use tracing::{info, instrument};
8
9use super::{
10    hal::{ProverBackend, ProverDevice},
11    types::{DeviceMultiStarkProvingKey, HalProof, ProvingContext},
12    Prover,
13};
14#[cfg(feature = "bench-metrics")]
15use crate::prover::metrics::trace_metrics;
16use crate::{
17    config::{Com, StarkGenericConfig, Val},
18    keygen::view::MultiStarkVerifyingKeyView,
19    proof::{AirProofData, Commitments},
20    prover::{
21        hal::MatrixDimensions,
22        types::{PairView, SingleCommitPreimage},
23    },
24    utils::metrics_span,
25};
26
27/// Host-to-device coordinator for full prover implementation.
28///
29/// The generics are:
30/// - `SC`: Stark configuration for proving key (from host)
31/// - `PB`: Prover backend types
32/// - `PD`: Prover device methods
33pub struct Coordinator<SC: StarkGenericConfig, PB, PD> {
34    pub backend: PB,
35    pub device: PD,
36    challenger: SC::Challenger,
37    phantom: PhantomData<(SC, PB)>,
38}
39
40impl<SC: StarkGenericConfig, PB, PD> Coordinator<SC, PB, PD> {
41    pub fn new(backend: PB, device: PD, challenger: SC::Challenger) -> Self {
42        Self {
43            backend,
44            device,
45            challenger,
46            phantom: PhantomData,
47        }
48    }
49}
50
51impl<SC, PB, PD> Prover for Coordinator<SC, PB, PD>
52where
53    SC: StarkGenericConfig,
54    PB: ProverBackend<
55        Val = Val<SC>,
56        Challenge = SC::Challenge,
57        Commitment = Com<SC>,
58        Challenger = SC::Challenger,
59    >,
60    PD: ProverDevice<PB>,
61{
62    type Proof = HalProof<PB>;
63    type ProvingKeyView<'a>
64        = &'a DeviceMultiStarkProvingKey<'a, PB>
65    where
66        Self: 'a;
67
68    type ProvingContext<'a>
69        = ProvingContext<'a, PB>
70    where
71        Self: 'a;
72
73    /// Specialized prove for InteractiveAirs.
74    /// Handles trace generation of the permutation traces.
75    /// Assumes the main traces have been generated and committed already.
76    ///
77    /// The [DeviceMultiStarkProvingKey] should already be filtered to only include the relevant AIR's proving keys.
78    #[instrument(name = "Coordinator::prove", level = "info", skip_all)]
79    fn prove<'a>(
80        &'a mut self,
81        mpk: Self::ProvingKeyView<'a>,
82        ctx: Self::ProvingContext<'a>,
83    ) -> Self::Proof {
84        #[cfg(feature = "bench-metrics")]
85        let start = std::time::Instant::now();
86        assert!(mpk.validate(&ctx), "Invalid proof input");
87        self.challenger.observe(mpk.vk_pre_hash.clone());
88
89        let num_air = ctx.per_air.len();
90        self.challenger
91            .observe(Val::<SC>::from_canonical_usize(num_air));
92        info!(num_air);
93        #[allow(clippy::type_complexity)]
94        let (cached_commits_per_air, cached_views_per_air, common_main_per_air, pvs_per_air): (
95            Vec<Vec<PB::Commitment>>,
96            Vec<Vec<SingleCommitPreimage<&'a PB::Matrix, &'a PB::PcsData>>>,
97            Vec<Option<PB::Matrix>>,
98            Vec<Vec<PB::Val>>,
99        ) = ctx
100            .into_iter()
101            .map(|(air_id, ctx)| {
102                self.challenger.observe(Val::<SC>::from_canonical_usize(air_id));
103                let (cached_commits, cached_views): (Vec<_>, Vec<_>) =
104                    ctx.cached_mains.into_iter().unzip();
105                (
106                    cached_commits,
107                    cached_views,
108                    ctx.common_main,
109                    ctx.public_values,
110                )
111            })
112            .multiunzip();
113
114        // ==================== All trace commitments that do not require challenges ====================
115        // Commit all common main traces in a commitment. Traces inside are ordered by AIR id.
116        let (common_main_traces, (common_main_commit, common_main_pcs_data)) =
117            metrics_span("main_trace_commit_time_ms", || {
118                let traces = common_main_per_air.into_iter().flatten().collect_vec();
119                let prover_data = self.device.commit(&traces);
120                (traces, prover_data)
121            });
122
123        // Commitments order:
124        // - for each air:
125        //   - for each cached main trace
126        //     - 1 commitment
127        // - 1 commitment of all common main traces
128        let main_trace_commitments: Vec<PB::Commitment> = cached_commits_per_air
129            .iter()
130            .flatten()
131            .chain(iter::once(&common_main_commit))
132            .cloned()
133            .collect();
134
135        // All commitments that don't require challenges have been made, so we collect them into trace views:
136        let mut common_main_idx = 0;
137        let mut log_trace_height_per_air: Vec<u8> = Vec::with_capacity(num_air);
138        let mut pair_trace_view_per_air = Vec::with_capacity(num_air);
139        for (pk, cached_views, pvs) in izip!(&mpk.per_air, &cached_views_per_air, &pvs_per_air) {
140            let mut main_trace_views: Vec<&PB::Matrix> =
141                cached_views.iter().map(|view| view.trace).collect_vec();
142            if pk.vk.has_common_main() {
143                main_trace_views.push(&common_main_traces[common_main_idx]);
144                common_main_idx += 1;
145            }
146            let trace_height = main_trace_views.first().expect("no main trace").height();
147            let log_trace_height: u8 = log2_strict_usize(trace_height).try_into().unwrap();
148            let pair_trace_view = PairView {
149                log_trace_height,
150                preprocessed: pk.preprocessed_data.as_ref().map(|d| &d.trace),
151                partitioned_main: main_trace_views,
152                public_values: pvs.to_vec(),
153            };
154            log_trace_height_per_air.push(log_trace_height);
155            pair_trace_view_per_air.push(pair_trace_view);
156        }
157        #[cfg(feature = "bench-metrics")]
158        trace_metrics(mpk, &log_trace_height_per_air).emit();
159
160        // ============ Challenger observations before additional RAP phases =============
161        // Observe public values:
162        for pvs in &pvs_per_air {
163            self.challenger.observe_slice(pvs);
164        }
165
166        // Observes preprocessed and main commitments:
167        let mvk = mpk.vk_view();
168        let preprocessed_commits = mvk.flattened_preprocessed_commits();
169        self.challenger.observe_slice(&preprocessed_commits);
170        self.challenger.observe_slice(&main_trace_commitments);
171        // Observe trace domain size per air:
172        self.challenger.observe_slice(
173            &log_trace_height_per_air
174                .iter()
175                .copied()
176                .map(Val::<SC>::from_canonical_u8)
177                .collect_vec(),
178        );
179
180        // ==================== Partially prove all RAP phases that require challenges ====================
181        let (rap_partial_proof, prover_data_after) =
182            self.device
183                .partially_prove(&mut self.challenger, mpk, pair_trace_view_per_air);
184        // Challenger observes additional commitments if any exist:
185        for (commit, _) in &prover_data_after.committed_pcs_data_per_phase {
186            self.challenger.observe(commit.clone());
187        }
188
189        // Collect exposed_values_per_air for the proof:
190        // - transpose per_phase, per_air -> per_air, per_phase
191        let exposed_values_per_air = (0..num_air)
192            .map(|i| {
193                let mut values = prover_data_after
194                    .rap_views_per_phase
195                    .iter()
196                    .map(|per_air| {
197                        per_air
198                            .get(i)
199                            .and_then(|v| v.inner.map(|_| v.exposed_values.clone()))
200                    })
201                    .collect_vec();
202                // Prune Nones
203                while let Some(last) = values.last() {
204                    if last.is_none() {
205                        values.pop();
206                    } else {
207                        break;
208                    }
209                }
210                values
211                    .into_iter()
212                    .map(|v| v.unwrap_or_default())
213                    .collect_vec()
214            })
215            .collect_vec();
216
217        // ==================== Quotient polynomial computation and commitment, if any ====================
218        // Note[jpw]: Currently we always call this step, we could add a flag to skip it for protocols that
219        // do not require quotient poly.
220        let (quotient_commit, quotient_data) = self.device.eval_and_commit_quotient(
221            &mut self.challenger,
222            &mpk.per_air,
223            &pvs_per_air,
224            &cached_views_per_air,
225            &common_main_pcs_data,
226            &prover_data_after,
227        );
228        // Observe quotient commitment
229        self.challenger.observe(quotient_commit.clone());
230
231        let (commitments_after, pcs_data_after): (Vec<_>, Vec<_>) = prover_data_after
232            .committed_pcs_data_per_phase
233            .into_iter()
234            .unzip();
235        // ==================== Polynomial Opening Proofs ====================
236        let opening = metrics_span("pcs_opening_time_ms", || {
237            let mut quotient_degrees = Vec::with_capacity(mpk.per_air.len());
238            let mut preprocessed = Vec::new();
239
240            for pk in &mpk.per_air {
241                quotient_degrees.push(pk.vk.quotient_degree);
242                if let Some(data) = pk.preprocessed_data.as_ref().map(|d| &d.data) {
243                    preprocessed.push(data);
244                }
245            }
246
247            let main = cached_views_per_air
248                .into_iter()
249                .flatten()
250                .map(|cv| cv.data)
251                .chain(iter::once(&common_main_pcs_data))
252                .collect();
253            self.device.open(
254                &mut self.challenger,
255                preprocessed,
256                main,
257                pcs_data_after,
258                quotient_data,
259                &quotient_degrees,
260            )
261        });
262
263        // ==================== Collect data into proof ====================
264        // Collect the commitments
265        let commitments = Commitments {
266            main_trace: main_trace_commitments,
267            after_challenge: commitments_after,
268            quotient: quotient_commit,
269        };
270        let proof = HalProof {
271            commitments,
272            opening,
273            per_air: izip!(
274                &mpk.air_ids,
275                log_trace_height_per_air,
276                exposed_values_per_air,
277                pvs_per_air
278            )
279            .map(
280                |(&air_id, log_height, exposed_values, public_values)| AirProofData {
281                    air_id,
282                    degree: 1 << log_height,
283                    public_values,
284                    exposed_values_after_challenge: exposed_values,
285                },
286            )
287            .collect(),
288            rap_partial_proof,
289        };
290
291        #[cfg(feature = "bench-metrics")]
292        ::metrics::gauge!("stark_prove_excluding_trace_time_ms")
293            .set(start.elapsed().as_millis() as f64);
294
295        proof
296    }
297}
298
299impl<'a, PB: ProverBackend> DeviceMultiStarkProvingKey<'a, PB> {
300    pub(crate) fn validate(&self, ctx: &ProvingContext<PB>) -> bool {
301        ctx.per_air.len() == self.air_ids.len()
302            && ctx
303                .per_air
304                .iter()
305                .zip(&self.air_ids)
306                .all(|((id1, _), id2)| id1 == id2)
307            && ctx.per_air.iter().tuple_windows().all(|(a, b)| a.0 < b.0)
308    }
309
310    pub(crate) fn vk_view(&'a self) -> MultiStarkVerifyingKeyView<'a, PB::Val, PB::Commitment> {
311        MultiStarkVerifyingKeyView::new(
312            self.per_air.iter().map(|pk| pk.vk).collect(),
313            &self.trace_height_constraints,
314            self.vk_pre_hash.clone(),
315        )
316    }
317}