openvm_stark_backend/prover/
coordinator.rs

1use std::{iter, marker::PhantomData};
2
3use itertools::{izip, Itertools};
4use p3_challenger::CanObserve;
5use p3_field::FieldAlgebra;
6use p3_util::log2_strict_usize;
7use tracing::{info, info_span, instrument};
8
9use super::{
10    hal::{ProverBackend, ProverDevice},
11    types::{HalProof, ProvingContext},
12    Prover,
13};
14#[cfg(feature = "metrics")]
15use crate::prover::metrics::trace_metrics;
16use crate::{
17    config::{Com, StarkGenericConfig, Val},
18    keygen::view::MultiStarkVerifyingKeyView,
19    proof::{AirProofData, Commitments},
20    prover::{
21        hal::MatrixDimensions,
22        types::{AirView, DeviceMultiStarkProvingKeyView},
23    },
24};
25
26/// Host-to-device coordinator for full prover implementation.
27///
28/// The generics are:
29/// - `SC`: Stark configuration for proving key (from host)
30/// - `PB`: Prover backend types
31/// - `PD`: Prover device methods
32pub struct Coordinator<SC: StarkGenericConfig, PB, PD> {
33    pub backend: PB,
34    pub device: PD,
35    challenger: SC::Challenger,
36    phantom: PhantomData<(SC, PB)>,
37}
38
39impl<SC: StarkGenericConfig, PB, PD> Coordinator<SC, PB, PD> {
40    pub fn new(backend: PB, device: PD, challenger: SC::Challenger) -> Self {
41        Self {
42            backend,
43            device,
44            challenger,
45            phantom: PhantomData,
46        }
47    }
48}
49
50impl<SC, PB, PD> Prover for Coordinator<SC, PB, PD>
51where
52    SC: StarkGenericConfig,
53    PB: ProverBackend<
54        Val = Val<SC>,
55        Challenge = SC::Challenge,
56        Commitment = Com<SC>,
57        Challenger = SC::Challenger,
58    >,
59    PD: ProverDevice<PB>,
60{
61    type Proof = HalProof<PB>;
62    type ProvingKeyView<'a>
63        = DeviceMultiStarkProvingKeyView<'a, PB>
64    where
65        Self: 'a;
66
67    type ProvingContext<'a>
68        = ProvingContext<PB>
69    where
70        Self: 'a;
71
72    /// Specialized prove for InteractiveAirs.
73    /// Handles trace generation of the permutation traces.
74    /// Assumes the main traces have been generated and committed already.
75    ///
76    /// The [DeviceMultiStarkProvingKey] should already be filtered to only include the relevant
77    /// AIR's proving keys.
78    #[instrument(name = "stark_prove_excluding_trace", level = "info", skip_all)]
79    fn prove<'a>(
80        &'a mut self,
81        mpk: Self::ProvingKeyView<'a>,
82        ctx: Self::ProvingContext<'a>,
83    ) -> Self::Proof {
84        assert!(mpk.validate(&ctx), "Invalid proof input");
85        self.challenger.observe(mpk.vk_pre_hash.clone());
86
87        let num_air = ctx.per_air.len();
88        self.challenger
89            .observe(Val::<SC>::from_canonical_usize(num_air));
90        info!(num_air);
91        #[allow(clippy::type_complexity)]
92        let (cached_commits_per_air, cached_views_per_air, common_main_per_air, pvs_per_air): (
93            Vec<Vec<PB::Commitment>>,
94            Vec<Vec<(PB::Matrix, PB::PcsData)>>,
95            Vec<Option<PB::Matrix>>,
96            Vec<Vec<PB::Val>>,
97        ) = ctx
98            .into_iter()
99            .map(|(air_id, ctx)| {
100                self.challenger.observe(Val::<SC>::from_canonical_usize(air_id));
101                let (cached_commits, cached_views): (Vec<_>, Vec<_>) =
102                    ctx.cached_mains.into_iter().map(|cm| (cm.commitment, (cm.trace, cm.data))).unzip();
103                (
104                    cached_commits,
105                    cached_views,
106                    ctx.common_main,
107                    ctx.public_values,
108                )
109            })
110            .multiunzip();
111
112        // ==================== All trace commitments that do not require challenges
113        // ==================== Commit all common main traces in a commitment. Traces inside
114        // are ordered by AIR id.
115        let (common_main_traces, (common_main_commit, common_main_pcs_data)) =
116            info_span!("main_trace_commit").in_scope(|| {
117                let traces = common_main_per_air.into_iter().flatten().collect_vec();
118                let prover_data = self.device.commit(&traces);
119                (traces, prover_data)
120            });
121
122        // Commitments order:
123        // - for each air:
124        //   - for each cached main trace
125        //     - 1 commitment
126        // - 1 commitment of all common main traces
127        let main_trace_commitments: Vec<PB::Commitment> = cached_commits_per_air
128            .iter()
129            .flatten()
130            .chain(iter::once(&common_main_commit))
131            .cloned()
132            .collect();
133
134        // All commitments that don't require challenges have been made, so we collect them into
135        // trace views:
136        let mut common_main_traces_it = common_main_traces.into_iter();
137        let mut log_trace_height_per_air: Vec<u8> = Vec::with_capacity(num_air);
138        let mut air_trace_views_per_air = Vec::with_capacity(num_air);
139        let mut cached_pcs_datas_per_air = Vec::with_capacity(num_air);
140        for (pk, cached_views, pvs) in izip!(&mpk.per_air, cached_views_per_air, &pvs_per_air) {
141            let (mut main_trace_views, cached_pcs_datas): (Vec<PB::Matrix>, Vec<PB::PcsData>) =
142                cached_views.into_iter().unzip();
143            cached_pcs_datas_per_air.push(cached_pcs_datas);
144            if pk.vk.has_common_main() {
145                main_trace_views.push(common_main_traces_it.next().expect("expected common main"));
146            }
147            let trace_height = main_trace_views.first().expect("no main trace").height();
148            let log_trace_height: u8 = log2_strict_usize(trace_height).try_into().unwrap();
149            let air_trace_view = AirView {
150                partitioned_main: main_trace_views,
151                public_values: pvs.to_vec(),
152            };
153            log_trace_height_per_air.push(log_trace_height);
154            air_trace_views_per_air.push(air_trace_view);
155        }
156        #[cfg(feature = "metrics")]
157        trace_metrics(&mpk, &log_trace_height_per_air).emit();
158
159        // ============ Challenger observations before additional RAP phases =============
160        // Observe public values:
161        for pvs in &pvs_per_air {
162            self.challenger.observe_slice(pvs);
163        }
164
165        // Observes preprocessed and main commitments:
166        let mvk = mpk.vk_view();
167        let preprocessed_commits = mvk.flattened_preprocessed_commits();
168        self.challenger.observe_slice(&preprocessed_commits);
169        self.challenger.observe_slice(&main_trace_commitments);
170        // Observe trace domain size per air:
171        self.challenger.observe_slice(
172            &log_trace_height_per_air
173                .iter()
174                .copied()
175                .map(Val::<SC>::from_canonical_u8)
176                .collect_vec(),
177        );
178
179        // ==================== Partially prove all RAP phases that require challenges
180        // ====================
181        let (rap_partial_proof, prover_data_after) =
182            self.device
183                .partially_prove(&mut self.challenger, &mpk, air_trace_views_per_air);
184        // At this point, main trace should be dropped
185
186        // Challenger observes additional commitments if any exist:
187        for (commit, _) in &prover_data_after.committed_pcs_data_per_phase {
188            self.challenger.observe(commit.clone());
189        }
190
191        // Collect exposed_values_per_air for the proof:
192        // - transpose per_phase, per_air -> per_air, per_phase
193        let exposed_values_per_air = (0..num_air)
194            .map(|i| {
195                let mut values = prover_data_after
196                    .rap_views_per_phase
197                    .iter()
198                    .map(|per_air| {
199                        per_air
200                            .get(i)
201                            .and_then(|v| v.inner.map(|_| v.exposed_values.clone()))
202                    })
203                    .collect_vec();
204                // Prune Nones
205                while let Some(last) = values.last() {
206                    if last.is_none() {
207                        values.pop();
208                    } else {
209                        break;
210                    }
211                }
212                values
213                    .into_iter()
214                    .map(|v| v.unwrap_or_default())
215                    .collect_vec()
216            })
217            .collect_vec();
218
219        // ==================== Quotient polynomial computation and commitment, if any
220        // ==================== Note[jpw]: Currently we always call this step, we could add
221        // a flag to skip it for protocols that do not require quotient poly.
222        let (quotient_commit, quotient_data) = self.device.eval_and_commit_quotient(
223            &mut self.challenger,
224            &mpk.per_air,
225            &pvs_per_air,
226            &cached_pcs_datas_per_air,
227            &common_main_pcs_data,
228            &prover_data_after,
229        );
230        // Observe quotient commitment
231        self.challenger.observe(quotient_commit.clone());
232
233        let (commitments_after, pcs_data_after): (Vec<_>, Vec<_>) = prover_data_after
234            .committed_pcs_data_per_phase
235            .into_iter()
236            .unzip();
237        // ==================== Polynomial Opening Proofs ====================
238        let opening = info_span!("pcs_opening").in_scope(|| {
239            let mut quotient_degrees = Vec::with_capacity(mpk.per_air.len());
240            let mut preprocessed = Vec::new();
241
242            for pk in mpk.per_air {
243                quotient_degrees.push(pk.vk.quotient_degree);
244                if let Some(preprocessed_data) = &pk.preprocessed_data {
245                    preprocessed.push(&preprocessed_data.data);
246                }
247            }
248
249            let main = cached_pcs_datas_per_air
250                .into_iter()
251                .flatten()
252                .chain(iter::once(common_main_pcs_data))
253                .collect();
254            self.device.open(
255                &mut self.challenger,
256                preprocessed,
257                main,
258                pcs_data_after,
259                quotient_data,
260                &quotient_degrees,
261            )
262        });
263
264        // ==================== Collect data into proof ====================
265        // Collect the commitments
266        let commitments = Commitments {
267            main_trace: main_trace_commitments,
268            after_challenge: commitments_after,
269            quotient: quotient_commit,
270        };
271        HalProof {
272            commitments,
273            opening,
274            per_air: izip!(
275                &mpk.air_ids,
276                log_trace_height_per_air,
277                exposed_values_per_air,
278                pvs_per_air
279            )
280            .map(
281                |(&air_id, log_height, exposed_values, public_values)| AirProofData {
282                    air_id,
283                    degree: 1 << log_height,
284                    public_values,
285                    exposed_values_after_challenge: exposed_values,
286                },
287            )
288            .collect(),
289            rap_partial_proof,
290        }
291    }
292}
293
294impl<'a, PB: ProverBackend> DeviceMultiStarkProvingKeyView<'a, PB> {
295    pub(crate) fn validate(&self, ctx: &ProvingContext<PB>) -> bool {
296        ctx.per_air.len() == self.air_ids.len()
297            && ctx
298                .per_air
299                .iter()
300                .zip(&self.air_ids)
301                .all(|((id1, _), id2)| id1 == id2)
302            && ctx.per_air.iter().tuple_windows().all(|(a, b)| a.0 < b.0)
303    }
304
305    pub(crate) fn vk_view(&'a self) -> MultiStarkVerifyingKeyView<'a, PB::Val, PB::Commitment> {
306        MultiStarkVerifyingKeyView::new(
307            self.per_air.iter().map(|pk| &pk.vk).collect(),
308            self.trace_height_constraints,
309            self.vk_pre_hash.clone(),
310        )
311    }
312}