openvm_stark_backend/keygen/
mod.rs

1use std::{collections::HashMap, iter::zip, sync::Arc};
2
3use itertools::Itertools;
4use p3_commit::Pcs;
5use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra};
6use p3_matrix::{dense::RowMajorMatrix, Matrix};
7use tracing::instrument;
8use types::MultiStarkVerifyingKey0;
9
10use crate::{
11    air_builders::symbolic::{get_symbolic_builder, SymbolicRapBuilder},
12    config::{Com, RapPartialProvingKey, StarkGenericConfig, Val},
13    interaction::{RapPhaseSeq, RapPhaseSeqKind},
14    keygen::types::{
15        LinearConstraint, MultiStarkProvingKey, ProverOnlySinglePreprocessedData, StarkProvingKey,
16        StarkVerifyingKey, TraceWidth, VerifierSinglePreprocessedData,
17    },
18    rap::AnyRap,
19};
20
21pub mod types;
22pub(crate) mod view;
23
24struct AirKeygenBuilder<SC: StarkGenericConfig> {
25    air: Arc<dyn AnyRap<SC>>,
26    rap_phase_seq_kind: RapPhaseSeqKind,
27    prep_keygen_data: PrepKeygenData<SC>,
28}
29
30/// Stateful builder to create multi-stark proving and verifying keys
31/// for system of multiple RAPs with multiple multi-matrix commitments
32pub struct MultiStarkKeygenBuilder<'a, SC: StarkGenericConfig> {
33    pub config: &'a SC,
34    /// Information for partitioned AIRs.
35    partitioned_airs: Vec<AirKeygenBuilder<SC>>,
36    max_constraint_degree: usize,
37}
38
39impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> {
40    pub fn new(config: &'a SC) -> Self {
41        Self {
42            config,
43            partitioned_airs: vec![],
44            max_constraint_degree: 0,
45        }
46    }
47
48    /// The builder will **try** to keep the max constraint degree across all AIRs below this value.
49    /// If it is given AIRs that exceed this value, it will still include them.
50    ///
51    /// Currently this is only used for interaction chunking in FRI logup.
52    pub fn set_max_constraint_degree(&mut self, max_constraint_degree: usize) {
53        self.max_constraint_degree = max_constraint_degree;
54    }
55
56    /// Default way to add a single Interactive AIR.
57    /// Returns `air_id`
58    #[instrument(level = "debug", skip_all)]
59    pub fn add_air(&mut self, air: Arc<dyn AnyRap<SC>>) -> usize {
60        self.partitioned_airs.push(AirKeygenBuilder::new(
61            self.config.pcs(),
62            SC::RapPhaseSeq::ID,
63            air,
64        ));
65        self.partitioned_airs.len() - 1
66    }
67
68    /// Consume the builder and generate proving key.
69    /// The verifying key can be obtained from the proving key.
70    pub fn generate_pk(mut self) -> MultiStarkProvingKey<SC> {
71        let air_max_constraint_degree = self
72            .partitioned_airs
73            .iter()
74            .map(|keygen_builder| {
75                let max_constraint_degree = keygen_builder.max_constraint_degree();
76                tracing::debug!(
77                    "{} has constraint degree {}",
78                    keygen_builder.air.name(),
79                    max_constraint_degree
80                );
81                max_constraint_degree
82            })
83            .max()
84            .unwrap();
85        tracing::info!(
86            "Max constraint (excluding logup constraints) degree across all AIRs: {}",
87            air_max_constraint_degree
88        );
89        if self.max_constraint_degree != 0 && air_max_constraint_degree > self.max_constraint_degree
90        {
91            // This means the quotient polynomial is already going to be higher degree, so we
92            // might as well use it.
93            tracing::info!(
94                "Setting max_constraint_degree from {} to {air_max_constraint_degree}",
95                self.max_constraint_degree
96            );
97            self.max_constraint_degree = air_max_constraint_degree;
98        }
99        // First pass: get symbolic constraints and interactions but RAP phase constraints are not final
100        let symbolic_constraints_per_air = self
101            .partitioned_airs
102            .iter()
103            .map(|keygen_builder| keygen_builder.get_symbolic_builder(None).constraints())
104            .collect_vec();
105        // Note: due to the need to go through a trait, there is some duplicate computation
106        // (e.g., FRI logup will calculate the interaction chunking both here and in the second pass below)
107        let rap_partial_pk_per_air = self
108            .config
109            .rap_phase_seq()
110            .generate_pk_per_air(&symbolic_constraints_per_air, self.max_constraint_degree);
111        let pk_per_air: Vec<_> = zip(self.partitioned_airs, rap_partial_pk_per_air)
112            .map(|(keygen_builder, rap_partial_pk)| {
113                // Second pass: get final constraints, where RAP phase constraints may have changed
114                keygen_builder.generate_pk(rap_partial_pk, self.max_constraint_degree)
115            })
116            .collect();
117
118        for pk in pk_per_air.iter() {
119            let width = &pk.vk.params.width;
120            tracing::info!("{:<20} | Quotient Deg = {:<2} | Prep Cols = {:<2} | Main Cols = {:<8} | Perm Cols = {:<4} | {:4} Constraints | {:3} Interactions On Buses {:?}",
121                pk.air_name,
122                pk.vk.quotient_degree,
123                width.preprocessed.unwrap_or(0),
124                format!("{:?}",width.main_widths()),
125                format!("{:?}",width.after_challenge.iter().map(|&x| x * <SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D).collect_vec()),
126                pk.vk.symbolic_constraints.constraints.constraint_idx.len(),
127                pk.vk.symbolic_constraints.interactions.len(),
128                pk.vk
129                    .symbolic_constraints
130                    .interactions
131                    .iter()
132                    .map(|i| i.bus_index)
133                    .collect_vec()
134            );
135            #[cfg(feature = "bench-metrics")]
136            {
137                let labels = [("air_name", pk.air_name.clone())];
138                metrics::counter!("quotient_deg", &labels).absolute(pk.vk.quotient_degree as u64);
139                // column info will be logged by prover later
140                metrics::counter!("constraints", &labels)
141                    .absolute(pk.vk.symbolic_constraints.constraints.constraint_idx.len() as u64);
142                metrics::counter!("interactions", &labels)
143                    .absolute(pk.vk.symbolic_constraints.interactions.len() as u64);
144            }
145        }
146
147        let num_airs = symbolic_constraints_per_air.len();
148        let base_order = Val::<SC>::order().to_u32_digits()[0];
149        let mut count_weight_per_air_per_bus_index = HashMap::new();
150
151        // We compute the a_i's for the constraints of the form a_0 n_0 + ... + a_{k-1} n_{k-1} < a_k,
152        // First the constraints that the total number of interactions on each bus is at most the base field order.
153        for (i, constraints_per_air) in symbolic_constraints_per_air.iter().enumerate() {
154            for interaction in &constraints_per_air.interactions {
155                // Also make sure that this of interaction is valid given the security params.
156                // +1 because of the bus
157                let max_msg_len = self
158                    .config
159                    .rap_phase_seq()
160                    .log_up_security_params()
161                    .max_message_length();
162                // plus one because of the bus
163                let total_message_length = interaction.message.len() + 1;
164                assert!(
165                    total_message_length <= max_msg_len,
166                    "interaction message with bus has length {}, which is more than max {max_msg_len}",
167                    total_message_length,
168                );
169
170                let b = interaction.bus_index;
171                let constraint = count_weight_per_air_per_bus_index
172                    .entry(b)
173                    .or_insert_with(|| LinearConstraint {
174                        coefficients: vec![0; num_airs],
175                        threshold: base_order,
176                    });
177                constraint.coefficients[i] += interaction.count_weight;
178            }
179        }
180
181        // Sorting by bus index is not necessary, but makes debugging/testing easier.
182        let mut trace_height_constraints = count_weight_per_air_per_bus_index
183            .into_iter()
184            .sorted_by_key(|(bus_index, _)| *bus_index)
185            .map(|(_, constraint)| constraint)
186            .collect_vec();
187
188        let log_up_security_params = self.config.rap_phase_seq().log_up_security_params();
189
190        // Add a constraint for the total number of interactions.
191        trace_height_constraints.push(LinearConstraint {
192            coefficients: symbolic_constraints_per_air
193                .iter()
194                .map(|c| c.interactions.len() as u32)
195                .collect(),
196            threshold: log_up_security_params.max_interaction_count,
197        });
198
199        let pre_vk: MultiStarkVerifyingKey0<SC> = MultiStarkVerifyingKey0 {
200            per_air: pk_per_air.iter().map(|pk| pk.vk.clone()).collect(),
201            trace_height_constraints: trace_height_constraints.clone(),
202            log_up_pow_bits: log_up_security_params.log_up_pow_bits,
203        };
204        // To protect against weak Fiat-Shamir, we hash the "pre"-verifying key and include it in the
205        // final verifying key. This just needs to commit to the verifying key and does not need to be
206        // verified by the verifier, so we just use bincode to serialize it.
207        let vk_bytes = bitcode::serialize(&pre_vk).unwrap();
208        tracing::info!("pre-vkey: {} bytes", vk_bytes.len());
209        // Purely to get type compatibility and convenience, we hash using pcs.commit as a single row
210        let vk_as_row = RowMajorMatrix::new_row(
211            vk_bytes
212                .into_iter()
213                .map(Val::<SC>::from_canonical_u8)
214                .collect(),
215        );
216        let pcs = self.config.pcs();
217        let deg_1_domain = pcs.natural_domain_for_degree(1);
218        let (vk_pre_hash, _) = pcs.commit(vec![(deg_1_domain, vk_as_row)]);
219
220        MultiStarkProvingKey {
221            per_air: pk_per_air,
222            trace_height_constraints,
223            max_constraint_degree: self.max_constraint_degree,
224            log_up_pow_bits: log_up_security_params.log_up_pow_bits,
225            vk_pre_hash,
226        }
227    }
228}
229
230impl<SC: StarkGenericConfig> AirKeygenBuilder<SC> {
231    fn new(pcs: &SC::Pcs, rap_phase_seq_kind: RapPhaseSeqKind, air: Arc<dyn AnyRap<SC>>) -> Self {
232        let prep_keygen_data = compute_prep_data_for_air(pcs, air.as_ref());
233        AirKeygenBuilder {
234            air,
235            rap_phase_seq_kind,
236            prep_keygen_data,
237        }
238    }
239
240    fn max_constraint_degree(&self) -> usize {
241        self.get_symbolic_builder(None)
242            .constraints()
243            .max_constraint_degree()
244    }
245
246    fn generate_pk(
247        self,
248        rap_partial_pk: RapPartialProvingKey<SC>,
249        max_constraint_degree: usize,
250    ) -> StarkProvingKey<SC> {
251        let air_name = self.air.name();
252
253        let symbolic_builder = self.get_symbolic_builder(Some(max_constraint_degree));
254        let params = symbolic_builder.params();
255        let symbolic_constraints = symbolic_builder.constraints();
256        let log_quotient_degree = symbolic_constraints.get_log_quotient_degree();
257        let quotient_degree = 1 << log_quotient_degree;
258
259        let Self {
260            prep_keygen_data:
261                PrepKeygenData {
262                    verifier_data: prep_verifier_data,
263                    prover_data: prep_prover_data,
264                },
265            ..
266        } = self;
267
268        let vk: StarkVerifyingKey<Val<SC>, Com<SC>> = StarkVerifyingKey {
269            preprocessed_data: prep_verifier_data,
270            params,
271            symbolic_constraints: symbolic_constraints.into(),
272            quotient_degree,
273            rap_phase_seq_kind: self.rap_phase_seq_kind,
274        };
275        StarkProvingKey {
276            air_name,
277            vk,
278            preprocessed_data: prep_prover_data,
279            rap_partial_pk,
280        }
281    }
282
283    fn get_symbolic_builder(
284        &self,
285        max_constraint_degree: Option<usize>,
286    ) -> SymbolicRapBuilder<Val<SC>> {
287        let width = TraceWidth {
288            preprocessed: self.prep_keygen_data.width(),
289            cached_mains: self.air.cached_main_widths(),
290            common_main: self.air.common_main_width(),
291            after_challenge: vec![],
292        };
293        get_symbolic_builder(
294            self.air.as_ref(),
295            &width,
296            &[],
297            &[],
298            SC::RapPhaseSeq::ID,
299            max_constraint_degree.unwrap_or(0),
300        )
301    }
302}
303
304pub(super) struct PrepKeygenData<SC: StarkGenericConfig> {
305    pub verifier_data: Option<VerifierSinglePreprocessedData<Com<SC>>>,
306    pub prover_data: Option<ProverOnlySinglePreprocessedData<SC>>,
307}
308
309impl<SC: StarkGenericConfig> PrepKeygenData<SC> {
310    pub fn width(&self) -> Option<usize> {
311        self.prover_data.as_ref().map(|d| d.trace.width())
312    }
313}
314
315fn compute_prep_data_for_air<SC: StarkGenericConfig>(
316    pcs: &SC::Pcs,
317    air: &dyn AnyRap<SC>,
318) -> PrepKeygenData<SC> {
319    let preprocessed_trace = air.preprocessed_trace();
320    let vpdata_opt = preprocessed_trace.map(|trace| {
321        let domain = pcs.natural_domain_for_degree(trace.height());
322        let (commit, data) = pcs.commit(vec![(domain, trace.clone())]);
323        let vdata = VerifierSinglePreprocessedData { commit };
324        let pdata = ProverOnlySinglePreprocessedData {
325            trace: Arc::new(trace),
326            data: Arc::new(data),
327        };
328        (vdata, pdata)
329    });
330    if let Some((vdata, pdata)) = vpdata_opt {
331        PrepKeygenData {
332            prover_data: Some(pdata),
333            verifier_data: Some(vdata),
334        }
335    } else {
336        PrepKeygenData {
337            prover_data: None,
338            verifier_data: None,
339        }
340    }
341}