openvm_native_recursion/fri/
two_adic_pcs.rs

1use openvm_native_compiler::prelude::*;
2use openvm_native_compiler_derive::iter_zip;
3use openvm_stark_backend::{
4    p3_commit::TwoAdicMultiplicativeCoset,
5    p3_field::{FieldAlgebra, FieldExtensionAlgebra, TwoAdicField},
6};
7use p3_symmetric::Hash;
8
9use super::{
10    types::{
11        DimensionsVariable, FriConfigVariable, TwoAdicPcsMatsVariable, TwoAdicPcsRoundVariable,
12    },
13    verify_batch, verify_query, NestedOpenedValues, TwoAdicMultiplicativeCosetVariable,
14};
15use crate::{
16    challenger::ChallengerVariable, commit::PcsVariable, digest::DigestVariable,
17    fri::types::FriProofVariable,
18};
19
20// The maximum two-adicity of Felt `C::F`. This means `C::F` does not have a multiplicative subgroup
21// of order 2^{MAX_TWO_ADICITY + 1}. Currently set to 27 for BabyBear.
22pub const MAX_TWO_ADICITY: usize = 27;
23
24/// Notes:
25/// 1. FieldMerkleTreeMMCS sorts traces by height in descending order when committing data.
26/// 2. **Required** that `C::F` has two-adicity <= [MAX_TWO_ADICITY]. In particular this implies
27///    that all LDE matrices have `log2(lde_height) <= MAX_TWO_ADICITY`.
28/// 3. **Required** that the maximum trace height is `2^log_max_height - 1`.
29///
30/// Reference:
31/// <https://github.com/Plonky3/Plonky3/blob/27b3127dab047e07145c38143379edec2960b3e1/merkle-tree/src/merkle_tree.rs#L53>
32/// So traces are sorted in `opening_proof`.
33///
34/// 2. FieldMerkleTreeMMCS::poseidon2 keeps the raw values in the original order. So traces are not
35///    sorted in `opened_values`.
36///
37/// Reference:
38/// <https://github.com/Plonky3/Plonky3/blob/27b3127dab047e07145c38143379edec2960b3e1/merkle-tree/src/mmcs.rs#L87>
39/// <https://github.com/Plonky3/Plonky3/blob/27b3127dab047e07145c38143379edec2960b3e1/merkle-tree/src/merkle_tree.rs#L100>
40/// <https://github.com/Plonky3/Plonky3/blob/784b7dd1fa87c1202e63350cc8182d7c5327a7af/fri/src/verifier.rs#L22>
41pub fn verify_two_adic_pcs<C: Config>(
42    builder: &mut Builder<C>,
43    config: &FriConfigVariable<C>,
44    rounds: Array<C, TwoAdicPcsRoundVariable<C>>,
45    proof: FriProofVariable<C>,
46    log_max_height: RVar<C::N>,
47    challenger: &mut impl ChallengerVariable<C>,
48) where
49    C::F: TwoAdicField,
50    C::EF: TwoAdicField,
51{
52    // Currently do not support other final poly len
53    builder.assert_var_eq(RVar::from(config.log_final_poly_len), RVar::zero());
54    // The `proof.final_poly` length is in general `2^{log_final_poly_len}`.
55    // We require `log_final_poly_len = 0`, so `proof.final_poly` has length `1`.
56    builder.assert_usize_eq(proof.final_poly.len(), RVar::one());
57    // Constant term of final poly
58    let final_poly_ct = builder.get(&proof.final_poly, 0);
59
60    let g = builder.generator();
61
62    let log_blowup = config.log_blowup;
63    iter_zip!(builder, rounds).for_each(|ptr_vec, builder| {
64        let round = builder.iter_ptr_get(&rounds, ptr_vec[0]);
65        iter_zip!(builder, round.mats).for_each(|ptr_vec, builder| {
66            let mat = builder.iter_ptr_get(&round.mats, ptr_vec[0]);
67            iter_zip!(builder, mat.values).for_each(|ptr_vec, builder| {
68                let value = builder.iter_ptr_get(&mat.values, ptr_vec[0]);
69                iter_zip!(builder, value).for_each(|ptr_vec, builder| {
70                    if builder.flags.static_only {
71                        let ext = builder.iter_ptr_get(&value, ptr_vec[0]);
72                        let arr = builder.ext2felt(ext);
73                        challenger.observe_slice(builder, arr);
74                    } else {
75                        let ptr = ptr_vec[0];
76                        for i in 0..C::EF::D {
77                            let f: Felt<_> = builder.uninit();
78                            builder.load(
79                                f,
80                                Ptr {
81                                    address: ptr.variable(),
82                                },
83                                MemIndex {
84                                    index: RVar::from(i),
85                                    offset: 0,
86                                    size: 1,
87                                },
88                            );
89                            challenger.observe(builder, f);
90                        }
91                    }
92                });
93            });
94        });
95    });
96    let alpha = challenger.sample_ext(builder);
97    if builder.flags.static_only {
98        builder.ext_reduce_circuit(alpha);
99    }
100
101    builder.cycle_tracker_start("stage-d-verifier-verify");
102    // **ATTENTION**: always check shape of user inputs.
103    builder.assert_usize_eq(proof.query_proofs.len(), RVar::from(config.num_queries));
104    builder.assert_usize_eq(proof.commit_phase_commits.len(), log_max_height);
105    let betas: Array<C, Ext<C::F, C::EF>> = builder.array(log_max_height);
106    let betas_squared: Array<C, Ext<C::F, C::EF>> = builder.array(log_max_height);
107    // `i_plus_one_arr[i] = i + 1`. This is needed to add "enumerate" to `iter_zip!`
108    // There is no risk of overflow because `log_max_height` is much less than `C::N::MODULUS`
109    let i_plus_one_arr: Array<C, Usize<C::N>> = builder.array(log_max_height);
110    let i_var: Usize<C::N> = builder.eval(C::N::ZERO);
111    iter_zip!(
112        builder,
113        proof.commit_phase_commits,
114        betas,
115        betas_squared,
116        i_plus_one_arr
117    )
118    .for_each(|ptr_vec, builder| {
119        let [comm_ptr, beta_ptr, beta_sq_ptr, i_plus_one_ptr] = ptr_vec.try_into().unwrap();
120
121        let comm = builder.iter_ptr_get(&proof.commit_phase_commits, comm_ptr);
122        challenger.observe_digest(builder, comm);
123        let sample = challenger.sample_ext(builder);
124        builder.iter_ptr_set(&betas, beta_ptr, sample);
125        builder.iter_ptr_set(&betas_squared, beta_sq_ptr, sample * sample);
126        builder.assign(&i_var, i_var.clone() + C::N::ONE);
127        // Note: this is a deep clone when `builder.flags.static == true`
128        let i_plus_one_clone: Usize<C::N> = builder.eval(i_var.clone());
129        builder.iter_ptr_set(&i_plus_one_arr, i_plus_one_ptr, i_plus_one_clone);
130    });
131
132    iter_zip!(builder, proof.final_poly).for_each(|ptr_vec, builder| {
133        let final_poly_elem = builder.iter_ptr_get(&proof.final_poly, ptr_vec[0]);
134        let final_poly_elem_felts = builder.ext2felt(final_poly_elem);
135        challenger.observe_slice(builder, final_poly_elem_felts);
136    });
137
138    challenger.check_witness(builder, config.proof_of_work_bits, proof.pow_witness);
139
140    let log_max_lde_height = builder.eval_expr(log_max_height + RVar::from(log_blowup));
141    // tag_exp is a shared buffer.
142    let tag_exp: Array<C, Felt<C::F>> = builder.array(log_max_lde_height);
143    let w = config.get_two_adic_generator(builder, log_max_lde_height);
144    let max_gen_pow = config.get_two_adic_generator(builder, 1);
145    let one_var: Felt<C::F> = builder.eval(C::F::ONE);
146
147    builder.cycle_tracker_start("pre-compute-rounds-context");
148    let rounds_context = compute_rounds_context(builder, &rounds, log_blowup, alpha);
149    // Only used in static mode.
150    let alpha_pows = if builder.flags.static_only {
151        let max_width = get_max_matrix_width(builder, &rounds);
152        let mut ret = Vec::with_capacity(max_width + 1);
153        ret.push(C::EF::ONE.cons());
154        for i in 1..=max_width {
155            let curr = builder.eval(ret[i - 1].clone() * alpha);
156            builder.ext_reduce_circuit(curr);
157            ret.push(curr.into());
158        }
159        ret
160    } else {
161        vec![]
162    };
163    builder.cycle_tracker_end("pre-compute-rounds-context");
164
165    // Accumulator of the reduced opening sums, reset per query. The array `ro` is indexed by
166    // log_height.
167    let ro: Array<C, Ext<C::F, C::EF>> = builder.array(MAX_TWO_ADICITY + 1);
168    let alpha_pow: Array<C, Ext<C::F, C::EF>> = builder.array(MAX_TWO_ADICITY + 1);
169
170    iter_zip!(builder, proof.query_proofs).for_each(|ptr_vec, builder| {
171        let query_proof = builder.iter_ptr_get(&proof.query_proofs, ptr_vec[0]);
172        let index_bits = challenger.sample_bits(builder, log_max_lde_height);
173
174        // We reset the reduced opening accumulator at the start of each query.
175        // We describe what `ro[log_height]` computes per query in pseudo-code, where `log_height`
176        // is log2 of the size of the LDE domain: ro[log_height] = 0
177        // alpha_pow[log_height] = 1
178        // for round in rounds:
179        //   for mat in round.mats where (mat.domain.log_n + log_blowup == log_height): //
180        // preserving order of round.mats      // g is generator of F
181        //      // w_{log_height} is generator of subgroup of F of order 2^log_height
182        //      x = g * w_{log_height}^{reverse_bits(index >> (log_max_height - log_height),
183        // log_height)}      // reverse_bits(x, bits) takes an unsigned integer x with
184        // `bits` bits and returns the unsigned integer with the bits of x reversed.      //
185        // x is a rotated evaluation point in a coset of the LDE domain.      ps_at_x =
186        // [claimed evaluation of p at x for each polynomial p corresponding to column of mat]
187        //      // ps_at_x is array of Felt
188        //      for (z, ps_at_z) in zip(mat.points, mat.values):
189        //        // z is an out of domain point in Ext. There may be multiple per round to account
190        // for rotations in AIR constraints.        // ps_at_z is array of Ext for [claimed
191        // evaluation of p at z for each polyomial p corresponding to column of mat]
192        //        for (p_at_x, p_at_z) in zip(ps_at_x, ps_at_z):
193        //          ro[log_height] += alpha_pow[log_height] * (p_at_x - p_at_z) / (x - z)
194        //          alpha_pow[log_height] *= alpha
195        //
196        // The final value of ro[log_height] is the reduced opening value for log_height.
197        if builder.flags.static_only {
198            for j in 0..=MAX_TWO_ADICITY {
199                // ATTENTION: don't use set_value here, Fixed will share the same variable.
200                builder.set(&ro, j, C::EF::ZERO.cons());
201                builder.set(&alpha_pow, j, C::EF::ONE.cons());
202            }
203        } else {
204            let zero_ef = builder.eval(C::EF::ZERO.cons());
205            let one_ef = builder.eval(C::EF::ONE.cons());
206            for j in 0..=MAX_TWO_ADICITY {
207                // Use set_value here to save a copy.
208                builder.set_value(&ro, j, zero_ef);
209                builder.set_value(&alpha_pow, j, one_ef);
210            }
211        }
212        // **ATTENTION**: always check shape of user inputs.
213        builder.assert_usize_eq(query_proof.input_proof.len(), rounds.len());
214
215        // Pre-compute tag_exp
216        builder.cycle_tracker_start("cache-generator-powers");
217        {
218            // truncate index_bits to log_max_height
219            let index_bits_truncated = index_bits.slice(builder, 0, log_max_lde_height);
220
221            // b = index_bits
222            // w = generator of order 2^log_max_height
223            // we first compute `w ** (b[0] * 2^(log_max_height - 1) + ... + b[log_max_height - 1])`
224            // using a square-and-multiply algorithm.
225            let res = builder.exp_bits_big_endian(w, &index_bits_truncated);
226
227            // we now compute:
228            // tag_exp[log_max_height - i] = g * w ** (b[log_max_height - i] * 2^(log_max_height -
229            // 1) + ... + b[log_max_height - 1] * 2^(log_max_height - i))
230            // using a square-and-divide algorithm.
231            // g * res is tag_exp[0]
232            // `tag_exp` is used below as a rotated evaluation point in a coset of the LDE domain.
233            iter_zip!(builder, index_bits_truncated, tag_exp).for_each(|ptr_vec, builder| {
234                builder.iter_ptr_set(&tag_exp, ptr_vec[1], g * res);
235
236                let bit = builder.iter_ptr_get(&index_bits_truncated, ptr_vec[0]);
237                let div = builder.select_f(bit, max_gen_pow, one_var);
238                builder.assign(&res, res / div);
239                builder.assign(&res, res * res);
240            });
241        };
242        builder.cycle_tracker_end("cache-generator-powers");
243
244        iter_zip!(builder, query_proof.input_proof, rounds, rounds_context).for_each(
245            |ptr_vec, builder| {
246                let batch_opening = builder.iter_ptr_get(&query_proof.input_proof, ptr_vec[0]);
247                let round = builder.iter_ptr_get(&rounds, ptr_vec[1]);
248                let round_context = builder.iter_ptr_get(&rounds_context, ptr_vec[2]);
249
250                let batch_commit = round.batch_commit;
251                let mats = round.mats;
252                let RoundContext {
253                    ov_ptrs,
254                    perm_ov_ptrs,
255                    batch_dims,
256                    mat_alpha_pows,
257                    log_batch_max_height,
258                } = round_context;
259
260                // **ATTENTION**: always check shape of user inputs.
261                builder.assert_usize_eq(ov_ptrs.len(), mats.len());
262
263                let hint_id = batch_opening.opened_values.id.clone();
264                // For static to track the offset in the hint space.
265                let mut hint_offset = 0;
266                builder.cycle_tracker_start("compute-reduced-opening");
267                iter_zip!(builder, ov_ptrs, mats, mat_alpha_pows).for_each(|ptr_vec, builder| {
268                    let mat_opening = builder.iter_ptr_get(&ov_ptrs, ptr_vec[0]);
269                    let mat = builder.iter_ptr_get(&mats, ptr_vec[1]);
270                    let mat_alpha_pow = if builder.flags.static_only {
271                        builder.uninit()
272                    } else {
273                        builder.iter_ptr_get(&mat_alpha_pows, ptr_vec[2])
274                    };
275                    let mat_points = mat.points;
276                    let mat_values = mat.values;
277                    let log2_domain_size = mat.domain.log_n;
278                    let log_height = builder.eval_expr(log2_domain_size + RVar::from(log_blowup));
279
280                    let cur_ro = builder.get(&ro, log_height);
281                    let cur_alpha_pow = builder.get(&alpha_pow, log_height);
282
283                    builder.cycle_tracker_start("exp-reverse-bits-len");
284                    let height_idx = builder.eval_expr(log_max_lde_height - log_height);
285                    let x = builder.get(&tag_exp, height_idx);
286                    builder.cycle_tracker_end("exp-reverse-bits-len");
287
288                    let is_init: Usize<C::N> = builder.eval(C::N::ZERO);
289                    iter_zip!(builder, mat_points, mat_values).for_each(|ptr_vec, builder| {
290                        let z: Ext<C::F, C::EF> = builder.iter_ptr_get(&mat_points, ptr_vec[0]);
291                        let ps_at_z = builder.iter_ptr_get(&mat_values, ptr_vec[1]);
292
293                        builder.cycle_tracker_start("single-reduced-opening-eval");
294
295                        if builder.flags.static_only {
296                            let n: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
297                            let width = ps_at_z.len().value();
298                            if is_init.value() == 0 {
299                                let mat_opening_vals = {
300                                    let witness_refs = builder.get_witness_refs(hint_id.clone());
301                                    let start = hint_offset;
302                                    witness_refs[start..start + width].to_vec()
303                                };
304                                for (i, v) in mat_opening_vals.into_iter().enumerate() {
305                                    builder.set_value(&mat_opening, i, v.into());
306                                }
307                            }
308                            for (t, alpha_pow) in alpha_pows.iter().take(width).enumerate() {
309                                let p_at_x = builder.get(&mat_opening, t);
310                                let p_at_z = builder.get(&ps_at_z, t);
311                                builder.assign(&n, n + (p_at_z - p_at_x) * alpha_pow.clone());
312                            }
313                            builder.assign(&cur_ro, cur_ro + n / (z - x) * cur_alpha_pow);
314                            builder
315                                .assign(&cur_alpha_pow, cur_alpha_pow * alpha_pows[width].clone());
316                        } else {
317                            let mat_ro = builder.fri_single_reduced_opening_eval(
318                                alpha,
319                                hint_id.get_var(),
320                                is_init.get_var(),
321                                &mat_opening,
322                                &ps_at_z,
323                            );
324                            builder.assign(&cur_ro, cur_ro + (mat_ro * cur_alpha_pow / (z - x)));
325                            builder.assign(&cur_alpha_pow, cur_alpha_pow * mat_alpha_pow);
326                        }
327                        // The buffer `mat_opening` has now been written to, so we set `is_init` to
328                        // 1.
329                        builder.assign(&is_init, C::N::ONE);
330                        builder.cycle_tracker_end("single-reduced-opening-eval");
331                    });
332                    if builder.flags.static_only {
333                        hint_offset += mat_opening.len().value();
334                    }
335                    builder.set_value(&ro, log_height, cur_ro);
336                    builder.set_value(&alpha_pow, log_height, cur_alpha_pow);
337                });
338                builder.cycle_tracker_end("compute-reduced-opening");
339
340                let bits_reduced: Usize<_> =
341                    builder.eval(log_max_lde_height - log_batch_max_height);
342                let index_bits_shifted_v1 = index_bits.shift(builder, bits_reduced);
343
344                builder.cycle_tracker_start("verify-batch");
345                verify_batch::<C>(
346                    builder,
347                    &batch_commit,
348                    batch_dims,
349                    index_bits_shifted_v1,
350                    &NestedOpenedValues::Felt(perm_ov_ptrs),
351                    &batch_opening.opening_proof,
352                );
353                builder.cycle_tracker_end("verify-batch");
354            },
355        );
356
357        // Note[jpw]: we do not need to assert `ro[log_blowup] = 0` here because we include
358        // `ro[log_blowup]` in the `verify_query` low-degree test. See comments therein.
359
360        let folded_eval = verify_query(
361            builder,
362            config,
363            &proof.commit_phase_commits,
364            &index_bits,
365            &query_proof,
366            &betas,
367            &betas_squared,
368            &ro,
369            log_max_lde_height,
370            &i_plus_one_arr,
371        );
372
373        builder.assert_ext_eq(folded_eval, final_poly_ct);
374    });
375    builder.cycle_tracker_end("stage-d-verifier-verify");
376}
377
378impl<C: Config> FromConstant<C> for TwoAdicPcsRoundVariable<C>
379where
380    C::F: TwoAdicField,
381{
382    type Constant = (
383        Hash<C::F, C::F, DIGEST_SIZE>,
384        Vec<(TwoAdicMultiplicativeCoset<C::F>, Vec<(C::EF, Vec<C::EF>)>)>,
385    );
386
387    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
388        let (commit_val, domains_and_openings_val) = value;
389
390        // Allocate the commitment.
391        let commit = builder.dyn_array::<Felt<_>>(DIGEST_SIZE);
392        let commit_val: [C::F; DIGEST_SIZE] = commit_val.into();
393        for (i, f) in commit_val.into_iter().enumerate() {
394            builder.set(&commit, i, f);
395        }
396
397        let mats = builder
398            .dyn_array::<TwoAdicPcsMatsVariable<C>>(RVar::from(domains_and_openings_val.len()));
399
400        for (i, (domain, opening)) in domains_and_openings_val.into_iter().enumerate() {
401            let domain = builder.constant::<TwoAdicMultiplicativeCosetVariable<_>>(domain);
402
403            let points_val = opening.iter().map(|(p, _)| *p).collect::<Vec<_>>();
404            let values_val = opening.iter().map(|(_, v)| v.clone()).collect::<Vec<_>>();
405            let points: Array<_, Ext<_, _>> = builder.dyn_array(points_val.len());
406            for (j, point) in points_val.into_iter().enumerate() {
407                let el: Ext<_, _> = builder.eval(point.cons());
408                builder.set_value(&points, j, el);
409            }
410            let values: Array<_, Array<_, Ext<_, _>>> = builder.dyn_array(values_val.len());
411            for (j, val) in values_val.into_iter().enumerate() {
412                let tmp = builder.dyn_array(val.len());
413                for (k, v) in val.into_iter().enumerate() {
414                    let el: Ext<_, _> = builder.eval(v.cons());
415                    builder.set_value(&tmp, k, el);
416                }
417                builder.set_value(&values, j, tmp);
418            }
419            let mat = TwoAdicPcsMatsVariable {
420                domain,
421                points,
422                values,
423            };
424            builder.set_value(&mats, i, mat);
425        }
426
427        Self {
428            batch_commit: DigestVariable::Felt(commit),
429            mats,
430            permutation: builder.dyn_array(0),
431        }
432    }
433}
434
435#[derive(Clone)]
436pub struct TwoAdicFriPcsVariable<C: Config> {
437    pub config: FriConfigVariable<C>,
438}
439
440impl<C: Config> PcsVariable<C> for TwoAdicFriPcsVariable<C>
441where
442    C::F: TwoAdicField,
443    C::EF: TwoAdicField,
444{
445    type Domain = TwoAdicMultiplicativeCosetVariable<C>;
446
447    type Commitment = DigestVariable<C>;
448
449    type Proof = FriProofVariable<C>;
450
451    fn natural_domain_for_log_degree(
452        &self,
453        builder: &mut Builder<C>,
454        log_degree: RVar<C::N>,
455    ) -> Self::Domain {
456        self.config.get_subgroup(builder, log_degree)
457    }
458
459    fn verify(
460        &self,
461        builder: &mut Builder<C>,
462        rounds: Array<C, TwoAdicPcsRoundVariable<C>>,
463        proof: Self::Proof,
464        log_max_height: RVar<C::N>,
465        challenger: &mut impl ChallengerVariable<C>,
466    ) {
467        verify_two_adic_pcs(
468            builder,
469            &self.config,
470            rounds,
471            proof,
472            log_max_height,
473            challenger,
474        )
475    }
476}
477
478fn get_max_matrix_width<C: Config>(
479    builder: &mut Builder<C>,
480    rounds: &Array<C, TwoAdicPcsRoundVariable<C>>,
481) -> usize {
482    let mut ret = 0;
483    for i in 0..rounds.len().value() {
484        let round = builder.get(rounds, i);
485        for j in 0..round.mats.len().value() {
486            let mat = builder.get(&round.mats, j);
487            let local = builder.get(&mat.values, 0);
488            ret = ret.max(local.len().value());
489        }
490    }
491    ret
492}
493
494#[derive(DslVariable, Clone)]
495struct RoundContext<C: Config> {
496    /// Opened values buffer.
497    ov_ptrs: Array<C, Array<C, Felt<C::F>>>,
498    /// Permuted opened values buffer.
499    perm_ov_ptrs: Array<C, Array<C, Felt<C::F>>>,
500    /// Permuted matrix dimensions.
501    batch_dims: Array<C, DimensionsVariable<C>>,
502    /// Alpha pows for each matrix.
503    mat_alpha_pows: Array<C, Ext<C::F, C::EF>>,
504    /// Max height in the matrices.
505    log_batch_max_height: Usize<C::N>,
506}
507
508fn compute_rounds_context<C: Config>(
509    builder: &mut Builder<C>,
510    rounds: &Array<C, TwoAdicPcsRoundVariable<C>>,
511    log_blowup: usize,
512    alpha: Ext<C::F, C::EF>,
513) -> Array<C, RoundContext<C>> {
514    let ret: Array<C, RoundContext<C>> = builder.array(rounds.len());
515
516    // This maximum is safe because the log width of any matrix in an AIR must fit
517    // within a single field element.
518    const MAX_LOG_WIDTH: usize = 31;
519    let pow_of_alpha: Array<C, Ext<_, _>> = builder.array(MAX_LOG_WIDTH);
520    if !builder.flags.static_only {
521        let current: Ext<_, _> = builder.eval(alpha);
522        for i in 0..MAX_LOG_WIDTH {
523            builder.set(&pow_of_alpha, i, current);
524            builder.assign(&current, current * current);
525        }
526    }
527
528    iter_zip!(builder, rounds, ret).for_each(|ptr_vec, builder| {
529        let round = builder.iter_ptr_get(rounds, ptr_vec[0]);
530        let permutation = round.permutation;
531        let to_perm_index = |builder: &mut Builder<_>, k: RVar<_>| {
532            // Always no permutation in static mode
533            if builder.flags.static_only {
534                builder.eval(k)
535            } else {
536                let ret: Usize<_> = builder.uninit();
537                builder.if_eq(permutation.len(), RVar::zero()).then_or_else(
538                    |builder| {
539                        builder.assign(&ret, k);
540                    },
541                    |builder| {
542                        let value = builder.get(&permutation, k);
543                        builder.assign(&ret, value);
544                    },
545                );
546                ret
547            }
548        };
549
550        let ov_ptrs: Array<C, Array<C, Felt<C::F>>> = builder.array(round.mats.len());
551        let perm_ov_ptrs: Array<C, Array<C, Felt<C::F>>> = builder.array(round.mats.len());
552        let batch_dims: Array<C, DimensionsVariable<C>> = builder.array(round.mats.len());
553        let mat_alpha_pows: Array<C, Ext<_, _>> = builder.array(round.mats.len());
554        let log_batch_max_height: Usize<_> = {
555            let log_batch_max_index = to_perm_index(builder, RVar::zero());
556            let mat = builder.get(&round.mats, log_batch_max_index);
557            let domain = mat.domain;
558            builder.eval(domain.log_n + RVar::from(log_blowup))
559        };
560
561        iter_zip!(builder, round.mats, ov_ptrs, mat_alpha_pows).for_each(|ptr_vec, builder| {
562            let mat = builder.iter_ptr_get(&round.mats, ptr_vec[0]);
563            let local = builder.get(&mat.values, 0);
564            // We allocate the underlying buffer for the current `ov_ptr` here. On allocation, it is
565            // uninit, and will be written to on the first call of
566            // `fri_single_reduced_opening_eval` for this `ov_ptr`.
567            let buf = builder.array(local.len());
568            let width = buf.len();
569            builder.iter_ptr_set(&ov_ptrs, ptr_vec[1], buf);
570
571            if !builder.flags.static_only {
572                let width = width.get_var();
573                // This is dynamic only so safe to cast.
574                let width_f = builder.unsafe_cast_var_to_felt(width);
575                let bits = builder.num2bits_f(width_f, MAX_LOG_WIDTH as u32);
576                let mat_alpha_pow: Ext<_, _> = builder.eval(C::EF::ONE.cons());
577                for i in 0..MAX_LOG_WIDTH {
578                    let bit = builder.get(&bits, i);
579                    builder.if_eq(bit, RVar::one()).then(|builder| {
580                        let to_mul = builder.get(&pow_of_alpha, i);
581                        builder.assign(&mat_alpha_pow, mat_alpha_pow * to_mul);
582                    });
583                }
584                builder.iter_ptr_set(&mat_alpha_pows, ptr_vec[2], mat_alpha_pow);
585            }
586        });
587        builder
588            .range(0, round.mats.len())
589            .for_each(|i_vec, builder| {
590                let i = i_vec[0];
591                let perm_i = to_perm_index(builder, i);
592                let mat = builder.get(&round.mats, perm_i.clone());
593
594                let domain = mat.domain;
595                let dim = DimensionsVariable::<C> {
596                    log_height: builder.eval(domain.log_n + RVar::from(log_blowup)),
597                };
598                builder.set_value(&batch_dims, i, dim);
599                let perm_ov_ptr = builder.get(&ov_ptrs, perm_i);
600                // Note both `ov_ptrs` and `perm_ov_ptrs` point to the same memory.
601                builder.set_value(&perm_ov_ptrs, i, perm_ov_ptr);
602            });
603        builder.iter_ptr_set(
604            &ret,
605            ptr_vec[1],
606            RoundContext {
607                ov_ptrs,
608                perm_ov_ptrs,
609                batch_dims,
610                mat_alpha_pows,
611                log_batch_max_height,
612            },
613        );
614    });
615    ret
616}
617
618pub mod tests {
619    use std::cmp::Reverse;
620
621    use itertools::Itertools;
622    use openvm_circuit::arch::instructions::program::Program;
623    use openvm_native_compiler::{
624        asm::AsmBuilder,
625        conversion::CompilerOptions,
626        ir::{Array, RVar, DIGEST_SIZE},
627    };
628    use openvm_stark_backend::{
629        config::{StarkGenericConfig, Val},
630        engine::StarkEngine,
631        p3_challenger::{CanObserve, FieldChallenger},
632        p3_commit::{Pcs, TwoAdicMultiplicativeCoset},
633        p3_matrix::dense::RowMajorMatrix,
634    };
635    use openvm_stark_sdk::{
636        config::baby_bear_poseidon2::{default_engine, BabyBearPoseidon2Config},
637        p3_baby_bear::BabyBear,
638    };
639    use rand::rngs::OsRng;
640
641    use crate::{
642        challenger::{duplex::DuplexChallengerVariable, CanObserveDigest, FeltChallenger},
643        commit::PcsVariable,
644        digest::DigestVariable,
645        fri::{
646            types::TwoAdicPcsRoundVariable, TwoAdicFriPcsVariable,
647            TwoAdicMultiplicativeCosetVariable,
648        },
649        hints::{Hintable, InnerFriProof, InnerVal},
650        utils::const_fri_config,
651    };
652
653    pub fn build_test_fri_with_cols_and_log2_rows(
654        nb_cols: usize,
655        nb_log2_rows: usize,
656    ) -> (Program<BabyBear>, Vec<Vec<BabyBear>>) {
657        type SC = BabyBearPoseidon2Config;
658        type F = Val<SC>;
659        type EF = <SC as StarkGenericConfig>::Challenge;
660        type Challenger = <SC as StarkGenericConfig>::Challenger;
661        type ScPcs = <SC as StarkGenericConfig>::Pcs;
662
663        let mut rng = &mut OsRng;
664        let log_degrees = &[nb_log2_rows];
665        let engine = default_engine();
666        let pcs = engine.config().pcs();
667        let perm = engine.perm.clone();
668
669        // Generate proof.
670        let domains_and_polys = log_degrees
671            .iter()
672            .map(|&d| {
673                (
674                    <ScPcs as Pcs<EF, Challenger>>::natural_domain_for_degree(pcs, 1 << d),
675                    RowMajorMatrix::<F>::rand(&mut rng, 1 << d, nb_cols),
676                )
677            })
678            .sorted_by_key(|(dom, _)| Reverse(dom.log_n))
679            .collect::<Vec<_>>();
680        let (commit, data) = <ScPcs as Pcs<EF, Challenger>>::commit(pcs, domains_and_polys.clone());
681        let mut challenger = Challenger::new(perm.clone());
682        challenger.observe(commit);
683        let zeta = challenger.sample_ext_element::<EF>();
684        let points = domains_and_polys
685            .iter()
686            .map(|_| vec![zeta])
687            .collect::<Vec<_>>();
688        let (opening, proof) = pcs.open(vec![(&data, points)], &mut challenger);
689
690        // Verify proof.
691        let mut challenger = Challenger::new(perm.clone());
692        challenger.observe(commit);
693        challenger.sample_ext_element::<EF>();
694        let os: Vec<(TwoAdicMultiplicativeCoset<F>, Vec<_>)> = domains_and_polys
695            .iter()
696            .zip(&opening[0])
697            .map(|((domain, _), mat_openings)| (*domain, vec![(zeta, mat_openings[0].clone())]))
698            .collect();
699        pcs.verify(vec![(commit, os.clone())], &proof, &mut challenger)
700            .unwrap();
701
702        // Test the recursive Pcs.
703        let mut builder = AsmBuilder::<F, EF>::default();
704        let config = const_fri_config(&mut builder, &engine.fri_params);
705        let pcs_var = TwoAdicFriPcsVariable { config };
706        let rounds =
707            builder.constant::<Array<_, TwoAdicPcsRoundVariable<_>>>(vec![(commit, os.clone())]);
708
709        // Test natural domain for degree.
710        for log_d_val in log_degrees.iter() {
711            let log_d = *log_d_val;
712            let domain = pcs_var.natural_domain_for_log_degree(&mut builder, RVar::from(log_d));
713
714            let domain_val =
715                <ScPcs as Pcs<EF, Challenger>>::natural_domain_for_degree(pcs, 1 << log_d_val);
716
717            let expected_domain: TwoAdicMultiplicativeCosetVariable<_> =
718                builder.constant(domain_val);
719
720            builder.assert_eq::<TwoAdicMultiplicativeCosetVariable<_>>(domain, expected_domain);
721        }
722
723        // Test proof verification.
724        let proofvar = InnerFriProof::read(&mut builder);
725        let mut challenger = DuplexChallengerVariable::new(&mut builder);
726        let commit = <[InnerVal; DIGEST_SIZE]>::from(commit).to_vec();
727        let commit = DigestVariable::Felt(builder.constant::<Array<_, _>>(commit));
728        challenger.observe_digest(&mut builder, commit);
729        challenger.sample_ext(&mut builder);
730        pcs_var.verify(
731            &mut builder,
732            rounds,
733            proofvar,
734            RVar::from(nb_log2_rows),
735            &mut challenger,
736        );
737        builder.halt();
738
739        let program =
740            builder.compile_isa_with_options(CompilerOptions::default().with_cycle_tracker());
741        let mut witness_stream = Vec::new();
742        witness_stream.extend(proof.write());
743        (program, witness_stream)
744    }
745
746    #[test]
747    fn test_two_adic_fri_pcs_single_batch() {
748        let (program, witness) = build_test_fri_with_cols_and_log2_rows(10, 10);
749        openvm_native_circuit::execute_program(program, witness);
750    }
751}