use std::{array, borrow::Borrow, cmp::max, iter::zip, marker::PhantomData, mem};
use itertools::Itertools;
use p3_air::ExtensionBuilder;
use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
use p3_field::{ExtensionField, Field, FieldAlgebra};
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use p3_maybe_rayon::prelude::*;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::{LogUpSecurityParameters, PairTraceView, SymbolicInteraction};
use crate::{
air_builders::symbolic::{symbolic_expression::SymbolicEvaluator, SymbolicConstraints},
interaction::{
trace::Evaluator, utils::generate_betas, InteractionBuilder, RapPhaseProverData,
RapPhaseSeq, RapPhaseSeqKind, RapPhaseVerifierData,
},
parizip,
rap::PermutationAirBuilderWithExposedValues,
utils::metrics_span,
};
pub struct FriLogUpPhase<F, Challenge, Challenger> {
log_up_params: LogUpSecurityParameters,
_marker: PhantomData<(F, Challenge, Challenger)>,
}
impl<F, Challenge, Challenger> FriLogUpPhase<F, Challenge, Challenger> {
pub fn new(log_up_params: LogUpSecurityParameters) -> Self {
Self {
log_up_params,
_marker: PhantomData,
}
}
}
#[derive(Error, Debug)]
pub enum FriLogUpError {
#[error("non-zero cumulative sum")]
NonZeroCumulativeSum,
#[error("missing proof")]
MissingPartialProof,
#[error("invalid proof of work witness")]
InvalidPowWitness,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct FriLogUpPartialProof<Witness> {
pub logup_pow_witness: Witness,
}
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct FriLogUpProvingKey {
interaction_partitions: Vec<Vec<usize>>,
}
impl FriLogUpProvingKey {
pub fn interaction_partitions(self) -> Vec<Vec<usize>> {
self.interaction_partitions
}
pub fn num_chunks(&self) -> usize {
self.interaction_partitions.len()
}
}
impl<F: Field, Challenge, Challenger> RapPhaseSeq<F, Challenge, Challenger>
for FriLogUpPhase<F, Challenge, Challenger>
where
F: Field,
Challenge: ExtensionField<F>,
Challenger: FieldChallenger<F> + GrindingChallenger<Witness = F>,
{
type PartialProof = FriLogUpPartialProof<F>;
type PartialProvingKey = FriLogUpProvingKey;
type Error = FriLogUpError;
const ID: RapPhaseSeqKind = RapPhaseSeqKind::FriLogUp;
fn log_up_security_params(&self) -> &LogUpSecurityParameters {
&self.log_up_params
}
fn generate_pk_per_air(
&self,
symbolic_constraints_per_air: &[SymbolicConstraints<F>],
max_constraint_degree: usize,
) -> Vec<Self::PartialProvingKey> {
symbolic_constraints_per_air
.iter()
.map(|constraints| {
find_interaction_chunks(&constraints.interactions, max_constraint_degree)
})
.collect()
}
fn partially_prove(
&self,
challenger: &mut Challenger,
constraints_per_air: &[&SymbolicConstraints<F>],
params_per_air: &[&FriLogUpProvingKey],
trace_view_per_air: &[PairTraceView<F>],
) -> Option<(Self::PartialProof, RapPhaseProverData<Challenge>)> {
let has_any_interactions = constraints_per_air
.iter()
.any(|constraints| !constraints.interactions.is_empty());
if !has_any_interactions {
return None;
}
let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits);
let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] =
array::from_fn(|_| challenger.sample_ext_element::<Challenge>());
let after_challenge_trace_per_air = metrics_span("generate_perm_trace_time_ms", || {
Self::generate_after_challenge_traces_per_air(
&challenges,
constraints_per_air,
params_per_air,
trace_view_per_air,
)
});
let cumulative_sum_per_air = Self::extract_cumulative_sums(&after_challenge_trace_per_air);
for cumulative_sum in cumulative_sum_per_air.iter().flatten() {
challenger.observe_slice(cumulative_sum.as_base_slice());
}
let exposed_values_per_air = cumulative_sum_per_air
.iter()
.map(|csum| csum.map(|csum| vec![csum]))
.collect_vec();
Some((
FriLogUpPartialProof { logup_pow_witness },
RapPhaseProverData {
challenges: challenges.to_vec(),
after_challenge_trace_per_air,
exposed_values_per_air,
},
))
}
fn partially_verify<Commitment: Clone>(
&self,
challenger: &mut Challenger,
partial_proof: Option<&Self::PartialProof>,
exposed_values_per_phase_per_air: &[Vec<Vec<Challenge>>],
commitment_per_phase: &[Commitment],
_permutation_opened_values: &[Vec<Vec<Vec<Challenge>>>],
) -> (RapPhaseVerifierData<Challenge>, Result<(), Self::Error>)
where
Challenger: CanObserve<Commitment>,
{
if exposed_values_per_phase_per_air
.iter()
.all(|exposed_values_per_phase_per_air| exposed_values_per_phase_per_air.is_empty())
{
return (RapPhaseVerifierData::default(), Ok(()));
}
let partial_proof = match partial_proof {
Some(proof) => proof,
None => {
return (
RapPhaseVerifierData::default(),
Err(FriLogUpError::MissingPartialProof),
);
}
};
if !challenger.check_witness(
self.log_up_params.log_up_pow_bits,
partial_proof.logup_pow_witness,
) {
return (
RapPhaseVerifierData::default(),
Err(FriLogUpError::InvalidPowWitness),
);
}
let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] =
array::from_fn(|_| challenger.sample_ext_element::<Challenge>());
for exposed_values_per_phase in exposed_values_per_phase_per_air.iter() {
if let Some(exposed_values) = exposed_values_per_phase.first() {
for exposed_value in exposed_values {
challenger.observe_slice(exposed_value.as_base_slice());
}
}
}
challenger.observe(commitment_per_phase[0].clone());
let cumulative_sums = exposed_values_per_phase_per_air
.iter()
.map(|exposed_values_per_phase| {
assert!(
exposed_values_per_phase.len() <= 1,
"Verifier does not support more than 1 challenge phase"
);
exposed_values_per_phase.first().map(|exposed_values| {
assert_eq!(
exposed_values.len(),
1,
"Only exposed value should be cumulative sum"
);
exposed_values[0]
})
})
.collect_vec();
let sum: Challenge = cumulative_sums
.into_iter()
.map(|c| c.unwrap_or(Challenge::ZERO))
.sum();
let result = if sum == Challenge::ZERO {
Ok(())
} else {
Err(Self::Error::NonZeroCumulativeSum)
};
let verifier_data = RapPhaseVerifierData {
challenges_per_phase: vec![challenges.to_vec()],
};
(verifier_data, result)
}
}
pub const STARK_LU_NUM_CHALLENGES: usize = 2;
pub const STARK_LU_NUM_EXPOSED_VALUES: usize = 1;
impl<F, Challenge, Challenger> FriLogUpPhase<F, Challenge, Challenger>
where
F: Field,
Challenge: ExtensionField<F>,
Challenger: FieldChallenger<F>,
{
fn generate_after_challenge_traces_per_air(
challenges: &[Challenge; STARK_LU_NUM_CHALLENGES],
constraints_per_air: &[&SymbolicConstraints<F>],
params_per_air: &[&FriLogUpProvingKey],
trace_view_per_air: &[PairTraceView<F>],
) -> Vec<Option<RowMajorMatrix<Challenge>>> {
parizip!(constraints_per_air, trace_view_per_air, params_per_air)
.map(|(constraints, trace_view, params)| {
Self::generate_after_challenge_trace(
&constraints.interactions,
trace_view,
challenges,
¶ms.interaction_partitions,
)
})
.collect::<Vec<_>>()
}
fn extract_cumulative_sums(
perm_traces: &[Option<RowMajorMatrix<Challenge>>],
) -> Vec<Option<Challenge>> {
perm_traces
.iter()
.map(|perm_trace| {
perm_trace.as_ref().map(|perm_trace| {
*perm_trace
.row_slice(perm_trace.height() - 1)
.last()
.unwrap()
})
})
.collect()
}
pub fn generate_after_challenge_trace(
all_interactions: &[SymbolicInteraction<F>],
trace_view: &PairTraceView<F>,
permutation_randomness: &[Challenge; STARK_LU_NUM_CHALLENGES],
interaction_partitions: &[Vec<usize>],
) -> Option<RowMajorMatrix<Challenge>>
where
F: Field,
Challenge: ExtensionField<F>,
{
if all_interactions.is_empty() {
return None;
}
let &[alpha, beta] = permutation_randomness;
let betas = generate_betas(beta, all_interactions);
let num_interactions = all_interactions.len();
let height = trace_view.partitioned_main[0].height();
let perm_width = interaction_partitions.len() + 1;
let mut perm_values = Challenge::zero_vec(height * perm_width);
debug_assert!(
trace_view
.partitioned_main
.iter()
.all(|m| m.height() == height),
"All main trace parts must have same height"
);
#[cfg(feature = "parallel")]
let num_threads = rayon::current_num_threads();
#[cfg(not(feature = "parallel"))]
let num_threads = 1;
let preprocessed = trace_view.preprocessed.as_ref().map(|m| m.as_view());
let partitioned_main = trace_view
.partitioned_main
.iter()
.map(|m| m.as_view())
.collect_vec();
let evaluator = |local_index: usize| Evaluator {
preprocessed: &preprocessed,
partitioned_main: &partitioned_main,
public_values: &trace_view.public_values,
height,
local_index,
};
let height_per_thread = height.div_ceil(num_threads);
perm_values
.par_chunks_mut(height_per_thread * perm_width)
.enumerate()
.for_each(|(thread_idx, perm_values)| {
let num_rows = perm_values.len() / perm_width;
let mut denoms = Challenge::zero_vec(num_rows * num_interactions);
let row_offset = thread_idx * height_per_thread;
for (n, denom_row) in denoms.chunks_exact_mut(num_interactions).enumerate() {
let evaluator = evaluator(row_offset + n);
for (denom, interaction) in denom_row.iter_mut().zip(all_interactions.iter()) {
debug_assert!(interaction.message.len() <= betas.len());
let b = F::from_canonical_u32(interaction.bus_index as u32 + 1);
let mut fields = interaction.message.iter();
*denom = alpha
+ evaluator
.eval_expr(fields.next().expect("fields should not be empty"));
for (expr, &beta) in fields.zip(betas.iter().skip(1)) {
*denom += beta * evaluator.eval_expr(expr);
}
*denom += betas[interaction.message.len()] * b;
}
}
let reciprocals = p3_field::batch_multiplicative_inverse(&denoms);
drop(denoms);
perm_values
.par_chunks_exact_mut(perm_width)
.zip(reciprocals.par_chunks_exact(num_interactions))
.enumerate()
.for_each(|(n, (perm_row, reciprocals))| {
debug_assert_eq!(perm_row.len(), perm_width);
debug_assert_eq!(reciprocals.len(), num_interactions);
let evaluator = evaluator(row_offset + n);
let mut row_sum = Challenge::ZERO;
for (part, perm_val) in zip(interaction_partitions, perm_row.iter_mut()) {
for &interaction_idx in part {
let interaction = &all_interactions[interaction_idx];
let interaction_val = reciprocals[interaction_idx]
* evaluator.eval_expr(&interaction.count);
*perm_val += interaction_val;
}
row_sum += *perm_val;
}
perm_row[perm_width - 1] = row_sum;
});
});
tracing::trace_span!("compute logup partial sums").in_scope(|| {
let mut phi = Challenge::ZERO;
for perm_chunk in perm_values.chunks_exact_mut(perm_width) {
phi += *perm_chunk.last().unwrap();
*perm_chunk.last_mut().unwrap() = phi;
}
});
Some(RowMajorMatrix::new(perm_values, perm_width))
}
}
pub fn eval_fri_log_up_phase<AB>(
builder: &mut AB,
symbolic_interactions: &[SymbolicInteraction<AB::F>],
max_constraint_degree: usize,
) where
AB: InteractionBuilder + PermutationAirBuilderWithExposedValues,
{
let exposed_values = builder.permutation_exposed_values();
assert_eq!(
exposed_values.len(),
1,
"Should have one exposed value for cumulative_sum"
);
let cumulative_sum = exposed_values[0];
let rand_elems = builder.permutation_randomness();
let perm = builder.permutation();
let (perm_local, perm_next) = (perm.row_slice(0), perm.row_slice(1));
let perm_local: &[AB::VarEF] = (*perm_local).borrow();
let perm_next: &[AB::VarEF] = (*perm_next).borrow();
let all_interactions = builder.all_interactions().to_vec();
let FriLogUpProvingKey {
interaction_partitions,
} = find_interaction_chunks(symbolic_interactions, max_constraint_degree);
let num_chunks = interaction_partitions.len();
debug_assert_eq!(num_chunks + 1, perm_local.len());
let phi_local = *perm_local.last().unwrap();
let phi_next = *perm_next.last().unwrap();
let alpha = rand_elems[0];
let betas = generate_betas(rand_elems[1].into(), &all_interactions);
let phi_lhs = phi_next.into() - phi_local.into();
let mut phi_rhs = AB::ExprEF::ZERO;
let mut phi_0 = AB::ExprEF::ZERO;
for (chunk_idx, part) in interaction_partitions.iter().enumerate() {
let denoms_per_chunk = part
.iter()
.map(|&interaction_idx| {
let interaction = &all_interactions[interaction_idx];
assert!(
!interaction.message.is_empty(),
"fields should not be empty"
);
let mut field_hash = AB::ExprEF::ZERO;
let b = AB::Expr::from_canonical_u32(interaction.bus_index as u32 + 1);
for (field, beta) in interaction.message.iter().chain([&b]).zip(&betas) {
field_hash += beta.clone() * field.clone();
}
field_hash + alpha.into()
})
.collect_vec();
let mut row_lhs: AB::ExprEF = perm_local[chunk_idx].into();
for denom in denoms_per_chunk.iter() {
row_lhs *= denom.clone();
}
let mut row_rhs = AB::ExprEF::ZERO;
for (i, &interaction_idx) in part.iter().enumerate() {
let interaction = &all_interactions[interaction_idx];
let mut term: AB::ExprEF = interaction.count.clone().into();
for (j, denom) in denoms_per_chunk.iter().enumerate() {
if i != j {
term *= denom.clone();
}
}
row_rhs += term;
}
builder.assert_eq_ext(row_lhs, row_rhs);
phi_0 += perm_local[chunk_idx].into();
phi_rhs += perm_next[chunk_idx].into();
}
builder.when_transition().assert_eq_ext(phi_lhs, phi_rhs);
builder
.when_first_row()
.assert_eq_ext(*perm_local.last().unwrap(), phi_0);
builder
.when_last_row()
.assert_eq_ext(*perm_local.last().unwrap(), cumulative_sum);
}
pub(crate) fn find_interaction_chunks<F: Field>(
interactions: &[SymbolicInteraction<F>],
max_constraint_degree: usize,
) -> FriLogUpProvingKey {
if interactions.is_empty() {
return FriLogUpProvingKey::default();
}
let max_field_degree = |i: usize| {
interactions[i]
.message
.iter()
.map(|f| f.degree_multiple())
.max()
.unwrap_or(0)
};
let mut interaction_idxs = (0..interactions.len()).collect_vec();
interaction_idxs.sort_by(|&i, &j| {
let field_cmp = max_field_degree(i).cmp(&max_field_degree(j));
if field_cmp == std::cmp::Ordering::Equal {
interactions[i]
.count
.degree_multiple()
.cmp(&interactions[j].count.degree_multiple())
} else {
field_cmp
}
});
let mut running_sum_field_degree = 0;
let mut numerator_max_degree = 0;
let mut interaction_partitions = vec![];
let mut cur_chunk = vec![];
for interaction_idx in interaction_idxs {
let field_degree = max_field_degree(interaction_idx);
let count_degree = interactions[interaction_idx].count.degree_multiple();
let new_num_max_degree = max(
numerator_max_degree + field_degree,
count_degree + running_sum_field_degree,
);
let new_denom_degree = running_sum_field_degree + field_degree;
if max(new_num_max_degree, new_denom_degree + 1) <= max_constraint_degree {
cur_chunk.push(interaction_idx);
numerator_max_degree = new_num_max_degree;
running_sum_field_degree += field_degree;
} else {
if !cur_chunk.is_empty() {
interaction_partitions.push(mem::take(&mut cur_chunk));
}
cur_chunk.push(interaction_idx);
numerator_max_degree = count_degree;
running_sum_field_degree = field_degree;
if max_constraint_degree > 0
&& max(count_degree, field_degree + 1) > max_constraint_degree
{
panic!("Interaction with field_degree={field_degree}, count_degree={count_degree} exceeds max_constraint_degree={max_constraint_degree}");
}
}
}
assert!(!cur_chunk.is_empty());
interaction_partitions.push(cur_chunk);
FriLogUpProvingKey {
interaction_partitions,
}
}