1use std::{array, borrow::Borrow, cmp::max, iter::zip, marker::PhantomData, mem};
2
3use itertools::Itertools;
4use p3_air::ExtensionBuilder;
5use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
6use p3_field::{ExtensionField, Field, FieldAlgebra};
7use p3_matrix::{dense::RowMajorMatrix, Matrix};
8use p3_maybe_rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11use tracing::info_span;
12
13use super::{LogUpSecurityParameters, PairTraceView, SymbolicInteraction};
14use crate::{
15 air_builders::symbolic::{symbolic_expression::SymbolicEvaluator, SymbolicConstraints},
16 interaction::{
17 trace::Evaluator, utils::generate_betas, InteractionBuilder, RapPhaseProverData,
18 RapPhaseSeq, RapPhaseSeqKind, RapPhaseVerifierData,
19 },
20 parizip,
21 rap::PermutationAirBuilderWithExposedValues,
22 utils::parallelize_chunks,
23};
24
25pub struct FriLogUpPhase<F, Challenge, Challenger> {
26 log_up_params: LogUpSecurityParameters,
27 extra_capacity_bits: usize,
30 _marker: PhantomData<(F, Challenge, Challenger)>,
31}
32
33impl<F, Challenge, Challenger> FriLogUpPhase<F, Challenge, Challenger> {
34 pub fn new(log_up_params: LogUpSecurityParameters, extra_capacity_bits: usize) -> Self {
35 Self {
36 log_up_params,
37 extra_capacity_bits,
38 _marker: PhantomData,
39 }
40 }
41}
42
43#[derive(Error, Debug)]
44pub enum FriLogUpError {
45 #[error("non-zero cumulative sum")]
46 NonZeroCumulativeSum,
47 #[error("missing proof")]
48 MissingPartialProof,
49 #[error("invalid proof of work witness")]
50 InvalidPowWitness,
51}
52
53#[derive(Clone, Serialize, Deserialize)]
54pub struct FriLogUpPartialProof<Witness> {
55 pub logup_pow_witness: Witness,
56}
57
58#[derive(Clone, Default, Serialize, Deserialize)]
59pub struct FriLogUpProvingKey {
60 interaction_partitions: Vec<Vec<usize>>,
61}
62
63impl FriLogUpProvingKey {
64 pub fn interaction_partitions(self) -> Vec<Vec<usize>> {
65 self.interaction_partitions
66 }
67 pub fn num_chunks(&self) -> usize {
68 self.interaction_partitions.len()
69 }
70}
71
72impl<F: Field, Challenge, Challenger> RapPhaseSeq<F, Challenge, Challenger>
73 for FriLogUpPhase<F, Challenge, Challenger>
74where
75 F: Field,
76 Challenge: ExtensionField<F>,
77 Challenger: FieldChallenger<F> + GrindingChallenger<Witness = F>,
78{
79 type PartialProof = FriLogUpPartialProof<F>;
80 type PartialProvingKey = FriLogUpProvingKey;
81 type Error = FriLogUpError;
82 const ID: RapPhaseSeqKind = RapPhaseSeqKind::FriLogUp;
83
84 fn log_up_security_params(&self) -> &LogUpSecurityParameters {
85 &self.log_up_params
86 }
87
88 fn generate_pk_per_air(
89 &self,
90 symbolic_constraints_per_air: &[SymbolicConstraints<F>],
91 max_constraint_degree: usize,
92 ) -> Vec<Self::PartialProvingKey> {
93 symbolic_constraints_per_air
94 .iter()
95 .map(|constraints| {
96 find_interaction_chunks(&constraints.interactions, max_constraint_degree)
97 })
98 .collect()
99 }
100
101 fn partially_prove(
103 &self,
104 challenger: &mut Challenger,
105 constraints_per_air: &[&SymbolicConstraints<F>],
106 params_per_air: &[&FriLogUpProvingKey],
107 trace_view_per_air: Vec<PairTraceView<F>>,
108 ) -> Option<(Self::PartialProof, RapPhaseProverData<Challenge>)> {
109 let has_any_interactions = constraints_per_air
110 .iter()
111 .any(|constraints| !constraints.interactions.is_empty());
112
113 if !has_any_interactions {
114 return None;
115 }
116
117 let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits);
119 let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] =
120 array::from_fn(|_| challenger.sample_ext_element::<Challenge>());
121
122 let after_challenge_trace_per_air = info_span!("generate_perm_trace").in_scope(|| {
123 Self::generate_after_challenge_traces_per_air(
124 &challenges,
125 constraints_per_air,
126 params_per_air,
127 trace_view_per_air,
128 self.extra_capacity_bits,
129 )
130 });
131 let cumulative_sum_per_air = Self::extract_cumulative_sums(&after_challenge_trace_per_air);
132
133 for cumulative_sum in cumulative_sum_per_air.iter().flatten() {
135 challenger.observe_slice(cumulative_sum.as_base_slice());
136 }
137
138 let exposed_values_per_air = cumulative_sum_per_air
139 .iter()
140 .map(|csum| csum.map(|csum| vec![csum]))
141 .collect_vec();
142
143 Some((
144 FriLogUpPartialProof { logup_pow_witness },
145 RapPhaseProverData {
146 challenges: challenges.to_vec(),
147 after_challenge_trace_per_air,
148 exposed_values_per_air,
149 },
150 ))
151 }
152
153 fn partially_verify<Commitment: Clone>(
154 &self,
155 challenger: &mut Challenger,
156 partial_proof: Option<&Self::PartialProof>,
157 exposed_values_per_phase_per_air: &[Vec<Vec<Challenge>>],
158 commitment_per_phase: &[Commitment],
159 _permutation_opened_values: &[Vec<Vec<Vec<Challenge>>>],
160 ) -> (RapPhaseVerifierData<Challenge>, Result<(), Self::Error>)
161 where
162 Challenger: CanObserve<Commitment>,
163 {
164 if exposed_values_per_phase_per_air
165 .iter()
166 .all(|exposed_values_per_phase_per_air| exposed_values_per_phase_per_air.is_empty())
167 {
168 return (RapPhaseVerifierData::default(), Ok(()));
169 }
170
171 let partial_proof = match partial_proof {
172 Some(proof) => proof,
173 None => {
174 return (
175 RapPhaseVerifierData::default(),
176 Err(FriLogUpError::MissingPartialProof),
177 );
178 }
179 };
180
181 if !challenger.check_witness(
182 self.log_up_params.log_up_pow_bits,
183 partial_proof.logup_pow_witness,
184 ) {
185 return (
186 RapPhaseVerifierData::default(),
187 Err(FriLogUpError::InvalidPowWitness),
188 );
189 }
190
191 let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] =
192 array::from_fn(|_| challenger.sample_ext_element::<Challenge>());
193
194 for exposed_values_per_phase in exposed_values_per_phase_per_air.iter() {
195 if let Some(exposed_values) = exposed_values_per_phase.first() {
196 for exposed_value in exposed_values {
197 challenger.observe_slice(exposed_value.as_base_slice());
198 }
199 }
200 }
201
202 challenger.observe(commitment_per_phase[0].clone());
203
204 let cumulative_sums = exposed_values_per_phase_per_air
205 .iter()
206 .map(|exposed_values_per_phase| {
207 assert!(
208 exposed_values_per_phase.len() <= 1,
209 "Verifier does not support more than 1 challenge phase"
210 );
211 exposed_values_per_phase.first().map(|exposed_values| {
212 assert_eq!(
213 exposed_values.len(),
214 1,
215 "Only exposed value should be cumulative sum"
216 );
217 exposed_values[0]
218 })
219 })
220 .collect_vec();
221
222 let sum: Challenge = cumulative_sums
224 .into_iter()
225 .map(|c| c.unwrap_or(Challenge::ZERO))
226 .sum();
227
228 let result = if sum == Challenge::ZERO {
229 Ok(())
230 } else {
231 Err(Self::Error::NonZeroCumulativeSum)
232 };
233 let verifier_data = RapPhaseVerifierData {
234 challenges_per_phase: vec![challenges.to_vec()],
235 };
236 (verifier_data, result)
237 }
238}
239
240pub const STARK_LU_NUM_CHALLENGES: usize = 2;
241pub const STARK_LU_NUM_EXPOSED_VALUES: usize = 1;
242
243impl<F, Challenge, Challenger> FriLogUpPhase<F, Challenge, Challenger>
244where
245 F: Field,
246 Challenge: ExtensionField<F>,
247 Challenger: FieldChallenger<F>,
248{
249 fn generate_after_challenge_traces_per_air(
253 challenges: &[Challenge; STARK_LU_NUM_CHALLENGES],
254 constraints_per_air: &[&SymbolicConstraints<F>],
255 params_per_air: &[&FriLogUpProvingKey],
256 trace_view_per_air: Vec<PairTraceView<F>>,
257 extra_capacity_bits: usize,
258 ) -> Vec<Option<RowMajorMatrix<Challenge>>> {
259 parizip!(constraints_per_air, trace_view_per_air, params_per_air)
260 .map(|(constraints, trace_view, params)| {
261 Self::generate_after_challenge_trace(
262 &constraints.interactions,
263 trace_view,
264 challenges,
265 ¶ms.interaction_partitions,
266 extra_capacity_bits,
267 )
268 })
269 .collect::<Vec<_>>()
270 }
271
272 fn extract_cumulative_sums(
273 perm_traces: &[Option<RowMajorMatrix<Challenge>>],
274 ) -> Vec<Option<Challenge>> {
275 perm_traces
276 .iter()
277 .map(|perm_trace| {
278 perm_trace.as_ref().map(|perm_trace| {
279 *perm_trace
280 .row_slice(perm_trace.height() - 1)
281 .last()
282 .unwrap()
283 })
284 })
285 .collect()
286 }
287
288 pub fn generate_after_challenge_trace(
300 all_interactions: &[SymbolicInteraction<F>],
301 trace_view: PairTraceView<F>,
302 permutation_randomness: &[Challenge; STARK_LU_NUM_CHALLENGES],
303 interaction_partitions: &[Vec<usize>],
304 extra_capacity_bits: usize,
305 ) -> Option<RowMajorMatrix<Challenge>>
306 where
307 F: Field,
308 Challenge: ExtensionField<F>,
309 {
310 if all_interactions.is_empty() {
311 return None;
312 }
313 let &[alpha, beta] = permutation_randomness;
314
315 let betas = generate_betas(beta, all_interactions);
316
317 let num_interactions = all_interactions.len();
334 let height = trace_view.partitioned_main[0].height();
335
336 let perm_width = interaction_partitions.len() + 1;
341 let perm_trace_len = height * perm_width;
344 let mut perm_values = Challenge::zero_vec(perm_trace_len << extra_capacity_bits);
345 perm_values.truncate(perm_trace_len);
346 debug_assert!(
347 trace_view
348 .partitioned_main
349 .iter()
350 .all(|m| m.height() == height),
351 "All main trace parts must have same height"
352 );
353
354 let preprocessed = trace_view.preprocessed.as_ref().map(|m| m.as_view());
355 let partitioned_main = trace_view
356 .partitioned_main
357 .iter()
358 .map(|m| m.as_view())
359 .collect_vec();
360 let evaluator = |local_index: usize| Evaluator {
361 preprocessed: &preprocessed,
362 partitioned_main: &partitioned_main,
363 public_values: &trace_view.public_values,
364 height,
365 local_index,
366 };
367 parallelize_chunks(&mut perm_values, perm_width, |perm_values, idx| {
368 debug_assert_eq!(perm_values.len() % perm_width, 0);
369 debug_assert_eq!(idx % perm_width, 0);
370 let num_rows = perm_values.len() / perm_width;
372 let mut denoms = Challenge::zero_vec(num_rows * num_interactions);
375 let row_offset = idx / perm_width;
376 for (n, denom_row) in denoms.chunks_exact_mut(num_interactions).enumerate() {
378 let evaluator = evaluator(row_offset + n);
379 for (denom, interaction) in denom_row.iter_mut().zip(all_interactions.iter()) {
380 debug_assert!(interaction.message.len() <= betas.len());
381 let b = F::from_canonical_u32(interaction.bus_index as u32 + 1);
382 let mut fields = interaction.message.iter();
383 *denom = alpha
384 + evaluator.eval_expr(fields.next().expect("fields should not be empty"));
385 for (expr, &beta) in fields.zip(betas.iter().skip(1)) {
386 *denom += beta * evaluator.eval_expr(expr);
387 }
388 *denom += betas[interaction.message.len()] * b;
389 }
390 }
391
392 let reciprocals = p3_field::batch_multiplicative_inverse(&denoms);
396 drop(denoms);
397 perm_values
401 .par_chunks_exact_mut(perm_width)
402 .zip(reciprocals.par_chunks_exact(num_interactions))
403 .enumerate()
404 .for_each(|(n, (perm_row, reciprocals))| {
405 debug_assert_eq!(perm_row.len(), perm_width);
406 debug_assert_eq!(reciprocals.len(), num_interactions);
407
408 let evaluator = evaluator(row_offset + n);
409 let mut row_sum = Challenge::ZERO;
410 for (part, perm_val) in zip(interaction_partitions, perm_row.iter_mut()) {
411 for &interaction_idx in part {
412 let interaction = &all_interactions[interaction_idx];
413 let interaction_val = reciprocals[interaction_idx]
414 * evaluator.eval_expr(&interaction.count);
415 *perm_val += interaction_val;
416 }
417 row_sum += *perm_val;
418 }
419
420 perm_row[perm_width - 1] = row_sum;
421 });
422 });
423 drop(trace_view);
425
426 tracing::trace_span!("compute logup partial sums").in_scope(|| {
429 let mut phi = Challenge::ZERO;
430 for perm_chunk in perm_values.chunks_exact_mut(perm_width) {
431 phi += *perm_chunk.last().unwrap();
432 *perm_chunk.last_mut().unwrap() = phi;
433 }
434 });
435
436 Some(RowMajorMatrix::new(perm_values, perm_width))
437 }
438}
439
440pub fn eval_fri_log_up_phase<AB>(
447 builder: &mut AB,
448 symbolic_interactions: &[SymbolicInteraction<AB::F>],
449 max_constraint_degree: usize,
450) where
451 AB: InteractionBuilder + PermutationAirBuilderWithExposedValues,
452{
453 let exposed_values = builder.permutation_exposed_values();
454 assert_eq!(
456 exposed_values.len(),
457 1,
458 "Should have one exposed value for cumulative_sum"
459 );
460 let cumulative_sum = exposed_values[0];
461
462 let rand_elems = builder.permutation_randomness();
463
464 let perm = builder.permutation();
465 let (perm_local, perm_next) = (perm.row_slice(0), perm.row_slice(1));
466 let perm_local: &[AB::VarEF] = (*perm_local).borrow();
467 let perm_next: &[AB::VarEF] = (*perm_next).borrow();
468
469 let all_interactions = builder.all_interactions().to_vec();
470 let FriLogUpProvingKey {
471 interaction_partitions,
472 } = find_interaction_chunks(symbolic_interactions, max_constraint_degree);
473 let num_chunks = interaction_partitions.len();
474 debug_assert_eq!(num_chunks + 1, perm_local.len());
475
476 let phi_local = *perm_local.last().unwrap();
477 let phi_next = *perm_next.last().unwrap();
478
479 let alpha = rand_elems[0];
480 let betas = generate_betas(rand_elems[1].into(), &all_interactions);
481
482 let phi_lhs = phi_next.into() - phi_local.into();
483 let mut phi_rhs = AB::ExprEF::ZERO;
484 let mut phi_0 = AB::ExprEF::ZERO;
485
486 for (chunk_idx, part) in interaction_partitions.iter().enumerate() {
487 let denoms_per_chunk = part
488 .iter()
489 .map(|&interaction_idx| {
490 let interaction = &all_interactions[interaction_idx];
491 assert!(
492 !interaction.message.is_empty(),
493 "fields should not be empty"
494 );
495 let mut field_hash = AB::ExprEF::ZERO;
496 let b = AB::Expr::from_canonical_u32(interaction.bus_index as u32 + 1);
497 for (field, beta) in interaction.message.iter().chain([&b]).zip(&betas) {
498 field_hash += beta.clone() * field.clone();
499 }
500 field_hash + alpha.into()
501 })
502 .collect_vec();
503
504 let mut row_lhs: AB::ExprEF = perm_local[chunk_idx].into();
505 for denom in denoms_per_chunk.iter() {
506 row_lhs *= denom.clone();
507 }
508
509 let mut row_rhs = AB::ExprEF::ZERO;
510 for (i, &interaction_idx) in part.iter().enumerate() {
511 let interaction = &all_interactions[interaction_idx];
512 let mut term: AB::ExprEF = interaction.count.clone().into();
513 for (j, denom) in denoms_per_chunk.iter().enumerate() {
514 if i != j {
515 term *= denom.clone();
516 }
517 }
518 row_rhs += term;
519 }
520
521 builder.assert_eq_ext(row_lhs, row_rhs);
530
531 phi_0 += perm_local[chunk_idx].into();
532 phi_rhs += perm_next[chunk_idx].into();
533 }
534
535 builder.when_transition().assert_eq_ext(phi_lhs, phi_rhs);
537 builder
538 .when_first_row()
539 .assert_eq_ext(*perm_local.last().unwrap(), phi_0);
540 builder
541 .when_last_row()
542 .assert_eq_ext(*perm_local.last().unwrap(), cumulative_sum);
543}
544
545pub(crate) fn find_interaction_chunks<F: Field>(
574 interactions: &[SymbolicInteraction<F>],
575 max_constraint_degree: usize,
576) -> FriLogUpProvingKey {
577 if interactions.is_empty() {
578 return FriLogUpProvingKey::default();
579 }
580 let max_field_degree = |i: usize| {
582 interactions[i]
583 .message
584 .iter()
585 .map(|f| f.degree_multiple())
586 .max()
587 .unwrap_or(0)
588 };
589 let mut interaction_idxs = (0..interactions.len()).collect_vec();
590 interaction_idxs.sort_by(|&i, &j| {
591 let field_cmp = max_field_degree(i).cmp(&max_field_degree(j));
592 if field_cmp == std::cmp::Ordering::Equal {
593 interactions[i]
594 .count
595 .degree_multiple()
596 .cmp(&interactions[j].count.degree_multiple())
597 } else {
598 field_cmp
599 }
600 });
601 let mut running_sum_field_degree = 0;
603 let mut numerator_max_degree = 0;
604 let mut interaction_partitions = vec![];
605 let mut cur_chunk = vec![];
606 for interaction_idx in interaction_idxs {
607 let field_degree = max_field_degree(interaction_idx);
608 let count_degree = interactions[interaction_idx].count.degree_multiple();
609 let new_num_max_degree = max(
611 numerator_max_degree + field_degree,
612 count_degree + running_sum_field_degree,
613 );
614 let new_denom_degree = running_sum_field_degree + field_degree;
615 if max(new_num_max_degree, new_denom_degree + 1) <= max_constraint_degree {
616 cur_chunk.push(interaction_idx);
618 numerator_max_degree = new_num_max_degree;
619 running_sum_field_degree += field_degree;
620 } else {
621 if !cur_chunk.is_empty() {
623 interaction_partitions.push(mem::take(&mut cur_chunk));
625 }
626 cur_chunk.push(interaction_idx);
627 numerator_max_degree = count_degree;
628 running_sum_field_degree = field_degree;
629 if max_constraint_degree > 0
630 && max(count_degree, field_degree + 1) > max_constraint_degree
631 {
632 panic!("Interaction with field_degree={field_degree}, count_degree={count_degree} exceeds max_constraint_degree={max_constraint_degree}");
633 }
634 }
635 }
636 assert!(!cur_chunk.is_empty());
638 interaction_partitions.push(cur_chunk);
639
640 FriLogUpProvingKey {
641 interaction_partitions,
642 }
643}