openvm_stark_backend/keygen/
mod.rs1use std::{collections::HashMap, iter::zip, sync::Arc};
2
3use itertools::Itertools;
4use p3_commit::Pcs;
5use p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing};
6use p3_matrix::{dense::RowMajorMatrix, Matrix};
7use tracing::instrument;
8use types::MultiStarkVerifyingKey0;
9
10use crate::{
11 air_builders::symbolic::{get_symbolic_builder, SymbolicRapBuilder},
12 config::{Com, RapPartialProvingKey, StarkGenericConfig, Val},
13 interaction::{RapPhaseSeq, RapPhaseSeqKind},
14 keygen::types::{
15 LinearConstraint, MultiStarkProvingKey, ProverOnlySinglePreprocessedData, StarkProvingKey,
16 StarkVerifyingKey, TraceWidth, VerifierSinglePreprocessedData,
17 },
18 rap::AnyRap,
19};
20
21pub mod types;
22pub mod view;
23
24struct AirKeygenBuilder<SC: StarkGenericConfig> {
25 air: Arc<dyn AnyRap<SC>>,
26 rap_phase_seq_kind: RapPhaseSeqKind,
27 prep_keygen_data: PrepKeygenData<SC>,
28}
29
30pub struct MultiStarkKeygenBuilder<'a, SC: StarkGenericConfig> {
33 pub config: &'a SC,
34 pub max_batch_size: Option<usize>,
35 pub max_num_constraints: Option<usize>,
36 partitioned_airs: Vec<AirKeygenBuilder<SC>>,
38 max_constraint_degree: usize,
39}
40
41impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> {
42 pub fn new(config: &'a SC) -> Self {
43 Self {
44 config,
45 max_batch_size: None,
46 max_num_constraints: None,
47 partitioned_airs: vec![],
48 max_constraint_degree: 0,
49 }
50 }
51
52 pub fn set_max_constraint_degree(&mut self, max_constraint_degree: usize) {
57 self.max_constraint_degree = max_constraint_degree;
58 }
59
60 #[instrument(level = "debug", skip_all)]
63 pub fn add_air(&mut self, air: Arc<dyn AnyRap<SC>>) -> usize {
64 self.partitioned_airs.push(AirKeygenBuilder::new(
65 self.config.pcs(),
66 SC::RapPhaseSeq::ID,
67 air,
68 ));
69 self.partitioned_airs.len() - 1
70 }
71
72 pub fn generate_pk(mut self) -> MultiStarkProvingKey<SC> {
75 let air_max_constraint_degree = self
76 .partitioned_airs
77 .iter()
78 .map(|keygen_builder| {
79 let max_constraint_degree = keygen_builder.max_constraint_degree();
80 tracing::debug!(
81 "{} has constraint degree {}",
82 keygen_builder.air.name(),
83 max_constraint_degree
84 );
85 max_constraint_degree
86 })
87 .max()
88 .unwrap();
89 tracing::info!(
90 "Max constraint (excluding logup constraints) degree across all AIRs: {}",
91 air_max_constraint_degree
92 );
93 if self.max_constraint_degree != 0 && air_max_constraint_degree > self.max_constraint_degree
94 {
95 tracing::info!(
98 "Setting max_constraint_degree from {} to {air_max_constraint_degree}",
99 self.max_constraint_degree
100 );
101 self.max_constraint_degree = air_max_constraint_degree;
102 }
103 let symbolic_constraints_per_air = self
106 .partitioned_airs
107 .iter()
108 .map(|keygen_builder| keygen_builder.get_symbolic_builder(None).constraints())
109 .collect_vec();
110 let rap_partial_pk_per_air = self
114 .config
115 .rap_phase_seq()
116 .generate_pk_per_air(&symbolic_constraints_per_air, self.max_constraint_degree);
117 let pk_per_air: Vec<_> = zip(self.partitioned_airs, rap_partial_pk_per_air)
118 .map(|(keygen_builder, rap_partial_pk)| {
119 keygen_builder.generate_pk(rap_partial_pk, self.max_constraint_degree)
121 })
122 .collect();
123
124 for pk in pk_per_air.iter() {
125 let width = &pk.vk.params.width;
126 tracing::info!("{:<20} | Quotient Deg = {:<2} | Prep Cols = {:<2} | Main Cols = {:<8} | Perm Cols = {:<4} | {:4} Constraints | {:3} Interactions",
127 pk.air_name,
128 pk.vk.quotient_degree,
129 width.preprocessed.unwrap_or(0),
130 format!("{:?}",width.main_widths()),
131 format!("{:?}",width.after_challenge.iter().map(|&x| x * <SC::Challenge as BasedVectorSpace<Val<SC>>>::DIMENSION).collect_vec()),
132 pk.vk.symbolic_constraints.constraints.constraint_idx.len(),
133 pk.vk.symbolic_constraints.interactions.len(),
134 );
135 tracing::debug!(
136 "On Buses {:?}",
137 pk.vk
138 .symbolic_constraints
139 .interactions
140 .iter()
141 .map(|i| i.bus_index)
142 .collect_vec()
143 );
144 #[cfg(feature = "metrics")]
145 {
146 let labels = [("air_name", pk.air_name.clone())];
147 metrics::counter!("quotient_deg", &labels).absolute(pk.vk.quotient_degree as u64);
148 metrics::counter!("constraints", &labels)
150 .absolute(pk.vk.symbolic_constraints.constraints.constraint_idx.len() as u64);
151 metrics::counter!("interactions", &labels)
152 .absolute(pk.vk.symbolic_constraints.interactions.len() as u64);
153 }
154 }
155
156 let num_airs = symbolic_constraints_per_air.len();
157 let base_order = Val::<SC>::order().to_u32_digits()[0];
158 let mut count_weight_per_air_per_bus_index = HashMap::new();
159
160 for (i, constraints_per_air) in symbolic_constraints_per_air.iter().enumerate() {
164 for interaction in &constraints_per_air.interactions {
165 let max_msg_len = self
168 .config
169 .rap_phase_seq()
170 .log_up_security_params()
171 .max_message_length();
172 let total_message_length = interaction.message.len() + 1;
174 assert!(
175 total_message_length <= max_msg_len,
176 "interaction message with bus has length {}, which is more than max {max_msg_len}",
177 total_message_length,
178 );
179
180 let b = interaction.bus_index;
181 let constraint = count_weight_per_air_per_bus_index
182 .entry(b)
183 .or_insert_with(|| LinearConstraint {
184 coefficients: vec![0; num_airs],
185 threshold: base_order,
186 });
187 constraint.coefficients[i] += interaction.count_weight;
188 }
189 }
190
191 let mut trace_height_constraints = count_weight_per_air_per_bus_index
193 .into_iter()
194 .sorted_by_key(|(bus_index, _)| *bus_index)
195 .map(|(_, constraint)| constraint)
196 .collect_vec();
197
198 let log_up_security_params = self.config.rap_phase_seq().log_up_security_params();
199
200 trace_height_constraints.push(LinearConstraint {
202 coefficients: symbolic_constraints_per_air
203 .iter()
204 .map(|c| c.interactions.len() as u32)
205 .collect(),
206 threshold: log_up_security_params.max_interaction_count,
207 });
208
209 let deep_pow_bits = self.config.deep_ali_params().deep_pow_bits;
210 let pre_vk: MultiStarkVerifyingKey0<SC> = MultiStarkVerifyingKey0 {
211 per_air: pk_per_air.iter().map(|pk| pk.vk.clone()).collect(),
212 trace_height_constraints: trace_height_constraints.clone(),
213 log_up_pow_bits: log_up_security_params.log_up_pow_bits,
214 deep_pow_bits,
215 };
216 if let Some(max_batch_size) = self.max_batch_size {
217 let ext_degree = <SC::Challenge as BasedVectorSpace<Val<SC>>>::DIMENSION;
218 let mut total_width = pre_vk
219 .per_air
220 .iter()
221 .map(|vk| vk.params.width.total_width(ext_degree))
222 .sum::<usize>();
223 let quotient_deg = pre_vk
224 .per_air
225 .iter()
226 .map(|vk| vk.quotient_degree as usize)
227 .max()
228 .unwrap_or_default();
229 total_width += quotient_deg * ext_degree;
231 tracing::info!(%total_width);
232 assert!(
234 total_width * 2 <= max_batch_size,
235 "Maximum number of AIR columns exceeded for desired security level"
236 );
237 }
238 if let Some(max_num_constraints) = self.max_num_constraints {
239 let total_constraint_count = pre_vk
240 .per_air
241 .iter()
242 .map(|vk| vk.symbolic_constraints.constraints.num_constraints())
243 .sum::<usize>();
244 tracing::info!(%total_constraint_count);
245 assert!(
246 total_constraint_count <= max_num_constraints,
247 "Maximum number of constraints exceeded for desired security level"
248 );
249 }
250
251 let vk_bytes = bitcode::serialize(&pre_vk).unwrap();
255 tracing::info!("pre-vkey: {} bytes", vk_bytes.len());
256 let vk_as_row =
259 RowMajorMatrix::new_row(vk_bytes.into_iter().map(Val::<SC>::from_u8).collect());
260 let pcs = self.config.pcs();
261 let deg_1_domain = pcs.natural_domain_for_degree(1);
262 let (vk_pre_hash, _) = pcs.commit(vec![(deg_1_domain, vk_as_row)]);
263
264 MultiStarkProvingKey {
265 per_air: pk_per_air,
266 trace_height_constraints,
267 max_constraint_degree: self.max_constraint_degree,
268 log_up_pow_bits: log_up_security_params.log_up_pow_bits,
269 deep_pow_bits,
270 vk_pre_hash,
271 }
272 }
273}
274
275impl<SC: StarkGenericConfig> AirKeygenBuilder<SC> {
276 fn new(pcs: &SC::Pcs, rap_phase_seq_kind: RapPhaseSeqKind, air: Arc<dyn AnyRap<SC>>) -> Self {
277 let prep_keygen_data = compute_prep_data_for_air(pcs, air.as_ref());
278 AirKeygenBuilder {
279 air,
280 rap_phase_seq_kind,
281 prep_keygen_data,
282 }
283 }
284
285 fn max_constraint_degree(&self) -> usize {
286 self.get_symbolic_builder(None)
287 .constraints()
288 .max_constraint_degree()
289 }
290
291 fn generate_pk(
292 self,
293 rap_partial_pk: RapPartialProvingKey<SC>,
294 max_constraint_degree: usize,
295 ) -> StarkProvingKey<SC> {
296 let air_name = self.air.name();
297
298 let symbolic_builder = self.get_symbolic_builder(Some(max_constraint_degree));
299 let params = symbolic_builder.params();
300 let symbolic_constraints = symbolic_builder.constraints();
301 let log_quotient_degree = symbolic_constraints.get_log_quotient_degree();
302 let quotient_degree = 1 << log_quotient_degree;
303
304 let Self {
305 prep_keygen_data:
306 PrepKeygenData {
307 verifier_data: prep_verifier_data,
308 prover_data: prep_prover_data,
309 },
310 ..
311 } = self;
312
313 let vk: StarkVerifyingKey<Val<SC>, Com<SC>> = StarkVerifyingKey {
314 preprocessed_data: prep_verifier_data,
315 params,
316 symbolic_constraints: symbolic_constraints.into(),
317 quotient_degree,
318 rap_phase_seq_kind: self.rap_phase_seq_kind,
319 };
320 StarkProvingKey {
321 air_name,
322 vk,
323 preprocessed_data: prep_prover_data,
324 rap_partial_pk,
325 }
326 }
327
328 fn get_symbolic_builder(
329 &self,
330 max_constraint_degree: Option<usize>,
331 ) -> SymbolicRapBuilder<Val<SC>> {
332 let width = TraceWidth {
333 preprocessed: self.prep_keygen_data.width(),
334 cached_mains: self.air.cached_main_widths(),
335 common_main: self.air.common_main_width(),
336 after_challenge: vec![],
337 };
338 get_symbolic_builder(
339 self.air.as_ref(),
340 &width,
341 &[],
342 &[],
343 SC::RapPhaseSeq::ID,
344 max_constraint_degree.unwrap_or(0),
345 )
346 }
347}
348
349pub(super) struct PrepKeygenData<SC: StarkGenericConfig> {
350 pub verifier_data: Option<VerifierSinglePreprocessedData<Com<SC>>>,
351 pub prover_data: Option<ProverOnlySinglePreprocessedData<SC>>,
352}
353
354impl<SC: StarkGenericConfig> PrepKeygenData<SC> {
355 pub fn width(&self) -> Option<usize> {
356 self.prover_data.as_ref().map(|d| d.trace.width())
357 }
358}
359
360fn compute_prep_data_for_air<SC: StarkGenericConfig>(
361 pcs: &SC::Pcs,
362 air: &dyn AnyRap<SC>,
363) -> PrepKeygenData<SC> {
364 let preprocessed_trace = air.preprocessed_trace();
365 let vpdata_opt = preprocessed_trace.map(|trace| {
366 let domain = pcs.natural_domain_for_degree(trace.height());
367 let (commit, data) = pcs.commit(vec![(domain, trace.clone())]);
368 let vdata = VerifierSinglePreprocessedData { commit };
369 let pdata = ProverOnlySinglePreprocessedData {
370 trace: Arc::new(trace),
371 data: Arc::new(data),
372 };
373 (vdata, pdata)
374 });
375 if let Some((vdata, pdata)) = vpdata_opt {
376 PrepKeygenData {
377 prover_data: Some(pdata),
378 verifier_data: Some(vdata),
379 }
380 } else {
381 PrepKeygenData {
382 prover_data: None,
383 verifier_data: None,
384 }
385 }
386}