p3_uni_stark/
prover.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use itertools::{izip, Itertools};
5use p3_air::Air;
6use p3_challenger::{CanObserve, CanSample, FieldChallenger};
7use p3_commit::{Pcs, PolynomialSpace};
8use p3_field::{FieldAlgebra, FieldExtensionAlgebra, PackedValue};
9use p3_matrix::dense::RowMajorMatrix;
10use p3_matrix::Matrix;
11use p3_maybe_rayon::prelude::*;
12use p3_util::{log2_ceil_usize, log2_strict_usize};
13use tracing::{info_span, instrument};
14
15use crate::{
16    get_symbolic_constraints, Commitments, Domain, OpenedValues, PackedChallenge, PackedVal, Proof,
17    ProverConstraintFolder, StarkGenericConfig, SymbolicAirBuilder, SymbolicExpression, Val,
18};
19
20#[instrument(skip_all)]
21#[allow(clippy::multiple_bound_locations)] // cfg not supported in where clauses?
22pub fn prove<
23    SC,
24    #[cfg(debug_assertions)] A: for<'a> Air<crate::check_constraints::DebugConstraintBuilder<'a, Val<SC>>>,
25    #[cfg(not(debug_assertions))] A,
26>(
27    config: &SC,
28    air: &A,
29    challenger: &mut SC::Challenger,
30    trace: RowMajorMatrix<Val<SC>>,
31    public_values: &Vec<Val<SC>>,
32) -> Proof<SC>
33where
34    SC: StarkGenericConfig,
35    A: Air<SymbolicAirBuilder<Val<SC>>> + for<'a> Air<ProverConstraintFolder<'a, SC>>,
36{
37    #[cfg(debug_assertions)]
38    crate::check_constraints::check_constraints(air, &trace, public_values);
39
40    let degree = trace.height();
41    let log_degree = log2_strict_usize(degree);
42
43    let symbolic_constraints = get_symbolic_constraints::<Val<SC>, A>(air, 0, public_values.len());
44    let constraint_count = symbolic_constraints.len();
45    let constraint_degree = symbolic_constraints
46        .iter()
47        .map(SymbolicExpression::degree_multiple)
48        .max()
49        .unwrap_or(0);
50    let log_quotient_degree = log2_ceil_usize(constraint_degree - 1);
51    let quotient_degree = 1 << log_quotient_degree;
52
53    let pcs = config.pcs();
54    let trace_domain = pcs.natural_domain_for_degree(degree);
55
56    let (trace_commit, trace_data) =
57        info_span!("commit to trace data").in_scope(|| pcs.commit(vec![(trace_domain, trace)]));
58
59    // Observe the instance.
60    challenger.observe(Val::<SC>::from_canonical_usize(log_degree));
61    // TODO: Might be best practice to include other instance data here; see verifier comment.
62
63    challenger.observe(trace_commit.clone());
64    challenger.observe_slice(public_values);
65    let alpha: SC::Challenge = challenger.sample_ext_element();
66
67    let quotient_domain =
68        trace_domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree));
69
70    let trace_on_quotient_domain = pcs.get_evaluations_on_domain(&trace_data, 0, quotient_domain);
71
72    let quotient_values = quotient_values(
73        air,
74        public_values,
75        trace_domain,
76        quotient_domain,
77        trace_on_quotient_domain,
78        alpha,
79        constraint_count,
80    );
81    let quotient_flat = RowMajorMatrix::new_col(quotient_values).flatten_to_base();
82    let quotient_chunks = quotient_domain.split_evals(quotient_degree, quotient_flat);
83    let qc_domains = quotient_domain.split_domains(quotient_degree);
84
85    let (quotient_commit, quotient_data) = info_span!("commit to quotient poly chunks")
86        .in_scope(|| pcs.commit(izip!(qc_domains, quotient_chunks).collect_vec()));
87    challenger.observe(quotient_commit.clone());
88
89    let commitments = Commitments {
90        trace: trace_commit,
91        quotient_chunks: quotient_commit,
92    };
93
94    let zeta: SC::Challenge = challenger.sample();
95    let zeta_next = trace_domain.next_point(zeta).unwrap();
96
97    let (opened_values, opening_proof) = info_span!("open").in_scope(|| {
98        pcs.open(
99            vec![
100                (&trace_data, vec![vec![zeta, zeta_next]]),
101                (
102                    &quotient_data,
103                    // open every chunk at zeta
104                    (0..quotient_degree).map(|_| vec![zeta]).collect_vec(),
105                ),
106            ],
107            challenger,
108        )
109    });
110    let trace_local = opened_values[0][0][0].clone();
111    let trace_next = opened_values[0][0][1].clone();
112    let quotient_chunks = opened_values[1].iter().map(|v| v[0].clone()).collect_vec();
113    let opened_values = OpenedValues {
114        trace_local,
115        trace_next,
116        quotient_chunks,
117    };
118    Proof {
119        commitments,
120        opened_values,
121        opening_proof,
122        degree_bits: log_degree,
123    }
124}
125
126#[instrument(name = "compute quotient polynomial", skip_all)]
127fn quotient_values<SC, A, Mat>(
128    air: &A,
129    public_values: &Vec<Val<SC>>,
130    trace_domain: Domain<SC>,
131    quotient_domain: Domain<SC>,
132    trace_on_quotient_domain: Mat,
133    alpha: SC::Challenge,
134    constraint_count: usize,
135) -> Vec<SC::Challenge>
136where
137    SC: StarkGenericConfig,
138    A: for<'a> Air<ProverConstraintFolder<'a, SC>>,
139    Mat: Matrix<Val<SC>> + Sync,
140{
141    let quotient_size = quotient_domain.size();
142    let width = trace_on_quotient_domain.width();
143    let mut sels = trace_domain.selectors_on_coset(quotient_domain);
144
145    let qdb = log2_strict_usize(quotient_domain.size()) - log2_strict_usize(trace_domain.size());
146    let next_step = 1 << qdb;
147
148    // We take PackedVal::<SC>::WIDTH worth of values at a time from a quotient_size slice, so we need to
149    // pad with default values in the case where quotient_size is smaller than PackedVal::<SC>::WIDTH.
150    for _ in quotient_size..PackedVal::<SC>::WIDTH {
151        sels.is_first_row.push(Val::<SC>::default());
152        sels.is_last_row.push(Val::<SC>::default());
153        sels.is_transition.push(Val::<SC>::default());
154        sels.inv_zeroifier.push(Val::<SC>::default());
155    }
156
157    let mut alpha_powers = alpha.powers().take(constraint_count).collect_vec();
158    alpha_powers.reverse();
159
160    (0..quotient_size)
161        .into_par_iter()
162        .step_by(PackedVal::<SC>::WIDTH)
163        .flat_map_iter(|i_start| {
164            let i_range = i_start..i_start + PackedVal::<SC>::WIDTH;
165
166            let is_first_row = *PackedVal::<SC>::from_slice(&sels.is_first_row[i_range.clone()]);
167            let is_last_row = *PackedVal::<SC>::from_slice(&sels.is_last_row[i_range.clone()]);
168            let is_transition = *PackedVal::<SC>::from_slice(&sels.is_transition[i_range.clone()]);
169            let inv_zeroifier = *PackedVal::<SC>::from_slice(&sels.inv_zeroifier[i_range.clone()]);
170
171            let main = RowMajorMatrix::new(
172                trace_on_quotient_domain.vertically_packed_row_pair(i_start, next_step),
173                width,
174            );
175
176            let accumulator = PackedChallenge::<SC>::ZERO;
177            let mut folder = ProverConstraintFolder {
178                main: main.as_view(),
179                public_values,
180                is_first_row,
181                is_last_row,
182                is_transition,
183                alpha_powers: &alpha_powers,
184                accumulator,
185                constraint_index: 0,
186            };
187            air.eval(&mut folder);
188
189            // quotient(x) = constraints(x) / Z_H(x)
190            let quotient = folder.accumulator * inv_zeroifier;
191
192            // "Transpose" D packed base coefficients into WIDTH scalar extension coefficients.
193            (0..core::cmp::min(quotient_size, PackedVal::<SC>::WIDTH)).map(move |idx_in_packing| {
194                SC::Challenge::from_base_fn(|coeff_idx| {
195                    quotient.as_base_slice()[coeff_idx].as_slice()[idx_in_packing]
196                })
197            })
198        })
199        .collect()
200}