openvm_stark_backend/keygen/
mod.rs

1use std::{collections::HashMap, iter::zip, sync::Arc};
2
3use itertools::Itertools;
4use p3_commit::Pcs;
5use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra};
6use p3_matrix::{dense::RowMajorMatrix, Matrix};
7use tracing::instrument;
8use types::MultiStarkVerifyingKey0;
9
10use crate::{
11    air_builders::symbolic::{get_symbolic_builder, SymbolicRapBuilder},
12    config::{Com, RapPartialProvingKey, StarkGenericConfig, Val},
13    interaction::{RapPhaseSeq, RapPhaseSeqKind},
14    keygen::types::{
15        LinearConstraint, MultiStarkProvingKey, ProverOnlySinglePreprocessedData, StarkProvingKey,
16        StarkVerifyingKey, TraceWidth, VerifierSinglePreprocessedData,
17    },
18    rap::AnyRap,
19};
20
21pub mod types;
22pub mod view;
23
24struct AirKeygenBuilder<SC: StarkGenericConfig> {
25    air: Arc<dyn AnyRap<SC>>,
26    rap_phase_seq_kind: RapPhaseSeqKind,
27    prep_keygen_data: PrepKeygenData<SC>,
28}
29
30/// Stateful builder to create multi-stark proving and verifying keys
31/// for system of multiple RAPs with multiple multi-matrix commitments
32pub struct MultiStarkKeygenBuilder<'a, SC: StarkGenericConfig> {
33    pub config: &'a SC,
34    /// Information for partitioned AIRs.
35    partitioned_airs: Vec<AirKeygenBuilder<SC>>,
36    max_constraint_degree: usize,
37}
38
39impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> {
40    pub fn new(config: &'a SC) -> Self {
41        Self {
42            config,
43            partitioned_airs: vec![],
44            max_constraint_degree: 0,
45        }
46    }
47
48    /// The builder will **try** to keep the max constraint degree across all AIRs below this value.
49    /// If it is given AIRs that exceed this value, it will still include them.
50    ///
51    /// Currently this is only used for interaction chunking in FRI logup.
52    pub fn set_max_constraint_degree(&mut self, max_constraint_degree: usize) {
53        self.max_constraint_degree = max_constraint_degree;
54    }
55
56    /// Default way to add a single Interactive AIR.
57    /// Returns `air_id`
58    #[instrument(level = "debug", skip_all)]
59    pub fn add_air(&mut self, air: Arc<dyn AnyRap<SC>>) -> usize {
60        self.partitioned_airs.push(AirKeygenBuilder::new(
61            self.config.pcs(),
62            SC::RapPhaseSeq::ID,
63            air,
64        ));
65        self.partitioned_airs.len() - 1
66    }
67
68    /// Consume the builder and generate proving key.
69    /// The verifying key can be obtained from the proving key.
70    pub fn generate_pk(mut self) -> MultiStarkProvingKey<SC> {
71        let air_max_constraint_degree = self
72            .partitioned_airs
73            .iter()
74            .map(|keygen_builder| {
75                let max_constraint_degree = keygen_builder.max_constraint_degree();
76                tracing::debug!(
77                    "{} has constraint degree {}",
78                    keygen_builder.air.name(),
79                    max_constraint_degree
80                );
81                max_constraint_degree
82            })
83            .max()
84            .unwrap();
85        tracing::info!(
86            "Max constraint (excluding logup constraints) degree across all AIRs: {}",
87            air_max_constraint_degree
88        );
89        if self.max_constraint_degree != 0 && air_max_constraint_degree > self.max_constraint_degree
90        {
91            // This means the quotient polynomial is already going to be higher degree, so we
92            // might as well use it.
93            tracing::info!(
94                "Setting max_constraint_degree from {} to {air_max_constraint_degree}",
95                self.max_constraint_degree
96            );
97            self.max_constraint_degree = air_max_constraint_degree;
98        }
99        // First pass: get symbolic constraints and interactions but RAP phase constraints are not
100        // final
101        let symbolic_constraints_per_air = self
102            .partitioned_airs
103            .iter()
104            .map(|keygen_builder| keygen_builder.get_symbolic_builder(None).constraints())
105            .collect_vec();
106        // Note: due to the need to go through a trait, there is some duplicate computation
107        // (e.g., FRI logup will calculate the interaction chunking both here and in the second pass
108        // below)
109        let rap_partial_pk_per_air = self
110            .config
111            .rap_phase_seq()
112            .generate_pk_per_air(&symbolic_constraints_per_air, self.max_constraint_degree);
113        let pk_per_air: Vec<_> = zip(self.partitioned_airs, rap_partial_pk_per_air)
114            .map(|(keygen_builder, rap_partial_pk)| {
115                // Second pass: get final constraints, where RAP phase constraints may have changed
116                keygen_builder.generate_pk(rap_partial_pk, self.max_constraint_degree)
117            })
118            .collect();
119
120        for pk in pk_per_air.iter() {
121            let width = &pk.vk.params.width;
122            tracing::info!("{:<20} | Quotient Deg = {:<2} | Prep Cols = {:<2} | Main Cols = {:<8} | Perm Cols = {:<4} | {:4} Constraints | {:3} Interactions",
123                pk.air_name,
124                pk.vk.quotient_degree,
125                width.preprocessed.unwrap_or(0),
126                format!("{:?}",width.main_widths()),
127                format!("{:?}",width.after_challenge.iter().map(|&x| x * <SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D).collect_vec()),
128                pk.vk.symbolic_constraints.constraints.constraint_idx.len(),
129                pk.vk.symbolic_constraints.interactions.len(),
130            );
131            tracing::debug!(
132                "On Buses {:?}",
133                pk.vk
134                    .symbolic_constraints
135                    .interactions
136                    .iter()
137                    .map(|i| i.bus_index)
138                    .collect_vec()
139            );
140            #[cfg(feature = "metrics")]
141            {
142                let labels = [("air_name", pk.air_name.clone())];
143                metrics::counter!("quotient_deg", &labels).absolute(pk.vk.quotient_degree as u64);
144                // column info will be logged by prover later
145                metrics::counter!("constraints", &labels)
146                    .absolute(pk.vk.symbolic_constraints.constraints.constraint_idx.len() as u64);
147                metrics::counter!("interactions", &labels)
148                    .absolute(pk.vk.symbolic_constraints.interactions.len() as u64);
149            }
150        }
151
152        let num_airs = symbolic_constraints_per_air.len();
153        let base_order = Val::<SC>::order().to_u32_digits()[0];
154        let mut count_weight_per_air_per_bus_index = HashMap::new();
155
156        // We compute the a_i's for the constraints of the form a_0 n_0 + ... + a_{k-1} n_{k-1} <
157        // a_k, First the constraints that the total number of interactions on each bus is
158        // at most the base field order.
159        for (i, constraints_per_air) in symbolic_constraints_per_air.iter().enumerate() {
160            for interaction in &constraints_per_air.interactions {
161                // Also make sure that this of interaction is valid given the security params.
162                // +1 because of the bus
163                let max_msg_len = self
164                    .config
165                    .rap_phase_seq()
166                    .log_up_security_params()
167                    .max_message_length();
168                // plus one because of the bus
169                let total_message_length = interaction.message.len() + 1;
170                assert!(
171                    total_message_length <= max_msg_len,
172                    "interaction message with bus has length {}, which is more than max {max_msg_len}",
173                    total_message_length,
174                );
175
176                let b = interaction.bus_index;
177                let constraint = count_weight_per_air_per_bus_index
178                    .entry(b)
179                    .or_insert_with(|| LinearConstraint {
180                        coefficients: vec![0; num_airs],
181                        threshold: base_order,
182                    });
183                constraint.coefficients[i] += interaction.count_weight;
184            }
185        }
186
187        // Sorting by bus index is not necessary, but makes debugging/testing easier.
188        let mut trace_height_constraints = count_weight_per_air_per_bus_index
189            .into_iter()
190            .sorted_by_key(|(bus_index, _)| *bus_index)
191            .map(|(_, constraint)| constraint)
192            .collect_vec();
193
194        let log_up_security_params = self.config.rap_phase_seq().log_up_security_params();
195
196        // Add a constraint for the total number of interactions.
197        trace_height_constraints.push(LinearConstraint {
198            coefficients: symbolic_constraints_per_air
199                .iter()
200                .map(|c| c.interactions.len() as u32)
201                .collect(),
202            threshold: log_up_security_params.max_interaction_count,
203        });
204
205        let pre_vk: MultiStarkVerifyingKey0<SC> = MultiStarkVerifyingKey0 {
206            per_air: pk_per_air.iter().map(|pk| pk.vk.clone()).collect(),
207            trace_height_constraints: trace_height_constraints.clone(),
208            log_up_pow_bits: log_up_security_params.log_up_pow_bits,
209        };
210        // To protect against weak Fiat-Shamir, we hash the "pre"-verifying key and include it in
211        // the final verifying key. This just needs to commit to the verifying key and does
212        // not need to be verified by the verifier, so we just use bincode to serialize it.
213        let vk_bytes = bitcode::serialize(&pre_vk).unwrap();
214        tracing::info!("pre-vkey: {} bytes", vk_bytes.len());
215        // Purely to get type compatibility and convenience, we hash using pcs.commit as a single
216        // row
217        let vk_as_row = RowMajorMatrix::new_row(
218            vk_bytes
219                .into_iter()
220                .map(Val::<SC>::from_canonical_u8)
221                .collect(),
222        );
223        let pcs = self.config.pcs();
224        let deg_1_domain = pcs.natural_domain_for_degree(1);
225        let (vk_pre_hash, _) = pcs.commit(vec![(deg_1_domain, vk_as_row)]);
226
227        MultiStarkProvingKey {
228            per_air: pk_per_air,
229            trace_height_constraints,
230            max_constraint_degree: self.max_constraint_degree,
231            log_up_pow_bits: log_up_security_params.log_up_pow_bits,
232            vk_pre_hash,
233        }
234    }
235}
236
237impl<SC: StarkGenericConfig> AirKeygenBuilder<SC> {
238    fn new(pcs: &SC::Pcs, rap_phase_seq_kind: RapPhaseSeqKind, air: Arc<dyn AnyRap<SC>>) -> Self {
239        let prep_keygen_data = compute_prep_data_for_air(pcs, air.as_ref());
240        AirKeygenBuilder {
241            air,
242            rap_phase_seq_kind,
243            prep_keygen_data,
244        }
245    }
246
247    fn max_constraint_degree(&self) -> usize {
248        self.get_symbolic_builder(None)
249            .constraints()
250            .max_constraint_degree()
251    }
252
253    fn generate_pk(
254        self,
255        rap_partial_pk: RapPartialProvingKey<SC>,
256        max_constraint_degree: usize,
257    ) -> StarkProvingKey<SC> {
258        let air_name = self.air.name();
259
260        let symbolic_builder = self.get_symbolic_builder(Some(max_constraint_degree));
261        let params = symbolic_builder.params();
262        let symbolic_constraints = symbolic_builder.constraints();
263        let log_quotient_degree = symbolic_constraints.get_log_quotient_degree();
264        let quotient_degree = 1 << log_quotient_degree;
265
266        let Self {
267            prep_keygen_data:
268                PrepKeygenData {
269                    verifier_data: prep_verifier_data,
270                    prover_data: prep_prover_data,
271                },
272            ..
273        } = self;
274
275        let vk: StarkVerifyingKey<Val<SC>, Com<SC>> = StarkVerifyingKey {
276            preprocessed_data: prep_verifier_data,
277            params,
278            symbolic_constraints: symbolic_constraints.into(),
279            quotient_degree,
280            rap_phase_seq_kind: self.rap_phase_seq_kind,
281        };
282        StarkProvingKey {
283            air_name,
284            vk,
285            preprocessed_data: prep_prover_data,
286            rap_partial_pk,
287        }
288    }
289
290    fn get_symbolic_builder(
291        &self,
292        max_constraint_degree: Option<usize>,
293    ) -> SymbolicRapBuilder<Val<SC>> {
294        let width = TraceWidth {
295            preprocessed: self.prep_keygen_data.width(),
296            cached_mains: self.air.cached_main_widths(),
297            common_main: self.air.common_main_width(),
298            after_challenge: vec![],
299        };
300        get_symbolic_builder(
301            self.air.as_ref(),
302            &width,
303            &[],
304            &[],
305            SC::RapPhaseSeq::ID,
306            max_constraint_degree.unwrap_or(0),
307        )
308    }
309}
310
311pub(super) struct PrepKeygenData<SC: StarkGenericConfig> {
312    pub verifier_data: Option<VerifierSinglePreprocessedData<Com<SC>>>,
313    pub prover_data: Option<ProverOnlySinglePreprocessedData<SC>>,
314}
315
316impl<SC: StarkGenericConfig> PrepKeygenData<SC> {
317    pub fn width(&self) -> Option<usize> {
318        self.prover_data.as_ref().map(|d| d.trace.width())
319    }
320}
321
322fn compute_prep_data_for_air<SC: StarkGenericConfig>(
323    pcs: &SC::Pcs,
324    air: &dyn AnyRap<SC>,
325) -> PrepKeygenData<SC> {
326    let preprocessed_trace = air.preprocessed_trace();
327    let vpdata_opt = preprocessed_trace.map(|trace| {
328        let domain = pcs.natural_domain_for_degree(trace.height());
329        let (commit, data) = pcs.commit(vec![(domain, trace.clone())]);
330        let vdata = VerifierSinglePreprocessedData { commit };
331        let pdata = ProverOnlySinglePreprocessedData {
332            trace: Arc::new(trace),
333            data: Arc::new(data),
334        };
335        (vdata, pdata)
336    });
337    if let Some((vdata, pdata)) = vpdata_opt {
338        PrepKeygenData {
339            prover_data: Some(pdata),
340            verifier_data: Some(vdata),
341        }
342    } else {
343        PrepKeygenData {
344            prover_data: None,
345            verifier_data: None,
346        }
347    }
348}