openvm_stark_backend/keygen/
mod.rs

1use std::{collections::HashMap, iter::zip, sync::Arc};
2
3use itertools::Itertools;
4use p3_commit::Pcs;
5use p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing};
6use p3_matrix::{dense::RowMajorMatrix, Matrix};
7use tracing::instrument;
8use types::MultiStarkVerifyingKey0;
9
10use crate::{
11    air_builders::symbolic::{get_symbolic_builder, SymbolicRapBuilder},
12    config::{Com, RapPartialProvingKey, StarkGenericConfig, Val},
13    interaction::{RapPhaseSeq, RapPhaseSeqKind},
14    keygen::types::{
15        LinearConstraint, MultiStarkProvingKey, ProverOnlySinglePreprocessedData, StarkProvingKey,
16        StarkVerifyingKey, TraceWidth, VerifierSinglePreprocessedData,
17    },
18    rap::AnyRap,
19};
20
21pub mod types;
22pub mod view;
23
24struct AirKeygenBuilder<SC: StarkGenericConfig> {
25    air: Arc<dyn AnyRap<SC>>,
26    rap_phase_seq_kind: RapPhaseSeqKind,
27    prep_keygen_data: PrepKeygenData<SC>,
28}
29
30/// Stateful builder to create multi-stark proving and verifying keys
31/// for system of multiple RAPs with multiple multi-matrix commitments
32pub struct MultiStarkKeygenBuilder<'a, SC: StarkGenericConfig> {
33    pub config: &'a SC,
34    pub max_batch_size: Option<usize>,
35    pub max_num_constraints: Option<usize>,
36    /// Information for partitioned AIRs.
37    partitioned_airs: Vec<AirKeygenBuilder<SC>>,
38    max_constraint_degree: usize,
39}
40
41impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> {
42    pub fn new(config: &'a SC) -> Self {
43        Self {
44            config,
45            max_batch_size: None,
46            max_num_constraints: None,
47            partitioned_airs: vec![],
48            max_constraint_degree: 0,
49        }
50    }
51
52    /// The builder will **try** to keep the max constraint degree across all AIRs below this value.
53    /// If it is given AIRs that exceed this value, it will still include them.
54    ///
55    /// Currently this is only used for interaction chunking in FRI logup.
56    pub fn set_max_constraint_degree(&mut self, max_constraint_degree: usize) {
57        self.max_constraint_degree = max_constraint_degree;
58    }
59
60    /// Default way to add a single Interactive AIR.
61    /// Returns `air_id`
62    #[instrument(level = "debug", skip_all)]
63    pub fn add_air(&mut self, air: Arc<dyn AnyRap<SC>>) -> usize {
64        self.partitioned_airs.push(AirKeygenBuilder::new(
65            self.config.pcs(),
66            SC::RapPhaseSeq::ID,
67            air,
68        ));
69        self.partitioned_airs.len() - 1
70    }
71
72    /// Consume the builder and generate proving key.
73    /// The verifying key can be obtained from the proving key.
74    pub fn generate_pk(mut self) -> MultiStarkProvingKey<SC> {
75        let air_max_constraint_degree = self
76            .partitioned_airs
77            .iter()
78            .map(|keygen_builder| {
79                let max_constraint_degree = keygen_builder.max_constraint_degree();
80                tracing::debug!(
81                    "{} has constraint degree {}",
82                    keygen_builder.air.name(),
83                    max_constraint_degree
84                );
85                max_constraint_degree
86            })
87            .max()
88            .unwrap();
89        tracing::info!(
90            "Max constraint (excluding logup constraints) degree across all AIRs: {}",
91            air_max_constraint_degree
92        );
93        if self.max_constraint_degree != 0 && air_max_constraint_degree > self.max_constraint_degree
94        {
95            // This means the quotient polynomial is already going to be higher degree, so we
96            // might as well use it.
97            tracing::info!(
98                "Setting max_constraint_degree from {} to {air_max_constraint_degree}",
99                self.max_constraint_degree
100            );
101            self.max_constraint_degree = air_max_constraint_degree;
102        }
103        // First pass: get symbolic constraints and interactions but RAP phase constraints are not
104        // final
105        let symbolic_constraints_per_air = self
106            .partitioned_airs
107            .iter()
108            .map(|keygen_builder| keygen_builder.get_symbolic_builder(None).constraints())
109            .collect_vec();
110        // Note: due to the need to go through a trait, there is some duplicate computation
111        // (e.g., FRI logup will calculate the interaction chunking both here and in the second pass
112        // below)
113        let rap_partial_pk_per_air = self
114            .config
115            .rap_phase_seq()
116            .generate_pk_per_air(&symbolic_constraints_per_air, self.max_constraint_degree);
117        let pk_per_air: Vec<_> = zip(self.partitioned_airs, rap_partial_pk_per_air)
118            .map(|(keygen_builder, rap_partial_pk)| {
119                // Second pass: get final constraints, where RAP phase constraints may have changed
120                keygen_builder.generate_pk(rap_partial_pk, self.max_constraint_degree)
121            })
122            .collect();
123
124        for pk in pk_per_air.iter() {
125            let width = &pk.vk.params.width;
126            tracing::info!("{:<20} | Quotient Deg = {:<2} | Prep Cols = {:<2} | Main Cols = {:<8} | Perm Cols = {:<4} | {:4} Constraints | {:3} Interactions",
127                pk.air_name,
128                pk.vk.quotient_degree,
129                width.preprocessed.unwrap_or(0),
130                format!("{:?}",width.main_widths()),
131                format!("{:?}",width.after_challenge.iter().map(|&x| x * <SC::Challenge as BasedVectorSpace<Val<SC>>>::DIMENSION).collect_vec()),
132                pk.vk.symbolic_constraints.constraints.constraint_idx.len(),
133                pk.vk.symbolic_constraints.interactions.len(),
134            );
135            tracing::debug!(
136                "On Buses {:?}",
137                pk.vk
138                    .symbolic_constraints
139                    .interactions
140                    .iter()
141                    .map(|i| i.bus_index)
142                    .collect_vec()
143            );
144            #[cfg(feature = "metrics")]
145            {
146                let labels = [("air_name", pk.air_name.clone())];
147                metrics::counter!("quotient_deg", &labels).absolute(pk.vk.quotient_degree as u64);
148                // column info will be logged by prover later
149                metrics::counter!("constraints", &labels)
150                    .absolute(pk.vk.symbolic_constraints.constraints.constraint_idx.len() as u64);
151                metrics::counter!("interactions", &labels)
152                    .absolute(pk.vk.symbolic_constraints.interactions.len() as u64);
153            }
154        }
155
156        let num_airs = symbolic_constraints_per_air.len();
157        let base_order = Val::<SC>::order().to_u32_digits()[0];
158        let mut count_weight_per_air_per_bus_index = HashMap::new();
159
160        // We compute the a_i's for the constraints of the form a_0 n_0 + ... + a_{k-1} n_{k-1} <
161        // a_k, First the constraints that the total number of interactions on each bus is
162        // at most the base field order.
163        for (i, constraints_per_air) in symbolic_constraints_per_air.iter().enumerate() {
164            for interaction in &constraints_per_air.interactions {
165                // Also make sure that this of interaction is valid given the security params.
166                // +1 because of the bus
167                let max_msg_len = self
168                    .config
169                    .rap_phase_seq()
170                    .log_up_security_params()
171                    .max_message_length();
172                // plus one because of the bus
173                let total_message_length = interaction.message.len() + 1;
174                assert!(
175                    total_message_length <= max_msg_len,
176                    "interaction message with bus has length {}, which is more than max {max_msg_len}",
177                    total_message_length,
178                );
179
180                let b = interaction.bus_index;
181                let constraint = count_weight_per_air_per_bus_index
182                    .entry(b)
183                    .or_insert_with(|| LinearConstraint {
184                        coefficients: vec![0; num_airs],
185                        threshold: base_order,
186                    });
187                constraint.coefficients[i] += interaction.count_weight;
188            }
189        }
190
191        // Sorting by bus index is not necessary, but makes debugging/testing easier.
192        let mut trace_height_constraints = count_weight_per_air_per_bus_index
193            .into_iter()
194            .sorted_by_key(|(bus_index, _)| *bus_index)
195            .map(|(_, constraint)| constraint)
196            .collect_vec();
197
198        let log_up_security_params = self.config.rap_phase_seq().log_up_security_params();
199
200        // Add a constraint for the total number of interactions.
201        trace_height_constraints.push(LinearConstraint {
202            coefficients: symbolic_constraints_per_air
203                .iter()
204                .map(|c| c.interactions.len() as u32)
205                .collect(),
206            threshold: log_up_security_params.max_interaction_count,
207        });
208
209        let deep_pow_bits = self.config.deep_ali_params().deep_pow_bits;
210        let pre_vk: MultiStarkVerifyingKey0<SC> = MultiStarkVerifyingKey0 {
211            per_air: pk_per_air.iter().map(|pk| pk.vk.clone()).collect(),
212            trace_height_constraints: trace_height_constraints.clone(),
213            log_up_pow_bits: log_up_security_params.log_up_pow_bits,
214            deep_pow_bits,
215        };
216        if let Some(max_batch_size) = self.max_batch_size {
217            let ext_degree = <SC::Challenge as BasedVectorSpace<Val<SC>>>::DIMENSION;
218            let mut total_width = pre_vk
219                .per_air
220                .iter()
221                .map(|vk| vk.params.width.total_width(ext_degree))
222                .sum::<usize>();
223            let quotient_deg = pre_vk
224                .per_air
225                .iter()
226                .map(|vk| vk.quotient_degree as usize)
227                .max()
228                .unwrap_or_default();
229            // Quotient polynomial contribution
230            total_width += quotient_deg * ext_degree;
231            tracing::info!(%total_width);
232            // x2 for rotation opening
233            assert!(
234                total_width * 2 <= max_batch_size,
235                "Maximum number of AIR columns exceeded for desired security level"
236            );
237        }
238        if let Some(max_num_constraints) = self.max_num_constraints {
239            let total_constraint_count = pre_vk
240                .per_air
241                .iter()
242                .map(|vk| vk.symbolic_constraints.constraints.num_constraints())
243                .sum::<usize>();
244            tracing::info!(%total_constraint_count);
245            assert!(
246                total_constraint_count <= max_num_constraints,
247                "Maximum number of constraints exceeded for desired security level"
248            );
249        }
250
251        // To protect against weak Fiat-Shamir, we hash the "pre"-verifying key and include it in
252        // the final verifying key. This just needs to commit to the verifying key and does
253        // not need to be verified by the verifier, so we just use bincode to serialize it.
254        let vk_bytes = bitcode::serialize(&pre_vk).unwrap();
255        tracing::info!("pre-vkey: {} bytes", vk_bytes.len());
256        // Purely to get type compatibility and convenience, we hash using pcs.commit as a single
257        // row
258        let vk_as_row =
259            RowMajorMatrix::new_row(vk_bytes.into_iter().map(Val::<SC>::from_u8).collect());
260        let pcs = self.config.pcs();
261        let deg_1_domain = pcs.natural_domain_for_degree(1);
262        let (vk_pre_hash, _) = pcs.commit(vec![(deg_1_domain, vk_as_row)]);
263
264        MultiStarkProvingKey {
265            per_air: pk_per_air,
266            trace_height_constraints,
267            max_constraint_degree: self.max_constraint_degree,
268            log_up_pow_bits: log_up_security_params.log_up_pow_bits,
269            deep_pow_bits,
270            vk_pre_hash,
271        }
272    }
273}
274
275impl<SC: StarkGenericConfig> AirKeygenBuilder<SC> {
276    fn new(pcs: &SC::Pcs, rap_phase_seq_kind: RapPhaseSeqKind, air: Arc<dyn AnyRap<SC>>) -> Self {
277        let prep_keygen_data = compute_prep_data_for_air(pcs, air.as_ref());
278        AirKeygenBuilder {
279            air,
280            rap_phase_seq_kind,
281            prep_keygen_data,
282        }
283    }
284
285    fn max_constraint_degree(&self) -> usize {
286        self.get_symbolic_builder(None)
287            .constraints()
288            .max_constraint_degree()
289    }
290
291    fn generate_pk(
292        self,
293        rap_partial_pk: RapPartialProvingKey<SC>,
294        max_constraint_degree: usize,
295    ) -> StarkProvingKey<SC> {
296        let air_name = self.air.name();
297
298        let symbolic_builder = self.get_symbolic_builder(Some(max_constraint_degree));
299        let params = symbolic_builder.params();
300        let symbolic_constraints = symbolic_builder.constraints();
301        let log_quotient_degree = symbolic_constraints.get_log_quotient_degree();
302        let quotient_degree = 1 << log_quotient_degree;
303
304        let Self {
305            prep_keygen_data:
306                PrepKeygenData {
307                    verifier_data: prep_verifier_data,
308                    prover_data: prep_prover_data,
309                },
310            ..
311        } = self;
312
313        let vk: StarkVerifyingKey<Val<SC>, Com<SC>> = StarkVerifyingKey {
314            preprocessed_data: prep_verifier_data,
315            params,
316            symbolic_constraints: symbolic_constraints.into(),
317            quotient_degree,
318            rap_phase_seq_kind: self.rap_phase_seq_kind,
319        };
320        StarkProvingKey {
321            air_name,
322            vk,
323            preprocessed_data: prep_prover_data,
324            rap_partial_pk,
325        }
326    }
327
328    fn get_symbolic_builder(
329        &self,
330        max_constraint_degree: Option<usize>,
331    ) -> SymbolicRapBuilder<Val<SC>> {
332        let width = TraceWidth {
333            preprocessed: self.prep_keygen_data.width(),
334            cached_mains: self.air.cached_main_widths(),
335            common_main: self.air.common_main_width(),
336            after_challenge: vec![],
337        };
338        get_symbolic_builder(
339            self.air.as_ref(),
340            &width,
341            &[],
342            &[],
343            SC::RapPhaseSeq::ID,
344            max_constraint_degree.unwrap_or(0),
345        )
346    }
347}
348
349pub(super) struct PrepKeygenData<SC: StarkGenericConfig> {
350    pub verifier_data: Option<VerifierSinglePreprocessedData<Com<SC>>>,
351    pub prover_data: Option<ProverOnlySinglePreprocessedData<SC>>,
352}
353
354impl<SC: StarkGenericConfig> PrepKeygenData<SC> {
355    pub fn width(&self) -> Option<usize> {
356        self.prover_data.as_ref().map(|d| d.trace.width())
357    }
358}
359
360fn compute_prep_data_for_air<SC: StarkGenericConfig>(
361    pcs: &SC::Pcs,
362    air: &dyn AnyRap<SC>,
363) -> PrepKeygenData<SC> {
364    let preprocessed_trace = air.preprocessed_trace();
365    let vpdata_opt = preprocessed_trace.map(|trace| {
366        let domain = pcs.natural_domain_for_degree(trace.height());
367        let (commit, data) = pcs.commit(vec![(domain, trace.clone())]);
368        let vdata = VerifierSinglePreprocessedData { commit };
369        let pdata = ProverOnlySinglePreprocessedData {
370            trace: Arc::new(trace),
371            data: Arc::new(data),
372        };
373        (vdata, pdata)
374    });
375    if let Some((vdata, pdata)) = vpdata_opt {
376        PrepKeygenData {
377            prover_data: Some(pdata),
378            verifier_data: Some(vdata),
379        }
380    } else {
381        PrepKeygenData {
382            prover_data: None,
383            verifier_data: None,
384        }
385    }
386}