openvm_stark_backend/keygen/
mod.rs1use std::{collections::HashMap, iter::zip, sync::Arc};
2
3use itertools::Itertools;
4use p3_commit::Pcs;
5use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra};
6use p3_matrix::{dense::RowMajorMatrix, Matrix};
7use tracing::instrument;
8use types::MultiStarkVerifyingKey0;
9
10use crate::{
11 air_builders::symbolic::{get_symbolic_builder, SymbolicRapBuilder},
12 config::{Com, RapPartialProvingKey, StarkGenericConfig, Val},
13 interaction::{RapPhaseSeq, RapPhaseSeqKind},
14 keygen::types::{
15 LinearConstraint, MultiStarkProvingKey, ProverOnlySinglePreprocessedData, StarkProvingKey,
16 StarkVerifyingKey, TraceWidth, VerifierSinglePreprocessedData,
17 },
18 rap::AnyRap,
19};
20
21pub mod types;
22pub mod view;
23
24struct AirKeygenBuilder<SC: StarkGenericConfig> {
25 air: Arc<dyn AnyRap<SC>>,
26 rap_phase_seq_kind: RapPhaseSeqKind,
27 prep_keygen_data: PrepKeygenData<SC>,
28}
29
30pub struct MultiStarkKeygenBuilder<'a, SC: StarkGenericConfig> {
33 pub config: &'a SC,
34 partitioned_airs: Vec<AirKeygenBuilder<SC>>,
36 max_constraint_degree: usize,
37}
38
39impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> {
40 pub fn new(config: &'a SC) -> Self {
41 Self {
42 config,
43 partitioned_airs: vec![],
44 max_constraint_degree: 0,
45 }
46 }
47
48 pub fn set_max_constraint_degree(&mut self, max_constraint_degree: usize) {
53 self.max_constraint_degree = max_constraint_degree;
54 }
55
56 #[instrument(level = "debug", skip_all)]
59 pub fn add_air(&mut self, air: Arc<dyn AnyRap<SC>>) -> usize {
60 self.partitioned_airs.push(AirKeygenBuilder::new(
61 self.config.pcs(),
62 SC::RapPhaseSeq::ID,
63 air,
64 ));
65 self.partitioned_airs.len() - 1
66 }
67
68 pub fn generate_pk(mut self) -> MultiStarkProvingKey<SC> {
71 let air_max_constraint_degree = self
72 .partitioned_airs
73 .iter()
74 .map(|keygen_builder| {
75 let max_constraint_degree = keygen_builder.max_constraint_degree();
76 tracing::debug!(
77 "{} has constraint degree {}",
78 keygen_builder.air.name(),
79 max_constraint_degree
80 );
81 max_constraint_degree
82 })
83 .max()
84 .unwrap();
85 tracing::info!(
86 "Max constraint (excluding logup constraints) degree across all AIRs: {}",
87 air_max_constraint_degree
88 );
89 if self.max_constraint_degree != 0 && air_max_constraint_degree > self.max_constraint_degree
90 {
91 tracing::info!(
94 "Setting max_constraint_degree from {} to {air_max_constraint_degree}",
95 self.max_constraint_degree
96 );
97 self.max_constraint_degree = air_max_constraint_degree;
98 }
99 let symbolic_constraints_per_air = self
102 .partitioned_airs
103 .iter()
104 .map(|keygen_builder| keygen_builder.get_symbolic_builder(None).constraints())
105 .collect_vec();
106 let rap_partial_pk_per_air = self
110 .config
111 .rap_phase_seq()
112 .generate_pk_per_air(&symbolic_constraints_per_air, self.max_constraint_degree);
113 let pk_per_air: Vec<_> = zip(self.partitioned_airs, rap_partial_pk_per_air)
114 .map(|(keygen_builder, rap_partial_pk)| {
115 keygen_builder.generate_pk(rap_partial_pk, self.max_constraint_degree)
117 })
118 .collect();
119
120 for pk in pk_per_air.iter() {
121 let width = &pk.vk.params.width;
122 tracing::info!("{:<20} | Quotient Deg = {:<2} | Prep Cols = {:<2} | Main Cols = {:<8} | Perm Cols = {:<4} | {:4} Constraints | {:3} Interactions",
123 pk.air_name,
124 pk.vk.quotient_degree,
125 width.preprocessed.unwrap_or(0),
126 format!("{:?}",width.main_widths()),
127 format!("{:?}",width.after_challenge.iter().map(|&x| x * <SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D).collect_vec()),
128 pk.vk.symbolic_constraints.constraints.constraint_idx.len(),
129 pk.vk.symbolic_constraints.interactions.len(),
130 );
131 tracing::debug!(
132 "On Buses {:?}",
133 pk.vk
134 .symbolic_constraints
135 .interactions
136 .iter()
137 .map(|i| i.bus_index)
138 .collect_vec()
139 );
140 #[cfg(feature = "metrics")]
141 {
142 let labels = [("air_name", pk.air_name.clone())];
143 metrics::counter!("quotient_deg", &labels).absolute(pk.vk.quotient_degree as u64);
144 metrics::counter!("constraints", &labels)
146 .absolute(pk.vk.symbolic_constraints.constraints.constraint_idx.len() as u64);
147 metrics::counter!("interactions", &labels)
148 .absolute(pk.vk.symbolic_constraints.interactions.len() as u64);
149 }
150 }
151
152 let num_airs = symbolic_constraints_per_air.len();
153 let base_order = Val::<SC>::order().to_u32_digits()[0];
154 let mut count_weight_per_air_per_bus_index = HashMap::new();
155
156 for (i, constraints_per_air) in symbolic_constraints_per_air.iter().enumerate() {
160 for interaction in &constraints_per_air.interactions {
161 let max_msg_len = self
164 .config
165 .rap_phase_seq()
166 .log_up_security_params()
167 .max_message_length();
168 let total_message_length = interaction.message.len() + 1;
170 assert!(
171 total_message_length <= max_msg_len,
172 "interaction message with bus has length {}, which is more than max {max_msg_len}",
173 total_message_length,
174 );
175
176 let b = interaction.bus_index;
177 let constraint = count_weight_per_air_per_bus_index
178 .entry(b)
179 .or_insert_with(|| LinearConstraint {
180 coefficients: vec![0; num_airs],
181 threshold: base_order,
182 });
183 constraint.coefficients[i] += interaction.count_weight;
184 }
185 }
186
187 let mut trace_height_constraints = count_weight_per_air_per_bus_index
189 .into_iter()
190 .sorted_by_key(|(bus_index, _)| *bus_index)
191 .map(|(_, constraint)| constraint)
192 .collect_vec();
193
194 let log_up_security_params = self.config.rap_phase_seq().log_up_security_params();
195
196 trace_height_constraints.push(LinearConstraint {
198 coefficients: symbolic_constraints_per_air
199 .iter()
200 .map(|c| c.interactions.len() as u32)
201 .collect(),
202 threshold: log_up_security_params.max_interaction_count,
203 });
204
205 let pre_vk: MultiStarkVerifyingKey0<SC> = MultiStarkVerifyingKey0 {
206 per_air: pk_per_air.iter().map(|pk| pk.vk.clone()).collect(),
207 trace_height_constraints: trace_height_constraints.clone(),
208 log_up_pow_bits: log_up_security_params.log_up_pow_bits,
209 };
210 let vk_bytes = bitcode::serialize(&pre_vk).unwrap();
214 tracing::info!("pre-vkey: {} bytes", vk_bytes.len());
215 let vk_as_row = RowMajorMatrix::new_row(
218 vk_bytes
219 .into_iter()
220 .map(Val::<SC>::from_canonical_u8)
221 .collect(),
222 );
223 let pcs = self.config.pcs();
224 let deg_1_domain = pcs.natural_domain_for_degree(1);
225 let (vk_pre_hash, _) = pcs.commit(vec![(deg_1_domain, vk_as_row)]);
226
227 MultiStarkProvingKey {
228 per_air: pk_per_air,
229 trace_height_constraints,
230 max_constraint_degree: self.max_constraint_degree,
231 log_up_pow_bits: log_up_security_params.log_up_pow_bits,
232 vk_pre_hash,
233 }
234 }
235}
236
237impl<SC: StarkGenericConfig> AirKeygenBuilder<SC> {
238 fn new(pcs: &SC::Pcs, rap_phase_seq_kind: RapPhaseSeqKind, air: Arc<dyn AnyRap<SC>>) -> Self {
239 let prep_keygen_data = compute_prep_data_for_air(pcs, air.as_ref());
240 AirKeygenBuilder {
241 air,
242 rap_phase_seq_kind,
243 prep_keygen_data,
244 }
245 }
246
247 fn max_constraint_degree(&self) -> usize {
248 self.get_symbolic_builder(None)
249 .constraints()
250 .max_constraint_degree()
251 }
252
253 fn generate_pk(
254 self,
255 rap_partial_pk: RapPartialProvingKey<SC>,
256 max_constraint_degree: usize,
257 ) -> StarkProvingKey<SC> {
258 let air_name = self.air.name();
259
260 let symbolic_builder = self.get_symbolic_builder(Some(max_constraint_degree));
261 let params = symbolic_builder.params();
262 let symbolic_constraints = symbolic_builder.constraints();
263 let log_quotient_degree = symbolic_constraints.get_log_quotient_degree();
264 let quotient_degree = 1 << log_quotient_degree;
265
266 let Self {
267 prep_keygen_data:
268 PrepKeygenData {
269 verifier_data: prep_verifier_data,
270 prover_data: prep_prover_data,
271 },
272 ..
273 } = self;
274
275 let vk: StarkVerifyingKey<Val<SC>, Com<SC>> = StarkVerifyingKey {
276 preprocessed_data: prep_verifier_data,
277 params,
278 symbolic_constraints: symbolic_constraints.into(),
279 quotient_degree,
280 rap_phase_seq_kind: self.rap_phase_seq_kind,
281 };
282 StarkProvingKey {
283 air_name,
284 vk,
285 preprocessed_data: prep_prover_data,
286 rap_partial_pk,
287 }
288 }
289
290 fn get_symbolic_builder(
291 &self,
292 max_constraint_degree: Option<usize>,
293 ) -> SymbolicRapBuilder<Val<SC>> {
294 let width = TraceWidth {
295 preprocessed: self.prep_keygen_data.width(),
296 cached_mains: self.air.cached_main_widths(),
297 common_main: self.air.common_main_width(),
298 after_challenge: vec![],
299 };
300 get_symbolic_builder(
301 self.air.as_ref(),
302 &width,
303 &[],
304 &[],
305 SC::RapPhaseSeq::ID,
306 max_constraint_degree.unwrap_or(0),
307 )
308 }
309}
310
311pub(super) struct PrepKeygenData<SC: StarkGenericConfig> {
312 pub verifier_data: Option<VerifierSinglePreprocessedData<Com<SC>>>,
313 pub prover_data: Option<ProverOnlySinglePreprocessedData<SC>>,
314}
315
316impl<SC: StarkGenericConfig> PrepKeygenData<SC> {
317 pub fn width(&self) -> Option<usize> {
318 self.prover_data.as_ref().map(|d| d.trace.width())
319 }
320}
321
322fn compute_prep_data_for_air<SC: StarkGenericConfig>(
323 pcs: &SC::Pcs,
324 air: &dyn AnyRap<SC>,
325) -> PrepKeygenData<SC> {
326 let preprocessed_trace = air.preprocessed_trace();
327 let vpdata_opt = preprocessed_trace.map(|trace| {
328 let domain = pcs.natural_domain_for_degree(trace.height());
329 let (commit, data) = pcs.commit(vec![(domain, trace.clone())]);
330 let vdata = VerifierSinglePreprocessedData { commit };
331 let pdata = ProverOnlySinglePreprocessedData {
332 trace: Arc::new(trace),
333 data: Arc::new(data),
334 };
335 (vdata, pdata)
336 });
337 if let Some((vdata, pdata)) = vpdata_opt {
338 PrepKeygenData {
339 prover_data: Some(pdata),
340 verifier_data: Some(vdata),
341 }
342 } else {
343 PrepKeygenData {
344 prover_data: None,
345 verifier_data: None,
346 }
347 }
348}