openvm_stark_backend/keygen/
mod.rs
1use std::{collections::HashMap, iter::zip, sync::Arc};
2
3use itertools::Itertools;
4use p3_commit::Pcs;
5use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra};
6use p3_matrix::{dense::RowMajorMatrix, Matrix};
7use tracing::instrument;
8use types::MultiStarkVerifyingKey0;
9
10use crate::{
11 air_builders::symbolic::{get_symbolic_builder, SymbolicRapBuilder},
12 config::{Com, RapPartialProvingKey, StarkGenericConfig, Val},
13 interaction::{RapPhaseSeq, RapPhaseSeqKind},
14 keygen::types::{
15 LinearConstraint, MultiStarkProvingKey, ProverOnlySinglePreprocessedData, StarkProvingKey,
16 StarkVerifyingKey, TraceWidth, VerifierSinglePreprocessedData,
17 },
18 rap::AnyRap,
19};
20
21pub mod types;
22pub(crate) mod view;
23
24struct AirKeygenBuilder<SC: StarkGenericConfig> {
25 air: Arc<dyn AnyRap<SC>>,
26 rap_phase_seq_kind: RapPhaseSeqKind,
27 prep_keygen_data: PrepKeygenData<SC>,
28}
29
30pub struct MultiStarkKeygenBuilder<'a, SC: StarkGenericConfig> {
33 pub config: &'a SC,
34 partitioned_airs: Vec<AirKeygenBuilder<SC>>,
36 max_constraint_degree: usize,
37}
38
39impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> {
40 pub fn new(config: &'a SC) -> Self {
41 Self {
42 config,
43 partitioned_airs: vec![],
44 max_constraint_degree: 0,
45 }
46 }
47
48 pub fn set_max_constraint_degree(&mut self, max_constraint_degree: usize) {
53 self.max_constraint_degree = max_constraint_degree;
54 }
55
56 #[instrument(level = "debug", skip_all)]
59 pub fn add_air(&mut self, air: Arc<dyn AnyRap<SC>>) -> usize {
60 self.partitioned_airs.push(AirKeygenBuilder::new(
61 self.config.pcs(),
62 SC::RapPhaseSeq::ID,
63 air,
64 ));
65 self.partitioned_airs.len() - 1
66 }
67
68 pub fn generate_pk(mut self) -> MultiStarkProvingKey<SC> {
71 let air_max_constraint_degree = self
72 .partitioned_airs
73 .iter()
74 .map(|keygen_builder| {
75 let max_constraint_degree = keygen_builder.max_constraint_degree();
76 tracing::debug!(
77 "{} has constraint degree {}",
78 keygen_builder.air.name(),
79 max_constraint_degree
80 );
81 max_constraint_degree
82 })
83 .max()
84 .unwrap();
85 tracing::info!(
86 "Max constraint (excluding logup constraints) degree across all AIRs: {}",
87 air_max_constraint_degree
88 );
89 if self.max_constraint_degree != 0 && air_max_constraint_degree > self.max_constraint_degree
90 {
91 tracing::info!(
94 "Setting max_constraint_degree from {} to {air_max_constraint_degree}",
95 self.max_constraint_degree
96 );
97 self.max_constraint_degree = air_max_constraint_degree;
98 }
99 let symbolic_constraints_per_air = self
101 .partitioned_airs
102 .iter()
103 .map(|keygen_builder| keygen_builder.get_symbolic_builder(None).constraints())
104 .collect_vec();
105 let rap_partial_pk_per_air = self
108 .config
109 .rap_phase_seq()
110 .generate_pk_per_air(&symbolic_constraints_per_air, self.max_constraint_degree);
111 let pk_per_air: Vec<_> = zip(self.partitioned_airs, rap_partial_pk_per_air)
112 .map(|(keygen_builder, rap_partial_pk)| {
113 keygen_builder.generate_pk(rap_partial_pk, self.max_constraint_degree)
115 })
116 .collect();
117
118 for pk in pk_per_air.iter() {
119 let width = &pk.vk.params.width;
120 tracing::info!("{:<20} | Quotient Deg = {:<2} | Prep Cols = {:<2} | Main Cols = {:<8} | Perm Cols = {:<4} | {:4} Constraints | {:3} Interactions On Buses {:?}",
121 pk.air_name,
122 pk.vk.quotient_degree,
123 width.preprocessed.unwrap_or(0),
124 format!("{:?}",width.main_widths()),
125 format!("{:?}",width.after_challenge.iter().map(|&x| x * <SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D).collect_vec()),
126 pk.vk.symbolic_constraints.constraints.constraint_idx.len(),
127 pk.vk.symbolic_constraints.interactions.len(),
128 pk.vk
129 .symbolic_constraints
130 .interactions
131 .iter()
132 .map(|i| i.bus_index)
133 .collect_vec()
134 );
135 #[cfg(feature = "bench-metrics")]
136 {
137 let labels = [("air_name", pk.air_name.clone())];
138 metrics::counter!("quotient_deg", &labels).absolute(pk.vk.quotient_degree as u64);
139 metrics::counter!("constraints", &labels)
141 .absolute(pk.vk.symbolic_constraints.constraints.constraint_idx.len() as u64);
142 metrics::counter!("interactions", &labels)
143 .absolute(pk.vk.symbolic_constraints.interactions.len() as u64);
144 }
145 }
146
147 let num_airs = symbolic_constraints_per_air.len();
148 let base_order = Val::<SC>::order().to_u32_digits()[0];
149 let mut count_weight_per_air_per_bus_index = HashMap::new();
150
151 for (i, constraints_per_air) in symbolic_constraints_per_air.iter().enumerate() {
154 for interaction in &constraints_per_air.interactions {
155 let max_msg_len = self
158 .config
159 .rap_phase_seq()
160 .log_up_security_params()
161 .max_message_length();
162 let total_message_length = interaction.message.len() + 1;
164 assert!(
165 total_message_length <= max_msg_len,
166 "interaction message with bus has length {}, which is more than max {max_msg_len}",
167 total_message_length,
168 );
169
170 let b = interaction.bus_index;
171 let constraint = count_weight_per_air_per_bus_index
172 .entry(b)
173 .or_insert_with(|| LinearConstraint {
174 coefficients: vec![0; num_airs],
175 threshold: base_order,
176 });
177 constraint.coefficients[i] += interaction.count_weight;
178 }
179 }
180
181 let mut trace_height_constraints = count_weight_per_air_per_bus_index
183 .into_iter()
184 .sorted_by_key(|(bus_index, _)| *bus_index)
185 .map(|(_, constraint)| constraint)
186 .collect_vec();
187
188 let log_up_security_params = self.config.rap_phase_seq().log_up_security_params();
189
190 trace_height_constraints.push(LinearConstraint {
192 coefficients: symbolic_constraints_per_air
193 .iter()
194 .map(|c| c.interactions.len() as u32)
195 .collect(),
196 threshold: log_up_security_params.max_interaction_count,
197 });
198
199 let pre_vk: MultiStarkVerifyingKey0<SC> = MultiStarkVerifyingKey0 {
200 per_air: pk_per_air.iter().map(|pk| pk.vk.clone()).collect(),
201 trace_height_constraints: trace_height_constraints.clone(),
202 log_up_pow_bits: log_up_security_params.log_up_pow_bits,
203 };
204 let vk_bytes = bitcode::serialize(&pre_vk).unwrap();
208 tracing::info!("pre-vkey: {} bytes", vk_bytes.len());
209 let vk_as_row = RowMajorMatrix::new_row(
211 vk_bytes
212 .into_iter()
213 .map(Val::<SC>::from_canonical_u8)
214 .collect(),
215 );
216 let pcs = self.config.pcs();
217 let deg_1_domain = pcs.natural_domain_for_degree(1);
218 let (vk_pre_hash, _) = pcs.commit(vec![(deg_1_domain, vk_as_row)]);
219
220 MultiStarkProvingKey {
221 per_air: pk_per_air,
222 trace_height_constraints,
223 max_constraint_degree: self.max_constraint_degree,
224 log_up_pow_bits: log_up_security_params.log_up_pow_bits,
225 vk_pre_hash,
226 }
227 }
228}
229
230impl<SC: StarkGenericConfig> AirKeygenBuilder<SC> {
231 fn new(pcs: &SC::Pcs, rap_phase_seq_kind: RapPhaseSeqKind, air: Arc<dyn AnyRap<SC>>) -> Self {
232 let prep_keygen_data = compute_prep_data_for_air(pcs, air.as_ref());
233 AirKeygenBuilder {
234 air,
235 rap_phase_seq_kind,
236 prep_keygen_data,
237 }
238 }
239
240 fn max_constraint_degree(&self) -> usize {
241 self.get_symbolic_builder(None)
242 .constraints()
243 .max_constraint_degree()
244 }
245
246 fn generate_pk(
247 self,
248 rap_partial_pk: RapPartialProvingKey<SC>,
249 max_constraint_degree: usize,
250 ) -> StarkProvingKey<SC> {
251 let air_name = self.air.name();
252
253 let symbolic_builder = self.get_symbolic_builder(Some(max_constraint_degree));
254 let params = symbolic_builder.params();
255 let symbolic_constraints = symbolic_builder.constraints();
256 let log_quotient_degree = symbolic_constraints.get_log_quotient_degree();
257 let quotient_degree = 1 << log_quotient_degree;
258
259 let Self {
260 prep_keygen_data:
261 PrepKeygenData {
262 verifier_data: prep_verifier_data,
263 prover_data: prep_prover_data,
264 },
265 ..
266 } = self;
267
268 let vk: StarkVerifyingKey<Val<SC>, Com<SC>> = StarkVerifyingKey {
269 preprocessed_data: prep_verifier_data,
270 params,
271 symbolic_constraints: symbolic_constraints.into(),
272 quotient_degree,
273 rap_phase_seq_kind: self.rap_phase_seq_kind,
274 };
275 StarkProvingKey {
276 air_name,
277 vk,
278 preprocessed_data: prep_prover_data,
279 rap_partial_pk,
280 }
281 }
282
283 fn get_symbolic_builder(
284 &self,
285 max_constraint_degree: Option<usize>,
286 ) -> SymbolicRapBuilder<Val<SC>> {
287 let width = TraceWidth {
288 preprocessed: self.prep_keygen_data.width(),
289 cached_mains: self.air.cached_main_widths(),
290 common_main: self.air.common_main_width(),
291 after_challenge: vec![],
292 };
293 get_symbolic_builder(
294 self.air.as_ref(),
295 &width,
296 &[],
297 &[],
298 SC::RapPhaseSeq::ID,
299 max_constraint_degree.unwrap_or(0),
300 )
301 }
302}
303
304pub(super) struct PrepKeygenData<SC: StarkGenericConfig> {
305 pub verifier_data: Option<VerifierSinglePreprocessedData<Com<SC>>>,
306 pub prover_data: Option<ProverOnlySinglePreprocessedData<SC>>,
307}
308
309impl<SC: StarkGenericConfig> PrepKeygenData<SC> {
310 pub fn width(&self) -> Option<usize> {
311 self.prover_data.as_ref().map(|d| d.trace.width())
312 }
313}
314
315fn compute_prep_data_for_air<SC: StarkGenericConfig>(
316 pcs: &SC::Pcs,
317 air: &dyn AnyRap<SC>,
318) -> PrepKeygenData<SC> {
319 let preprocessed_trace = air.preprocessed_trace();
320 let vpdata_opt = preprocessed_trace.map(|trace| {
321 let domain = pcs.natural_domain_for_degree(trace.height());
322 let (commit, data) = pcs.commit(vec![(domain, trace.clone())]);
323 let vdata = VerifierSinglePreprocessedData { commit };
324 let pdata = ProverOnlySinglePreprocessedData {
325 trace: Arc::new(trace),
326 data: Arc::new(data),
327 };
328 (vdata, pdata)
329 });
330 if let Some((vdata, pdata)) = vpdata_opt {
331 PrepKeygenData {
332 prover_data: Some(pdata),
333 verifier_data: Some(vdata),
334 }
335 } else {
336 PrepKeygenData {
337 prover_data: None,
338 verifier_data: None,
339 }
340 }
341}