1use std::{array, borrow::Borrow, cmp::max, iter::zip, marker::PhantomData, mem};
2
3use itertools::Itertools;
4use p3_air::ExtensionBuilder;
5use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
6use p3_field::{ExtensionField, Field, PrimeCharacteristicRing};
7use p3_matrix::{dense::RowMajorMatrix, Matrix};
8use p3_maybe_rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11use tracing::info_span;
12
13use super::{LogUpSecurityParameters, PairTraceView, SymbolicInteraction};
14use crate::{
15 air_builders::symbolic::{symbolic_expression::SymbolicEvaluator, SymbolicConstraints},
16 interaction::{
17 trace::Evaluator, utils::generate_betas, InteractionBuilder, RapPhaseProverData,
18 RapPhaseSeq, RapPhaseSeqKind, RapPhaseVerifierData,
19 },
20 parizip,
21 rap::PermutationAirBuilderWithExposedValues,
22 utils::parallelize_chunks,
23};
24
25pub struct FriLogUpPhase<F, Challenge, Challenger> {
26 log_up_params: LogUpSecurityParameters,
27 extra_capacity_bits: usize,
30 _marker: PhantomData<(F, Challenge, Challenger)>,
31}
32
33impl<F, Challenge, Challenger> FriLogUpPhase<F, Challenge, Challenger> {
34 pub fn new(log_up_params: LogUpSecurityParameters, extra_capacity_bits: usize) -> Self {
35 Self {
36 log_up_params,
37 extra_capacity_bits,
38 _marker: PhantomData,
39 }
40 }
41}
42
43#[derive(Error, Debug)]
44pub enum FriLogUpError {
45 #[error("non-zero cumulative sum")]
46 NonZeroCumulativeSum,
47 #[error("missing proof")]
48 MissingPartialProof,
49 #[error("invalid proof of work witness")]
50 InvalidPowWitness,
51}
52
53#[derive(Clone, Serialize, Deserialize)]
54pub struct FriLogUpPartialProof<Witness> {
55 pub logup_pow_witness: Witness,
56}
57
58#[derive(Clone, Default, Serialize, Deserialize)]
59pub struct FriLogUpProvingKey {
60 interaction_partitions: Vec<Vec<usize>>,
61}
62
63impl FriLogUpProvingKey {
64 pub fn interaction_partitions(self) -> Vec<Vec<usize>> {
65 self.interaction_partitions
66 }
67 pub fn num_chunks(&self) -> usize {
68 self.interaction_partitions.len()
69 }
70}
71
72impl<F: Field, Challenge, Challenger> RapPhaseSeq<F, Challenge, Challenger>
73 for FriLogUpPhase<F, Challenge, Challenger>
74where
75 F: Field,
76 Challenge: ExtensionField<F>,
77 Challenger: FieldChallenger<F> + GrindingChallenger<Witness = F>,
78{
79 type PartialProof = FriLogUpPartialProof<F>;
80 type PartialProvingKey = FriLogUpProvingKey;
81 type Error = FriLogUpError;
82 const ID: RapPhaseSeqKind = RapPhaseSeqKind::FriLogUp;
83
84 fn log_up_security_params(&self) -> &LogUpSecurityParameters {
85 &self.log_up_params
86 }
87
88 fn generate_pk_per_air(
89 &self,
90 symbolic_constraints_per_air: &[SymbolicConstraints<F>],
91 max_constraint_degree: usize,
92 ) -> Vec<Self::PartialProvingKey> {
93 symbolic_constraints_per_air
94 .iter()
95 .map(|constraints| {
96 find_interaction_chunks(&constraints.interactions, max_constraint_degree)
97 })
98 .collect()
99 }
100
101 fn partially_prove(
103 &self,
104 challenger: &mut Challenger,
105 constraints_per_air: &[&SymbolicConstraints<F>],
106 params_per_air: &[&FriLogUpProvingKey],
107 trace_view_per_air: Vec<PairTraceView<F>>,
108 ) -> Option<(Self::PartialProof, RapPhaseProverData<Challenge>)> {
109 let has_any_interactions = constraints_per_air
110 .iter()
111 .any(|constraints| !constraints.interactions.is_empty());
112
113 if !has_any_interactions {
114 return None;
115 }
116
117 let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits);
119 let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] =
120 array::from_fn(|_| challenger.sample_algebra_element::<Challenge>());
121
122 let after_challenge_trace_per_air = info_span!("generate_perm_trace").in_scope(|| {
123 Self::generate_after_challenge_traces_per_air(
124 &challenges,
125 constraints_per_air,
126 params_per_air,
127 trace_view_per_air,
128 self.extra_capacity_bits,
129 )
130 });
131 let cumulative_sum_per_air = Self::extract_cumulative_sums(&after_challenge_trace_per_air);
132
133 for cumulative_sum in cumulative_sum_per_air.iter().flatten() {
135 challenger.observe_slice(cumulative_sum.as_basis_coefficients_slice());
136 }
137
138 let exposed_values_per_air = cumulative_sum_per_air
139 .iter()
140 .map(|csum| csum.map(|csum| vec![csum]))
141 .collect_vec();
142
143 Some((
144 FriLogUpPartialProof { logup_pow_witness },
145 RapPhaseProverData {
146 challenges: challenges.to_vec(),
147 after_challenge_trace_per_air,
148 exposed_values_per_air,
149 },
150 ))
151 }
152
153 fn partially_verify<Commitment: Clone>(
154 &self,
155 challenger: &mut Challenger,
156 partial_proof: Option<&Self::PartialProof>,
157 exposed_values_per_phase_per_air: &[Vec<Vec<Challenge>>],
158 commitment_per_phase: &[Commitment],
159 _permutation_opened_values: &[Vec<Vec<Vec<Challenge>>>],
160 ) -> (RapPhaseVerifierData<Challenge>, Result<(), Self::Error>)
161 where
162 Challenger: CanObserve<Commitment>,
163 {
164 if exposed_values_per_phase_per_air
165 .iter()
166 .all(|exposed_values_per_phase_per_air| exposed_values_per_phase_per_air.is_empty())
167 {
168 return (RapPhaseVerifierData::default(), Ok(()));
169 }
170
171 let partial_proof = match partial_proof {
172 Some(proof) => proof,
173 None => {
174 return (
175 RapPhaseVerifierData::default(),
176 Err(FriLogUpError::MissingPartialProof),
177 );
178 }
179 };
180
181 if !challenger.check_witness(
182 self.log_up_params.log_up_pow_bits,
183 partial_proof.logup_pow_witness,
184 ) {
185 return (
186 RapPhaseVerifierData::default(),
187 Err(FriLogUpError::InvalidPowWitness),
188 );
189 }
190
191 let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] =
192 array::from_fn(|_| challenger.sample_algebra_element::<Challenge>());
193
194 for exposed_values_per_phase in exposed_values_per_phase_per_air.iter() {
195 if let Some(exposed_values) = exposed_values_per_phase.first() {
196 for exposed_value in exposed_values {
197 challenger.observe_slice(exposed_value.as_basis_coefficients_slice());
198 }
199 }
200 }
201
202 challenger.observe(commitment_per_phase[0].clone());
203
204 let cumulative_sums = exposed_values_per_phase_per_air
205 .iter()
206 .map(|exposed_values_per_phase| {
207 assert!(
208 exposed_values_per_phase.len() <= 1,
209 "Verifier does not support more than 1 challenge phase"
210 );
211 exposed_values_per_phase.first().map(|exposed_values| {
212 assert_eq!(
213 exposed_values.len(),
214 1,
215 "Only exposed value should be cumulative sum"
216 );
217 exposed_values[0]
218 })
219 })
220 .collect_vec();
221
222 let sum: Challenge = cumulative_sums
224 .into_iter()
225 .map(|c| c.unwrap_or(Challenge::ZERO))
226 .sum();
227
228 let result = if sum == Challenge::ZERO {
229 Ok(())
230 } else {
231 Err(Self::Error::NonZeroCumulativeSum)
232 };
233 let verifier_data = RapPhaseVerifierData {
234 challenges_per_phase: vec![challenges.to_vec()],
235 };
236 (verifier_data, result)
237 }
238}
239
240pub const STARK_LU_NUM_CHALLENGES: usize = 2;
241pub const STARK_LU_NUM_EXPOSED_VALUES: usize = 1;
242
243impl<F, Challenge, Challenger> FriLogUpPhase<F, Challenge, Challenger>
244where
245 F: Field,
246 Challenge: ExtensionField<F>,
247 Challenger: FieldChallenger<F>,
248{
249 fn generate_after_challenge_traces_per_air(
253 challenges: &[Challenge; STARK_LU_NUM_CHALLENGES],
254 constraints_per_air: &[&SymbolicConstraints<F>],
255 params_per_air: &[&FriLogUpProvingKey],
256 trace_view_per_air: Vec<PairTraceView<F>>,
257 extra_capacity_bits: usize,
258 ) -> Vec<Option<RowMajorMatrix<Challenge>>> {
259 parizip!(constraints_per_air, trace_view_per_air, params_per_air)
260 .map(|(constraints, trace_view, params)| {
261 Self::generate_after_challenge_trace(
262 &constraints.interactions,
263 trace_view,
264 challenges,
265 ¶ms.interaction_partitions,
266 extra_capacity_bits,
267 )
268 })
269 .collect::<Vec<_>>()
270 }
271
272 fn extract_cumulative_sums(
273 perm_traces: &[Option<RowMajorMatrix<Challenge>>],
274 ) -> Vec<Option<Challenge>> {
275 perm_traces
276 .iter()
277 .map(|perm_trace| {
278 perm_trace.as_ref().map(|perm_trace| {
279 *perm_trace
280 .row_slice(perm_trace.height() - 1)
281 .unwrap()
282 .last()
283 .unwrap()
284 })
285 })
286 .collect()
287 }
288
289 pub fn generate_after_challenge_trace(
301 all_interactions: &[SymbolicInteraction<F>],
302 trace_view: PairTraceView<F>,
303 permutation_randomness: &[Challenge; STARK_LU_NUM_CHALLENGES],
304 interaction_partitions: &[Vec<usize>],
305 extra_capacity_bits: usize,
306 ) -> Option<RowMajorMatrix<Challenge>>
307 where
308 F: Field,
309 Challenge: ExtensionField<F>,
310 {
311 if all_interactions.is_empty() {
312 return None;
313 }
314 let &[alpha, beta] = permutation_randomness;
315
316 let betas = generate_betas(beta, all_interactions);
317
318 let num_interactions = all_interactions.len();
335 let height = trace_view.partitioned_main[0].height();
336
337 let perm_width = interaction_partitions.len() + 1;
342 let perm_trace_len = height * perm_width;
345 let mut perm_values = Challenge::zero_vec(perm_trace_len << extra_capacity_bits);
346 perm_values.truncate(perm_trace_len);
347 debug_assert!(
348 trace_view
349 .partitioned_main
350 .iter()
351 .all(|m| m.height() == height),
352 "All main trace parts must have same height"
353 );
354
355 let preprocessed = trace_view.preprocessed.as_ref().map(|m| m.as_view());
356 let partitioned_main = trace_view
357 .partitioned_main
358 .iter()
359 .map(|m| m.as_view())
360 .collect_vec();
361 let evaluator = |local_index: usize| Evaluator {
362 preprocessed: &preprocessed,
363 partitioned_main: &partitioned_main,
364 public_values: &trace_view.public_values,
365 height,
366 local_index,
367 };
368 parallelize_chunks(&mut perm_values, perm_width, |perm_values, idx| {
369 debug_assert_eq!(perm_values.len() % perm_width, 0);
370 debug_assert_eq!(idx % perm_width, 0);
371 let num_rows = perm_values.len() / perm_width;
373 let mut denoms = Challenge::zero_vec(num_rows * num_interactions);
376 let row_offset = idx / perm_width;
377 for (n, denom_row) in denoms.chunks_exact_mut(num_interactions).enumerate() {
379 let evaluator = evaluator(row_offset + n);
380 for (denom, interaction) in denom_row.iter_mut().zip(all_interactions.iter()) {
381 debug_assert!(interaction.message.len() <= betas.len());
382 let b = F::from_u32(interaction.bus_index as u32 + 1);
383 let mut fields = interaction.message.iter();
384 *denom = alpha
385 + evaluator.eval_expr(fields.next().expect("fields should not be empty"));
386 for (expr, &beta) in fields.zip(betas.iter().skip(1)) {
387 *denom += beta * evaluator.eval_expr(expr);
388 }
389 *denom += betas[interaction.message.len()] * b;
390 }
391 }
392
393 let reciprocals = p3_field::batch_multiplicative_inverse(&denoms);
397 drop(denoms);
398 perm_values
402 .par_chunks_exact_mut(perm_width)
403 .zip(reciprocals.par_chunks_exact(num_interactions))
404 .enumerate()
405 .for_each(|(n, (perm_row, reciprocals))| {
406 debug_assert_eq!(perm_row.len(), perm_width);
407 debug_assert_eq!(reciprocals.len(), num_interactions);
408
409 let evaluator = evaluator(row_offset + n);
410 let mut row_sum = Challenge::ZERO;
411 for (part, perm_val) in zip(interaction_partitions, perm_row.iter_mut()) {
412 for &interaction_idx in part {
413 let interaction = &all_interactions[interaction_idx];
414 let interaction_val = reciprocals[interaction_idx]
415 * evaluator.eval_expr(&interaction.count);
416 *perm_val += interaction_val;
417 }
418 row_sum += *perm_val;
419 }
420
421 perm_row[perm_width - 1] = row_sum;
422 });
423 });
424 drop(trace_view);
426
427 tracing::trace_span!("compute logup partial sums").in_scope(|| {
430 let mut phi = Challenge::ZERO;
431 for perm_chunk in perm_values.chunks_exact_mut(perm_width) {
432 phi += *perm_chunk.last().unwrap();
433 *perm_chunk.last_mut().unwrap() = phi;
434 }
435 });
436
437 Some(RowMajorMatrix::new(perm_values, perm_width))
438 }
439}
440
441pub fn eval_fri_log_up_phase<AB>(
448 builder: &mut AB,
449 symbolic_interactions: &[SymbolicInteraction<AB::F>],
450 max_constraint_degree: usize,
451) where
452 AB: InteractionBuilder + PermutationAirBuilderWithExposedValues,
453{
454 let exposed_values = builder.permutation_exposed_values();
455 assert_eq!(
457 exposed_values.len(),
458 1,
459 "Should have one exposed value for cumulative_sum"
460 );
461 let cumulative_sum = exposed_values[0];
462
463 let rand_elems = builder.permutation_randomness();
464
465 let perm = builder.permutation();
466 let (perm_local, perm_next) = (
467 perm.row_slice(0).expect("window should have two elements"),
468 perm.row_slice(1).expect("window should have two elements"),
469 );
470 let perm_local: &[AB::VarEF] = (*perm_local).borrow();
471 let perm_next: &[AB::VarEF] = (*perm_next).borrow();
472
473 let all_interactions = builder.all_interactions().to_vec();
474 let FriLogUpProvingKey {
475 interaction_partitions,
476 } = find_interaction_chunks(symbolic_interactions, max_constraint_degree);
477 let num_chunks = interaction_partitions.len();
478 debug_assert_eq!(num_chunks + 1, perm_local.len());
479
480 let phi_local = *perm_local.last().unwrap();
481 let phi_next = *perm_next.last().unwrap();
482
483 let alpha = rand_elems[0];
484 let betas = generate_betas(rand_elems[1].into(), &all_interactions);
485
486 let phi_lhs = phi_next.into() - phi_local.into();
487 let mut phi_rhs = AB::ExprEF::ZERO;
488 let mut phi_0 = AB::ExprEF::ZERO;
489
490 for (chunk_idx, part) in interaction_partitions.iter().enumerate() {
491 let denoms_per_chunk = part
492 .iter()
493 .map(|&interaction_idx| {
494 let interaction = &all_interactions[interaction_idx];
495 assert!(
496 !interaction.message.is_empty(),
497 "fields should not be empty"
498 );
499 let mut field_hash = AB::ExprEF::ZERO;
500 let b = AB::Expr::from_u32(interaction.bus_index as u32 + 1);
501 for (field, beta) in interaction.message.iter().chain([&b]).zip(&betas) {
502 let field_ext: AB::ExprEF = field.clone().into();
503 field_hash += beta.clone() * field_ext;
504 }
505 field_hash + alpha.into()
506 })
507 .collect_vec();
508
509 let mut row_lhs: AB::ExprEF = perm_local[chunk_idx].into();
510 for denom in denoms_per_chunk.iter() {
511 row_lhs *= denom.clone();
512 }
513
514 let mut row_rhs = AB::ExprEF::ZERO;
515 for (i, &interaction_idx) in part.iter().enumerate() {
516 let interaction = &all_interactions[interaction_idx];
517 let mut term: AB::ExprEF = interaction.count.clone().into();
518 for (j, denom) in denoms_per_chunk.iter().enumerate() {
519 if i != j {
520 term *= denom.clone();
521 }
522 }
523 row_rhs += term;
524 }
525
526 builder.assert_eq_ext(row_lhs, row_rhs);
535
536 phi_0 += perm_local[chunk_idx].into();
537 phi_rhs += perm_next[chunk_idx].into();
538 }
539
540 builder.when_transition().assert_eq_ext(phi_lhs, phi_rhs);
542 builder
543 .when_first_row()
544 .assert_eq_ext(*perm_local.last().unwrap(), phi_0);
545 builder
546 .when_last_row()
547 .assert_eq_ext(*perm_local.last().unwrap(), cumulative_sum);
548}
549
550pub(crate) fn find_interaction_chunks<F: Field>(
579 interactions: &[SymbolicInteraction<F>],
580 max_constraint_degree: usize,
581) -> FriLogUpProvingKey {
582 if interactions.is_empty() {
583 return FriLogUpProvingKey::default();
584 }
585 let max_field_degree = |i: usize| {
587 interactions[i]
588 .message
589 .iter()
590 .map(|f| f.degree_multiple())
591 .max()
592 .unwrap_or(0)
593 };
594 let mut interaction_idxs = (0..interactions.len()).collect_vec();
595 interaction_idxs.sort_by(|&i, &j| {
596 let field_cmp = max_field_degree(i).cmp(&max_field_degree(j));
597 if field_cmp == std::cmp::Ordering::Equal {
598 interactions[i]
599 .count
600 .degree_multiple()
601 .cmp(&interactions[j].count.degree_multiple())
602 } else {
603 field_cmp
604 }
605 });
606 let mut running_sum_field_degree = 0;
608 let mut numerator_max_degree = 0;
609 let mut interaction_partitions = vec![];
610 let mut cur_chunk = vec![];
611 for interaction_idx in interaction_idxs {
612 let field_degree = max_field_degree(interaction_idx);
613 let count_degree = interactions[interaction_idx].count.degree_multiple();
614 let new_num_max_degree = max(
616 numerator_max_degree + field_degree,
617 count_degree + running_sum_field_degree,
618 );
619 let new_denom_degree = running_sum_field_degree + field_degree;
620 if max(new_num_max_degree, new_denom_degree + 1) <= max_constraint_degree {
621 cur_chunk.push(interaction_idx);
623 numerator_max_degree = new_num_max_degree;
624 running_sum_field_degree += field_degree;
625 } else {
626 if !cur_chunk.is_empty() {
628 interaction_partitions.push(mem::take(&mut cur_chunk));
630 }
631 cur_chunk.push(interaction_idx);
632 numerator_max_degree = count_degree;
633 running_sum_field_degree = field_degree;
634 if max_constraint_degree > 0
635 && max(count_degree, field_degree + 1) > max_constraint_degree
636 {
637 panic!("Interaction with field_degree={field_degree}, count_degree={count_degree} exceeds max_constraint_degree={max_constraint_degree}");
638 }
639 }
640 }
641 assert!(!cur_chunk.is_empty());
643 interaction_partitions.push(cur_chunk);
644
645 FriLogUpProvingKey {
646 interaction_partitions,
647 }
648}