openvm_stark_backend/prover/
coordinator.rs

1use std::{iter, marker::PhantomData};
2
3use itertools::{izip, Itertools};
4use p3_challenger::CanObserve;
5use p3_field::PrimeCharacteristicRing;
6use p3_util::log2_strict_usize;
7use tracing::{info, info_span, instrument};
8
9use super::{
10    hal::{ProverBackend, ProverDevice},
11    types::{HalProof, ProvingContext},
12    Prover,
13};
14#[cfg(feature = "metrics")]
15use crate::prover::metrics::trace_metrics;
16use crate::{
17    config::{Com, StarkGenericConfig, Val},
18    keygen::view::MultiStarkVerifyingKeyView,
19    proof::{AirProofData, Commitments},
20    prover::{
21        hal::MatrixDimensions,
22        types::{AirView, DeviceMultiStarkProvingKeyView},
23    },
24};
25
26/// Host-to-device coordinator for full prover implementation.
27///
28/// The generics are:
29/// - `SC`: Stark configuration for proving key (from host)
30/// - `PB`: Prover backend types
31/// - `PD`: Prover device methods
32pub struct Coordinator<SC: StarkGenericConfig, PB, PD> {
33    pub backend: PB,
34    pub device: PD,
35    challenger: SC::Challenger,
36    phantom: PhantomData<(SC, PB)>,
37}
38
39impl<SC: StarkGenericConfig, PB, PD> Coordinator<SC, PB, PD> {
40    pub fn new(backend: PB, device: PD, challenger: SC::Challenger) -> Self {
41        Self {
42            backend,
43            device,
44            challenger,
45            phantom: PhantomData,
46        }
47    }
48}
49
50impl<SC, PB, PD> Prover for Coordinator<SC, PB, PD>
51where
52    SC: StarkGenericConfig,
53    PB: ProverBackend<
54        Val = Val<SC>,
55        Challenge = SC::Challenge,
56        Commitment = Com<SC>,
57        Challenger = SC::Challenger,
58    >,
59    PD: ProverDevice<PB>,
60{
61    type Proof = HalProof<PB>;
62    type ProvingKeyView<'a>
63        = DeviceMultiStarkProvingKeyView<'a, PB>
64    where
65        Self: 'a;
66
67    type ProvingContext<'a>
68        = ProvingContext<PB>
69    where
70        Self: 'a;
71
72    /// Specialized prove for InteractiveAirs.
73    /// Handles trace generation of the permutation traces.
74    /// Assumes the main traces have been generated and committed already.
75    ///
76    /// The [DeviceMultiStarkProvingKey] should already be filtered to only include the relevant
77    /// AIR's proving keys.
78    #[instrument(name = "stark_prove_excluding_trace", level = "info", skip_all)]
79    fn prove<'a>(
80        &'a mut self,
81        mpk: Self::ProvingKeyView<'a>,
82        ctx: Self::ProvingContext<'a>,
83    ) -> Self::Proof {
84        assert!(mpk.validate(&ctx), "Invalid proof input");
85        self.challenger.observe(mpk.vk_pre_hash.clone());
86
87        let num_air = ctx.per_air.len();
88        self.challenger.observe(Val::<SC>::from_usize(num_air));
89        info!(num_air);
90        #[allow(clippy::type_complexity)]
91        let (cached_commits_per_air, cached_views_per_air, common_main_per_air, pvs_per_air): (
92            Vec<Vec<PB::Commitment>>,
93            Vec<Vec<(PB::Matrix, PB::PcsData)>>,
94            Vec<Option<PB::Matrix>>,
95            Vec<Vec<PB::Val>>,
96        ) = ctx
97            .into_iter()
98            .map(|(air_id, ctx)| {
99                self.challenger.observe(Val::<SC>::from_usize(air_id));
100                let (cached_commits, cached_views): (Vec<_>, Vec<_>) =
101                    ctx.cached_mains.into_iter().map(|cm| (cm.commitment, (cm.trace, cm.data))).unzip();
102                (
103                    cached_commits,
104                    cached_views,
105                    ctx.common_main,
106                    ctx.public_values,
107                )
108            })
109            .multiunzip();
110
111        // ==================== All trace commitments that do not require challenges
112        // ==================== Commit all common main traces in a commitment. Traces inside
113        // are ordered by AIR id.
114        let (common_main_traces, (common_main_commit, common_main_pcs_data)) =
115            info_span!("main_trace_commit").in_scope(|| {
116                let traces = common_main_per_air.into_iter().flatten().collect_vec();
117                let prover_data = self.device.commit(&traces);
118                (traces, prover_data)
119            });
120
121        // Commitments order:
122        // - for each air:
123        //   - for each cached main trace
124        //     - 1 commitment
125        // - 1 commitment of all common main traces
126        let main_trace_commitments: Vec<PB::Commitment> = cached_commits_per_air
127            .iter()
128            .flatten()
129            .chain(iter::once(&common_main_commit))
130            .cloned()
131            .collect();
132
133        // All commitments that don't require challenges have been made, so we collect them into
134        // trace views:
135        let mut common_main_traces_it = common_main_traces.into_iter();
136        let mut log_trace_height_per_air: Vec<u8> = Vec::with_capacity(num_air);
137        let mut air_trace_views_per_air = Vec::with_capacity(num_air);
138        let mut cached_pcs_datas_per_air = Vec::with_capacity(num_air);
139        for (pk, cached_views, pvs) in izip!(&mpk.per_air, cached_views_per_air, &pvs_per_air) {
140            let (mut main_trace_views, cached_pcs_datas): (Vec<PB::Matrix>, Vec<PB::PcsData>) =
141                cached_views.into_iter().unzip();
142            cached_pcs_datas_per_air.push(cached_pcs_datas);
143            if pk.vk.has_common_main() {
144                main_trace_views.push(common_main_traces_it.next().expect("expected common main"));
145            }
146            let trace_height = main_trace_views.first().expect("no main trace").height();
147            let log_trace_height: u8 = log2_strict_usize(trace_height).try_into().unwrap();
148            let air_trace_view = AirView {
149                partitioned_main: main_trace_views,
150                public_values: pvs.to_vec(),
151            };
152            log_trace_height_per_air.push(log_trace_height);
153            air_trace_views_per_air.push(air_trace_view);
154        }
155        #[cfg(feature = "metrics")]
156        trace_metrics(&mpk, &log_trace_height_per_air).emit();
157
158        // ============ Challenger observations before additional RAP phases =============
159        // Observe public values:
160        for pvs in &pvs_per_air {
161            self.challenger.observe_slice(pvs);
162        }
163
164        // Observes preprocessed and main commitments:
165        let mvk = mpk.vk_view();
166        let preprocessed_commits = mvk.flattened_preprocessed_commits();
167        self.challenger.observe_slice(&preprocessed_commits);
168        self.challenger.observe_slice(&main_trace_commitments);
169        // Observe trace domain size per air:
170        self.challenger.observe_slice(
171            &log_trace_height_per_air
172                .iter()
173                .copied()
174                .map(|h| Val::<SC>::from_usize(h as usize))
175                .collect_vec(),
176        );
177
178        // ==================== Partially prove all RAP phases that require challenges
179        // ====================
180        let (rap_partial_proof, prover_data_after) =
181            self.device
182                .partially_prove(&mut self.challenger, &mpk, air_trace_views_per_air);
183        // At this point, main trace should be dropped
184
185        // Challenger observes additional commitments if any exist:
186        for (commit, _) in &prover_data_after.committed_pcs_data_per_phase {
187            self.challenger.observe(commit.clone());
188        }
189
190        // Collect exposed_values_per_air for the proof:
191        // - transpose per_phase, per_air -> per_air, per_phase
192        let exposed_values_per_air = (0..num_air)
193            .map(|i| {
194                let mut values = prover_data_after
195                    .rap_views_per_phase
196                    .iter()
197                    .map(|per_air| {
198                        per_air
199                            .get(i)
200                            .and_then(|v| v.inner.map(|_| v.exposed_values.clone()))
201                    })
202                    .collect_vec();
203                // Prune Nones
204                while let Some(last) = values.last() {
205                    if last.is_none() {
206                        values.pop();
207                    } else {
208                        break;
209                    }
210                }
211                values
212                    .into_iter()
213                    .map(|v| v.unwrap_or_default())
214                    .collect_vec()
215            })
216            .collect_vec();
217
218        // ==================== Quotient polynomial computation and commitment, if any
219        // ==================== Note[jpw]: Currently we always call this step, we could add
220        // a flag to skip it for protocols that do not require quotient poly.
221        let (quotient_commit, quotient_data) = self.device.eval_and_commit_quotient(
222            &mut self.challenger,
223            &mpk.per_air,
224            &pvs_per_air,
225            &cached_pcs_datas_per_air,
226            &common_main_pcs_data,
227            &prover_data_after,
228        );
229        // Observe quotient commitment
230        self.challenger.observe(quotient_commit.clone());
231
232        let (commitments_after, pcs_data_after): (Vec<_>, Vec<_>) = prover_data_after
233            .committed_pcs_data_per_phase
234            .into_iter()
235            .unzip();
236        // ==================== Polynomial Opening Proofs ====================
237        let opening = info_span!("pcs_opening").in_scope(|| {
238            let mut quotient_degrees = Vec::with_capacity(mpk.per_air.len());
239            let mut preprocessed = Vec::new();
240
241            for pk in mpk.per_air {
242                quotient_degrees.push(pk.vk.quotient_degree);
243                if let Some(preprocessed_data) = &pk.preprocessed_data {
244                    preprocessed.push(&preprocessed_data.data);
245                }
246            }
247
248            let main = cached_pcs_datas_per_air
249                .into_iter()
250                .flatten()
251                .chain(iter::once(common_main_pcs_data))
252                .collect();
253            self.device.open(
254                &mut self.challenger,
255                preprocessed,
256                main,
257                pcs_data_after,
258                quotient_data,
259                &quotient_degrees,
260            )
261        });
262
263        // ==================== Collect data into proof ====================
264        // Collect the commitments
265        let commitments = Commitments {
266            main_trace: main_trace_commitments,
267            after_challenge: commitments_after,
268            quotient: quotient_commit,
269        };
270        HalProof {
271            commitments,
272            opening,
273            per_air: izip!(
274                &mpk.air_ids,
275                log_trace_height_per_air,
276                exposed_values_per_air,
277                pvs_per_air
278            )
279            .map(
280                |(&air_id, log_height, exposed_values, public_values)| AirProofData {
281                    air_id,
282                    degree: 1 << log_height,
283                    public_values,
284                    exposed_values_after_challenge: exposed_values,
285                },
286            )
287            .collect(),
288            rap_partial_proof,
289        }
290    }
291}
292
293impl<'a, PB: ProverBackend> DeviceMultiStarkProvingKeyView<'a, PB> {
294    pub(crate) fn validate(&self, ctx: &ProvingContext<PB>) -> bool {
295        ctx.per_air.len() == self.air_ids.len()
296            && ctx
297                .per_air
298                .iter()
299                .zip(&self.air_ids)
300                .all(|((id1, _), id2)| id1 == id2)
301            && ctx.per_air.iter().tuple_windows().all(|(a, b)| a.0 < b.0)
302    }
303
304    pub(crate) fn vk_view(&'a self) -> MultiStarkVerifyingKeyView<'a, PB::Val, PB::Commitment> {
305        MultiStarkVerifyingKeyView::new(
306            self.per_air.iter().map(|pk| &pk.vk).collect(),
307            self.trace_height_constraints,
308            self.vk_pre_hash.clone(),
309        )
310    }
311}