openvm_stark_backend/prover/cpu/quotient/
single.rs

1use std::cmp::min;
2
3use itertools::Itertools;
4use p3_commit::PolynomialSpace;
5use p3_field::{FieldAlgebra, FieldExtensionAlgebra, PackedValue};
6use p3_matrix::Matrix;
7use p3_maybe_rayon::prelude::*;
8use p3_util::log2_strict_usize;
9use tracing::instrument;
10
11use super::evaluator::{ProverConstraintEvaluator, ViewPair};
12use crate::{
13    air_builders::symbolic::{
14        symbolic_variable::Entry, SymbolicExpressionDag, SymbolicExpressionNode,
15    },
16    config::{Domain, PackedChallenge, PackedVal, StarkGenericConfig, Val},
17};
18
19// Starting reference: p3_uni_stark::prover::quotient_values
20// (many changes have been made since then)
21/// Computes evaluation of DEEP quotient polynomial on the quotient domain for a single RAP (single trace matrix).
22///
23/// Designed to be general enough to support RAP with multiple rounds of challenges.
24#[allow(clippy::too_many_arguments)]
25#[instrument(
26    name = "compute single RAP quotient polynomial",
27    level = "trace",
28    skip_all
29)]
30pub fn compute_single_rap_quotient_values<'a, SC, M>(
31    constraints: &SymbolicExpressionDag<Val<SC>>,
32    trace_domain: Domain<SC>,
33    quotient_domain: Domain<SC>,
34    preprocessed_trace_on_quotient_domain: Option<M>,
35    partitioned_main_lde_on_quotient_domain: Vec<M>,
36    after_challenge_lde_on_quotient_domain: Vec<M>,
37    // For each challenge round, the challenges drawn
38    challenges: &'a [Vec<PackedChallenge<SC>>],
39    alpha: SC::Challenge,
40    public_values: &'a [Val<SC>],
41    // Values exposed to verifier after challenge round i
42    exposed_values_after_challenge: &'a [Vec<PackedChallenge<SC>>],
43) -> Vec<SC::Challenge>
44where
45    SC: StarkGenericConfig,
46    M: Matrix<Val<SC>>,
47{
48    let quotient_size = quotient_domain.size();
49    assert!(partitioned_main_lde_on_quotient_domain
50        .iter()
51        .all(|m| m.height() >= quotient_size));
52    assert!(after_challenge_lde_on_quotient_domain
53        .iter()
54        .all(|m| m.height() >= quotient_size));
55    let preprocessed_width = preprocessed_trace_on_quotient_domain
56        .as_ref()
57        .map(|m| m.width())
58        .unwrap_or(0);
59    let mut sels = trace_domain.selectors_on_coset(quotient_domain);
60
61    let qdb = log2_strict_usize(quotient_size) - log2_strict_usize(trace_domain.size());
62    let next_step = 1 << qdb;
63
64    let ext_degree = SC::Challenge::D;
65
66    let mut alpha_powers = alpha
67        .powers()
68        .take(constraints.constraint_idx.len())
69        .map(PackedChallenge::<SC>::from_f)
70        .collect_vec();
71    // We want alpha powers to have highest power first, because of how accumulator "folding" works
72    // So this will be alpha^{num_constraints - 1}, ..., alpha^0
73    alpha_powers.reverse();
74
75    // assert!(quotient_size >= PackedVal::<SC>::WIDTH);
76    // We take PackedVal::<SC>::WIDTH worth of values at a time from a quotient_size slice, so we need to
77    // pad with default values in the case where quotient_size is smaller than PackedVal::<SC>::WIDTH.
78    for _ in quotient_size..PackedVal::<SC>::WIDTH {
79        sels.is_first_row.push(Val::<SC>::default());
80        sels.is_last_row.push(Val::<SC>::default());
81        sels.is_transition.push(Val::<SC>::default());
82        sels.inv_zeroifier.push(Val::<SC>::default());
83    }
84
85    // Scan constraints to see if we need `next` row and also check index bounds
86    // so we don't need to check them per row.
87    let mut rotation = 0;
88    for node in &constraints.nodes {
89        if let SymbolicExpressionNode::Variable(var) = node {
90            match var.entry {
91                Entry::Preprocessed { offset } => {
92                    rotation = rotation.max(offset);
93                    assert!(var.index < preprocessed_width);
94                    assert!(
95                        preprocessed_trace_on_quotient_domain
96                            .as_ref()
97                            .unwrap()
98                            .height()
99                            >= quotient_size
100                    );
101                }
102                Entry::Main { part_index, offset } => {
103                    rotation = rotation.max(offset);
104                    assert!(
105                        var.index < partitioned_main_lde_on_quotient_domain[part_index].width()
106                    );
107                }
108                Entry::Public => {
109                    assert!(var.index < public_values.len());
110                }
111                Entry::Permutation { offset } => {
112                    rotation = rotation.max(offset);
113                    let ext_width = after_challenge_lde_on_quotient_domain
114                        .first()
115                        .expect("Challenge phase not supported")
116                        .width()
117                        / ext_degree;
118                    assert!(var.index < ext_width);
119                }
120                Entry::Challenge => {
121                    assert!(
122                        var.index
123                            < challenges
124                                .first()
125                                .expect("Challenge phase not supported")
126                                .len()
127                    );
128                }
129                Entry::Exposed => {
130                    assert!(
131                        var.index
132                            < exposed_values_after_challenge
133                                .first()
134                                .expect("Challenge phase not supported")
135                                .len()
136                    );
137                }
138            }
139        }
140    }
141    let needs_next = rotation > 0;
142
143    (0..quotient_size)
144        .into_par_iter()
145        .step_by(PackedVal::<SC>::WIDTH)
146        .flat_map_iter(|i_start| {
147            let wrap = |i| i % quotient_size;
148            let i_range = i_start..i_start + PackedVal::<SC>::WIDTH;
149
150            let [row_idx_local, row_idx_next] = [0, next_step].map(|shift| {
151                (0..PackedVal::<SC>::WIDTH)
152                    .map(|offset| wrap(i_start + offset + shift))
153                    .collect::<Vec<_>>()
154            });
155            let row_idx_local = Some(row_idx_local);
156            let row_idx_next = needs_next.then_some(row_idx_next);
157
158            let is_first_row = *PackedVal::<SC>::from_slice(&sels.is_first_row[i_range.clone()]);
159            let is_last_row = *PackedVal::<SC>::from_slice(&sels.is_last_row[i_range.clone()]);
160            let is_transition = *PackedVal::<SC>::from_slice(&sels.is_transition[i_range.clone()]);
161            let inv_zeroifier = *PackedVal::<SC>::from_slice(&sels.inv_zeroifier[i_range.clone()]);
162
163            // Vertically pack rows of each matrix,
164            // skipping `next` if above scan showed no constraints need it:
165
166            let [preprocessed_local, preprocessed_next] =
167                [&row_idx_local, &row_idx_next].map(|wrapped_idx| {
168                    wrapped_idx.as_ref().map(|wrapped_idx| {
169                        (0..preprocessed_width)
170                            .map(|col| {
171                                PackedVal::<SC>::from_fn(|offset| {
172                                    preprocessed_trace_on_quotient_domain
173                                        .as_ref()
174                                        .unwrap()
175                                        .get(wrapped_idx[offset], col)
176                                })
177                            })
178                            .collect_vec()
179                    })
180                });
181            let preprocessed_pair = ViewPair::new(preprocessed_local.unwrap(), preprocessed_next);
182
183            let partitioned_main_pairs = partitioned_main_lde_on_quotient_domain
184                .iter()
185                .map(|lde| {
186                    let width = lde.width();
187                    let [local, next] = [&row_idx_local, &row_idx_next].map(|wrapped_idx| {
188                        wrapped_idx.as_ref().map(|wrapped_idx| {
189                            (0..width)
190                                .map(|col| {
191                                    PackedVal::<SC>::from_fn(|offset| {
192                                        lde.get(wrapped_idx[offset], col)
193                                    })
194                                })
195                                .collect_vec()
196                        })
197                    });
198                    ViewPair::new(local.unwrap(), next)
199                })
200                .collect_vec();
201
202            let after_challenge_pairs = after_challenge_lde_on_quotient_domain
203                .iter()
204                .map(|lde| {
205                    // Width in base field with extension field elements flattened
206                    let base_width = lde.width();
207                    let [local, next] = [&row_idx_local, &row_idx_next].map(|wrapped_idx| {
208                        wrapped_idx.as_ref().map(|wrapped_idx| {
209                            (0..base_width)
210                                .step_by(ext_degree)
211                                .map(|col| {
212                                    PackedChallenge::<SC>::from_base_fn(|i| {
213                                        PackedVal::<SC>::from_fn(|offset| {
214                                            lde.get(wrapped_idx[offset], col + i)
215                                        })
216                                    })
217                                })
218                                .collect_vec()
219                        })
220                    });
221                    ViewPair::new(local.unwrap(), next)
222                })
223                .collect_vec();
224
225            let evaluator: ProverConstraintEvaluator<SC> = ProverConstraintEvaluator {
226                preprocessed: preprocessed_pair,
227                partitioned_main: partitioned_main_pairs,
228                after_challenge: after_challenge_pairs,
229                challenges,
230                is_first_row,
231                is_last_row,
232                is_transition,
233                public_values,
234                exposed_values_after_challenge,
235            };
236            let accumulator = evaluator.accumulate(constraints, &alpha_powers);
237            // quotient(x) = constraints(x) / Z_H(x)
238            let quotient: PackedChallenge<SC> = accumulator * inv_zeroifier;
239
240            // "Transpose" D packed base coefficients into WIDTH scalar extension coefficients.
241            let width = min(PackedVal::<SC>::WIDTH, quotient_size);
242            (0..width).map(move |idx_in_packing| {
243                let quotient_value = (0..<SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D)
244                    .map(|coeff_idx| quotient.as_base_slice()[coeff_idx].as_slice()[idx_in_packing])
245                    .collect::<Vec<_>>();
246                SC::Challenge::from_base_slice(&quotient_value)
247            })
248        })
249        .collect()
250}