1use std::{array, borrow::Borrow, cmp::max, iter::zip, marker::PhantomData, mem};
2
3use itertools::Itertools;
4use p3_air::ExtensionBuilder;
5use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
6use p3_field::{ExtensionField, Field, FieldAlgebra};
7use p3_matrix::{dense::RowMajorMatrix, Matrix};
8use p3_maybe_rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11
12use super::{LogUpSecurityParameters, PairTraceView, SymbolicInteraction};
13use crate::{
14 air_builders::symbolic::{symbolic_expression::SymbolicEvaluator, SymbolicConstraints},
15 interaction::{
16 trace::Evaluator, utils::generate_betas, InteractionBuilder, RapPhaseProverData,
17 RapPhaseSeq, RapPhaseSeqKind, RapPhaseVerifierData,
18 },
19 parizip,
20 rap::PermutationAirBuilderWithExposedValues,
21 utils::metrics_span,
22};
23
24pub struct FriLogUpPhase<F, Challenge, Challenger> {
25 log_up_params: LogUpSecurityParameters,
26 _marker: PhantomData<(F, Challenge, Challenger)>,
27}
28
29impl<F, Challenge, Challenger> FriLogUpPhase<F, Challenge, Challenger> {
30 pub fn new(log_up_params: LogUpSecurityParameters) -> Self {
31 Self {
32 log_up_params,
33 _marker: PhantomData,
34 }
35 }
36}
37
38#[derive(Error, Debug)]
39pub enum FriLogUpError {
40 #[error("non-zero cumulative sum")]
41 NonZeroCumulativeSum,
42 #[error("missing proof")]
43 MissingPartialProof,
44 #[error("invalid proof of work witness")]
45 InvalidPowWitness,
46}
47
48#[derive(Clone, Serialize, Deserialize)]
49pub struct FriLogUpPartialProof<Witness> {
50 pub logup_pow_witness: Witness,
51}
52
53#[derive(Clone, Default, Serialize, Deserialize)]
54pub struct FriLogUpProvingKey {
55 interaction_partitions: Vec<Vec<usize>>,
56}
57
58impl FriLogUpProvingKey {
59 pub fn interaction_partitions(self) -> Vec<Vec<usize>> {
60 self.interaction_partitions
61 }
62 pub fn num_chunks(&self) -> usize {
63 self.interaction_partitions.len()
64 }
65}
66
67impl<F: Field, Challenge, Challenger> RapPhaseSeq<F, Challenge, Challenger>
68 for FriLogUpPhase<F, Challenge, Challenger>
69where
70 F: Field,
71 Challenge: ExtensionField<F>,
72 Challenger: FieldChallenger<F> + GrindingChallenger<Witness = F>,
73{
74 type PartialProof = FriLogUpPartialProof<F>;
75 type PartialProvingKey = FriLogUpProvingKey;
76 type Error = FriLogUpError;
77 const ID: RapPhaseSeqKind = RapPhaseSeqKind::FriLogUp;
78
79 fn log_up_security_params(&self) -> &LogUpSecurityParameters {
80 &self.log_up_params
81 }
82
83 fn generate_pk_per_air(
84 &self,
85 symbolic_constraints_per_air: &[SymbolicConstraints<F>],
86 max_constraint_degree: usize,
87 ) -> Vec<Self::PartialProvingKey> {
88 symbolic_constraints_per_air
89 .iter()
90 .map(|constraints| {
91 find_interaction_chunks(&constraints.interactions, max_constraint_degree)
92 })
93 .collect()
94 }
95
96 fn partially_prove(
97 &self,
98 challenger: &mut Challenger,
99 constraints_per_air: &[&SymbolicConstraints<F>],
100 params_per_air: &[&FriLogUpProvingKey],
101 trace_view_per_air: &[PairTraceView<F>],
102 ) -> Option<(Self::PartialProof, RapPhaseProverData<Challenge>)> {
103 let has_any_interactions = constraints_per_air
104 .iter()
105 .any(|constraints| !constraints.interactions.is_empty());
106
107 if !has_any_interactions {
108 return None;
109 }
110
111 let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits);
113 let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] =
114 array::from_fn(|_| challenger.sample_ext_element::<Challenge>());
115
116 let after_challenge_trace_per_air = metrics_span("generate_perm_trace_time_ms", || {
117 Self::generate_after_challenge_traces_per_air(
118 &challenges,
119 constraints_per_air,
120 params_per_air,
121 trace_view_per_air,
122 )
123 });
124 let cumulative_sum_per_air = Self::extract_cumulative_sums(&after_challenge_trace_per_air);
125
126 for cumulative_sum in cumulative_sum_per_air.iter().flatten() {
128 challenger.observe_slice(cumulative_sum.as_base_slice());
129 }
130
131 let exposed_values_per_air = cumulative_sum_per_air
132 .iter()
133 .map(|csum| csum.map(|csum| vec![csum]))
134 .collect_vec();
135
136 Some((
137 FriLogUpPartialProof { logup_pow_witness },
138 RapPhaseProverData {
139 challenges: challenges.to_vec(),
140 after_challenge_trace_per_air,
141 exposed_values_per_air,
142 },
143 ))
144 }
145
146 fn partially_verify<Commitment: Clone>(
147 &self,
148 challenger: &mut Challenger,
149 partial_proof: Option<&Self::PartialProof>,
150 exposed_values_per_phase_per_air: &[Vec<Vec<Challenge>>],
151 commitment_per_phase: &[Commitment],
152 _permutation_opened_values: &[Vec<Vec<Vec<Challenge>>>],
153 ) -> (RapPhaseVerifierData<Challenge>, Result<(), Self::Error>)
154 where
155 Challenger: CanObserve<Commitment>,
156 {
157 if exposed_values_per_phase_per_air
158 .iter()
159 .all(|exposed_values_per_phase_per_air| exposed_values_per_phase_per_air.is_empty())
160 {
161 return (RapPhaseVerifierData::default(), Ok(()));
162 }
163
164 let partial_proof = match partial_proof {
165 Some(proof) => proof,
166 None => {
167 return (
168 RapPhaseVerifierData::default(),
169 Err(FriLogUpError::MissingPartialProof),
170 );
171 }
172 };
173
174 if !challenger.check_witness(
175 self.log_up_params.log_up_pow_bits,
176 partial_proof.logup_pow_witness,
177 ) {
178 return (
179 RapPhaseVerifierData::default(),
180 Err(FriLogUpError::InvalidPowWitness),
181 );
182 }
183
184 let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] =
185 array::from_fn(|_| challenger.sample_ext_element::<Challenge>());
186
187 for exposed_values_per_phase in exposed_values_per_phase_per_air.iter() {
188 if let Some(exposed_values) = exposed_values_per_phase.first() {
189 for exposed_value in exposed_values {
190 challenger.observe_slice(exposed_value.as_base_slice());
191 }
192 }
193 }
194
195 challenger.observe(commitment_per_phase[0].clone());
196
197 let cumulative_sums = exposed_values_per_phase_per_air
198 .iter()
199 .map(|exposed_values_per_phase| {
200 assert!(
201 exposed_values_per_phase.len() <= 1,
202 "Verifier does not support more than 1 challenge phase"
203 );
204 exposed_values_per_phase.first().map(|exposed_values| {
205 assert_eq!(
206 exposed_values.len(),
207 1,
208 "Only exposed value should be cumulative sum"
209 );
210 exposed_values[0]
211 })
212 })
213 .collect_vec();
214
215 let sum: Challenge = cumulative_sums
217 .into_iter()
218 .map(|c| c.unwrap_or(Challenge::ZERO))
219 .sum();
220
221 let result = if sum == Challenge::ZERO {
222 Ok(())
223 } else {
224 Err(Self::Error::NonZeroCumulativeSum)
225 };
226 let verifier_data = RapPhaseVerifierData {
227 challenges_per_phase: vec![challenges.to_vec()],
228 };
229 (verifier_data, result)
230 }
231}
232
233pub const STARK_LU_NUM_CHALLENGES: usize = 2;
234pub const STARK_LU_NUM_EXPOSED_VALUES: usize = 1;
235
236impl<F, Challenge, Challenger> FriLogUpPhase<F, Challenge, Challenger>
237where
238 F: Field,
239 Challenge: ExtensionField<F>,
240 Challenger: FieldChallenger<F>,
241{
242 fn generate_after_challenge_traces_per_air(
244 challenges: &[Challenge; STARK_LU_NUM_CHALLENGES],
245 constraints_per_air: &[&SymbolicConstraints<F>],
246 params_per_air: &[&FriLogUpProvingKey],
247 trace_view_per_air: &[PairTraceView<F>],
248 ) -> Vec<Option<RowMajorMatrix<Challenge>>> {
249 parizip!(constraints_per_air, trace_view_per_air, params_per_air)
250 .map(|(constraints, trace_view, params)| {
251 Self::generate_after_challenge_trace(
252 &constraints.interactions,
253 trace_view,
254 challenges,
255 ¶ms.interaction_partitions,
256 )
257 })
258 .collect::<Vec<_>>()
259 }
260
261 fn extract_cumulative_sums(
262 perm_traces: &[Option<RowMajorMatrix<Challenge>>],
263 ) -> Vec<Option<Challenge>> {
264 perm_traces
265 .iter()
266 .map(|perm_trace| {
267 perm_trace.as_ref().map(|perm_trace| {
268 *perm_trace
269 .row_slice(perm_trace.height() - 1)
270 .last()
271 .unwrap()
272 })
273 })
274 .collect()
275 }
276
277 pub fn generate_after_challenge_trace(
289 all_interactions: &[SymbolicInteraction<F>],
290 trace_view: &PairTraceView<F>,
291 permutation_randomness: &[Challenge; STARK_LU_NUM_CHALLENGES],
292 interaction_partitions: &[Vec<usize>],
293 ) -> Option<RowMajorMatrix<Challenge>>
294 where
295 F: Field,
296 Challenge: ExtensionField<F>,
297 {
298 if all_interactions.is_empty() {
299 return None;
300 }
301 let &[alpha, beta] = permutation_randomness;
302
303 let betas = generate_betas(beta, all_interactions);
304
305 let num_interactions = all_interactions.len();
322 let height = trace_view.partitioned_main[0].height();
323
324 let perm_width = interaction_partitions.len() + 1;
329 let mut perm_values = Challenge::zero_vec(height * perm_width);
330 debug_assert!(
331 trace_view
332 .partitioned_main
333 .iter()
334 .all(|m| m.height() == height),
335 "All main trace parts must have same height"
336 );
337
338 #[cfg(feature = "parallel")]
342 let num_threads = rayon::current_num_threads();
343 #[cfg(not(feature = "parallel"))]
344 let num_threads = 1;
345
346 let preprocessed = trace_view.preprocessed.as_ref().map(|m| m.as_view());
347 let partitioned_main = trace_view
348 .partitioned_main
349 .iter()
350 .map(|m| m.as_view())
351 .collect_vec();
352 let evaluator = |local_index: usize| Evaluator {
353 preprocessed: &preprocessed,
354 partitioned_main: &partitioned_main,
355 public_values: &trace_view.public_values,
356 height,
357 local_index,
358 };
359 let height_per_thread = height.div_ceil(num_threads);
360 perm_values
361 .par_chunks_mut(height_per_thread * perm_width)
362 .enumerate()
363 .for_each(|(thread_idx, perm_values)| {
364 let num_rows = perm_values.len() / perm_width;
366 let mut denoms = Challenge::zero_vec(num_rows * num_interactions);
369 let row_offset = thread_idx * height_per_thread;
370 for (n, denom_row) in denoms.chunks_exact_mut(num_interactions).enumerate() {
372 let evaluator = evaluator(row_offset + n);
373 for (denom, interaction) in denom_row.iter_mut().zip(all_interactions.iter()) {
374 debug_assert!(interaction.message.len() <= betas.len());
375 let b = F::from_canonical_u32(interaction.bus_index as u32 + 1);
376 let mut fields = interaction.message.iter();
377 *denom = alpha
378 + evaluator
379 .eval_expr(fields.next().expect("fields should not be empty"));
380 for (expr, &beta) in fields.zip(betas.iter().skip(1)) {
381 *denom += beta * evaluator.eval_expr(expr);
382 }
383 *denom += betas[interaction.message.len()] * b;
384 }
385 }
386
387 let reciprocals = p3_field::batch_multiplicative_inverse(&denoms);
391 drop(denoms);
392 perm_values
396 .par_chunks_exact_mut(perm_width)
397 .zip(reciprocals.par_chunks_exact(num_interactions))
398 .enumerate()
399 .for_each(|(n, (perm_row, reciprocals))| {
400 debug_assert_eq!(perm_row.len(), perm_width);
401 debug_assert_eq!(reciprocals.len(), num_interactions);
402
403 let evaluator = evaluator(row_offset + n);
404 let mut row_sum = Challenge::ZERO;
405 for (part, perm_val) in zip(interaction_partitions, perm_row.iter_mut()) {
406 for &interaction_idx in part {
407 let interaction = &all_interactions[interaction_idx];
408 let interaction_val = reciprocals[interaction_idx]
409 * evaluator.eval_expr(&interaction.count);
410 *perm_val += interaction_val;
411 }
412 row_sum += *perm_val;
413 }
414
415 perm_row[perm_width - 1] = row_sum;
416 });
417 });
418
419 tracing::trace_span!("compute logup partial sums").in_scope(|| {
422 let mut phi = Challenge::ZERO;
423 for perm_chunk in perm_values.chunks_exact_mut(perm_width) {
424 phi += *perm_chunk.last().unwrap();
425 *perm_chunk.last_mut().unwrap() = phi;
426 }
427 });
428
429 Some(RowMajorMatrix::new(perm_values, perm_width))
430 }
431}
432
433pub fn eval_fri_log_up_phase<AB>(
440 builder: &mut AB,
441 symbolic_interactions: &[SymbolicInteraction<AB::F>],
442 max_constraint_degree: usize,
443) where
444 AB: InteractionBuilder + PermutationAirBuilderWithExposedValues,
445{
446 let exposed_values = builder.permutation_exposed_values();
447 assert_eq!(
449 exposed_values.len(),
450 1,
451 "Should have one exposed value for cumulative_sum"
452 );
453 let cumulative_sum = exposed_values[0];
454
455 let rand_elems = builder.permutation_randomness();
456
457 let perm = builder.permutation();
458 let (perm_local, perm_next) = (perm.row_slice(0), perm.row_slice(1));
459 let perm_local: &[AB::VarEF] = (*perm_local).borrow();
460 let perm_next: &[AB::VarEF] = (*perm_next).borrow();
461
462 let all_interactions = builder.all_interactions().to_vec();
463 let FriLogUpProvingKey {
464 interaction_partitions,
465 } = find_interaction_chunks(symbolic_interactions, max_constraint_degree);
466 let num_chunks = interaction_partitions.len();
467 debug_assert_eq!(num_chunks + 1, perm_local.len());
468
469 let phi_local = *perm_local.last().unwrap();
470 let phi_next = *perm_next.last().unwrap();
471
472 let alpha = rand_elems[0];
473 let betas = generate_betas(rand_elems[1].into(), &all_interactions);
474
475 let phi_lhs = phi_next.into() - phi_local.into();
476 let mut phi_rhs = AB::ExprEF::ZERO;
477 let mut phi_0 = AB::ExprEF::ZERO;
478
479 for (chunk_idx, part) in interaction_partitions.iter().enumerate() {
480 let denoms_per_chunk = part
481 .iter()
482 .map(|&interaction_idx| {
483 let interaction = &all_interactions[interaction_idx];
484 assert!(
485 !interaction.message.is_empty(),
486 "fields should not be empty"
487 );
488 let mut field_hash = AB::ExprEF::ZERO;
489 let b = AB::Expr::from_canonical_u32(interaction.bus_index as u32 + 1);
490 for (field, beta) in interaction.message.iter().chain([&b]).zip(&betas) {
491 field_hash += beta.clone() * field.clone();
492 }
493 field_hash + alpha.into()
494 })
495 .collect_vec();
496
497 let mut row_lhs: AB::ExprEF = perm_local[chunk_idx].into();
498 for denom in denoms_per_chunk.iter() {
499 row_lhs *= denom.clone();
500 }
501
502 let mut row_rhs = AB::ExprEF::ZERO;
503 for (i, &interaction_idx) in part.iter().enumerate() {
504 let interaction = &all_interactions[interaction_idx];
505 let mut term: AB::ExprEF = interaction.count.clone().into();
506 for (j, denom) in denoms_per_chunk.iter().enumerate() {
507 if i != j {
508 term *= denom.clone();
509 }
510 }
511 row_rhs += term;
512 }
513
514 builder.assert_eq_ext(row_lhs, row_rhs);
522
523 phi_0 += perm_local[chunk_idx].into();
524 phi_rhs += perm_next[chunk_idx].into();
525 }
526
527 builder.when_transition().assert_eq_ext(phi_lhs, phi_rhs);
529 builder
530 .when_first_row()
531 .assert_eq_ext(*perm_local.last().unwrap(), phi_0);
532 builder
533 .when_last_row()
534 .assert_eq_ext(*perm_local.last().unwrap(), cumulative_sum);
535}
536
537pub(crate) fn find_interaction_chunks<F: Field>(
565 interactions: &[SymbolicInteraction<F>],
566 max_constraint_degree: usize,
567) -> FriLogUpProvingKey {
568 if interactions.is_empty() {
569 return FriLogUpProvingKey::default();
570 }
571 let max_field_degree = |i: usize| {
573 interactions[i]
574 .message
575 .iter()
576 .map(|f| f.degree_multiple())
577 .max()
578 .unwrap_or(0)
579 };
580 let mut interaction_idxs = (0..interactions.len()).collect_vec();
581 interaction_idxs.sort_by(|&i, &j| {
582 let field_cmp = max_field_degree(i).cmp(&max_field_degree(j));
583 if field_cmp == std::cmp::Ordering::Equal {
584 interactions[i]
585 .count
586 .degree_multiple()
587 .cmp(&interactions[j].count.degree_multiple())
588 } else {
589 field_cmp
590 }
591 });
592 let mut running_sum_field_degree = 0;
594 let mut numerator_max_degree = 0;
595 let mut interaction_partitions = vec![];
596 let mut cur_chunk = vec![];
597 for interaction_idx in interaction_idxs {
598 let field_degree = max_field_degree(interaction_idx);
599 let count_degree = interactions[interaction_idx].count.degree_multiple();
600 let new_num_max_degree = max(
602 numerator_max_degree + field_degree,
603 count_degree + running_sum_field_degree,
604 );
605 let new_denom_degree = running_sum_field_degree + field_degree;
606 if max(new_num_max_degree, new_denom_degree + 1) <= max_constraint_degree {
607 cur_chunk.push(interaction_idx);
609 numerator_max_degree = new_num_max_degree;
610 running_sum_field_degree += field_degree;
611 } else {
612 if !cur_chunk.is_empty() {
614 interaction_partitions.push(mem::take(&mut cur_chunk));
616 }
617 cur_chunk.push(interaction_idx);
618 numerator_max_degree = count_degree;
619 running_sum_field_degree = field_degree;
620 if max_constraint_degree > 0
621 && max(count_degree, field_degree + 1) > max_constraint_degree
622 {
623 panic!("Interaction with field_degree={field_degree}, count_degree={count_degree} exceeds max_constraint_degree={max_constraint_degree}");
624 }
625 }
626 }
627 assert!(!cur_chunk.is_empty());
629 interaction_partitions.push(cur_chunk);
630
631 FriLogUpProvingKey {
632 interaction_partitions,
633 }
634}