openvm_native_recursion/fri/
two_adic_pcs.rs

1use openvm_native_compiler::prelude::*;
2use openvm_native_compiler_derive::iter_zip;
3use openvm_stark_backend::{
4    p3_commit::TwoAdicMultiplicativeCoset,
5    p3_field::{FieldAlgebra, FieldExtensionAlgebra, TwoAdicField},
6};
7use p3_symmetric::Hash;
8
9use super::{
10    types::{
11        DimensionsVariable, FriConfigVariable, TwoAdicPcsMatsVariable, TwoAdicPcsRoundVariable,
12    },
13    verify_batch, verify_query, NestedOpenedValues, TwoAdicMultiplicativeCosetVariable,
14};
15use crate::{
16    challenger::ChallengerVariable, commit::PcsVariable, digest::DigestVariable,
17    fri::types::FriProofVariable,
18};
19
20// The maximum two-adicity of Felt `C::F`. This means `C::F` does not have a multiplicative subgroup
21// of order 2^{MAX_TWO_ADICITY + 1}. Currently set to 27 for BabyBear.
22pub const MAX_TWO_ADICITY: usize = 27;
23
24/// Notes:
25/// 1. FieldMerkleTreeMMCS sorts traces by height in descending order when committing data.
26/// 2. **Required** that `C::F` has two-adicity <= [MAX_TWO_ADICITY]. In particular this implies that all LDE matrices have
27///    `log2(lde_height) <= MAX_TWO_ADICITY`.
28/// 3. **Required** that the maximum trace height is `2^log_max_height - 1`.
29///
30/// Reference:
31/// <https://github.com/Plonky3/Plonky3/blob/27b3127dab047e07145c38143379edec2960b3e1/merkle-tree/src/merkle_tree.rs#L53>
32/// So traces are sorted in `opening_proof`.
33///
34/// 2. FieldMerkleTreeMMCS::poseidon2 keeps the raw values in the original order. So traces are not sorted in `opened_values`.
35///
36/// Reference:
37/// <https://github.com/Plonky3/Plonky3/blob/27b3127dab047e07145c38143379edec2960b3e1/merkle-tree/src/mmcs.rs#L87>
38/// <https://github.com/Plonky3/Plonky3/blob/27b3127dab047e07145c38143379edec2960b3e1/merkle-tree/src/merkle_tree.rs#L100>
39/// <https://github.com/Plonky3/Plonky3/blob/784b7dd1fa87c1202e63350cc8182d7c5327a7af/fri/src/verifier.rs#L22>
40pub fn verify_two_adic_pcs<C: Config>(
41    builder: &mut Builder<C>,
42    config: &FriConfigVariable<C>,
43    rounds: Array<C, TwoAdicPcsRoundVariable<C>>,
44    proof: FriProofVariable<C>,
45    log_max_height: RVar<C::N>,
46    challenger: &mut impl ChallengerVariable<C>,
47) where
48    C::F: TwoAdicField,
49    C::EF: TwoAdicField,
50{
51    // Currently do not support other final poly len
52    builder.assert_var_eq(RVar::from(config.log_final_poly_len), RVar::zero());
53    // The `proof.final_poly` length is in general `2^{log_final_poly_len + log_blowup}`.
54    // We require `log_final_poly_len = 0`, so `proof.final_poly` has length `2^log_blowup`.
55    // In fact, the FRI low-degree test requires that `proof.final_poly = [constant, 0, ..., 0]`.
56    builder.assert_usize_eq(proof.final_poly.len(), RVar::from(config.blowup));
57    // Constant term of final poly
58    let final_poly_ct = builder.get(&proof.final_poly, 0);
59    for i in 1..config.blowup {
60        let term = builder.get(&proof.final_poly, i);
61        builder.assert_ext_eq(term, C::EF::ZERO.cons());
62    }
63
64    let g = builder.generator();
65
66    let log_blowup = config.log_blowup;
67    iter_zip!(builder, rounds).for_each(|ptr_vec, builder| {
68        let round = builder.iter_ptr_get(&rounds, ptr_vec[0]);
69        iter_zip!(builder, round.mats).for_each(|ptr_vec, builder| {
70            let mat = builder.iter_ptr_get(&round.mats, ptr_vec[0]);
71            iter_zip!(builder, mat.values).for_each(|ptr_vec, builder| {
72                let value = builder.iter_ptr_get(&mat.values, ptr_vec[0]);
73                iter_zip!(builder, value).for_each(|ptr_vec, builder| {
74                    if builder.flags.static_only {
75                        let ext = builder.iter_ptr_get(&value, ptr_vec[0]);
76                        let arr = builder.ext2felt(ext);
77                        challenger.observe_slice(builder, arr);
78                    } else {
79                        let ptr = ptr_vec[0];
80                        for i in 0..C::EF::D {
81                            let f: Felt<_> = builder.uninit();
82                            builder.load(
83                                f,
84                                Ptr {
85                                    address: ptr.variable(),
86                                },
87                                MemIndex {
88                                    index: RVar::from(i),
89                                    offset: 0,
90                                    size: 1,
91                                },
92                            );
93                            challenger.observe(builder, f);
94                        }
95                    }
96                });
97            });
98        });
99    });
100    let alpha = challenger.sample_ext(builder);
101    if builder.flags.static_only {
102        builder.ext_reduce_circuit(alpha);
103    }
104
105    builder.cycle_tracker_start("stage-d-verifier-verify");
106    // **ATTENTION**: always check shape of user inputs.
107    builder.assert_usize_eq(proof.query_proofs.len(), RVar::from(config.num_queries));
108    builder.assert_usize_eq(proof.commit_phase_commits.len(), log_max_height);
109    let betas: Array<C, Ext<C::F, C::EF>> = builder.array(log_max_height);
110    iter_zip!(builder, proof.commit_phase_commits, betas).for_each(|ptr_vec, builder| {
111        let comm_ptr = ptr_vec[0];
112        let beta_ptr = ptr_vec[1];
113        let comm = builder.iter_ptr_get(&proof.commit_phase_commits, comm_ptr);
114        challenger.observe_digest(builder, comm);
115        let sample = challenger.sample_ext(builder);
116        builder.iter_ptr_set(&betas, beta_ptr, sample);
117    });
118
119    iter_zip!(builder, proof.final_poly).for_each(|ptr_vec, builder| {
120        let final_poly_elem = builder.iter_ptr_get(&proof.final_poly, ptr_vec[0]);
121        let final_poly_elem_felts = builder.ext2felt(final_poly_elem);
122        challenger.observe_slice(builder, final_poly_elem_felts);
123    });
124
125    challenger.check_witness(builder, config.proof_of_work_bits, proof.pow_witness);
126
127    let log_max_lde_height = builder.eval_expr(log_max_height + RVar::from(log_blowup));
128    // tag_exp is a shared buffer.
129    let tag_exp: Array<C, Felt<C::F>> = builder.array(log_max_lde_height);
130    let w = config.get_two_adic_generator(builder, log_max_lde_height);
131    let max_gen_pow = config.get_two_adic_generator(builder, 1);
132    let one_var: Felt<C::F> = builder.eval(C::F::ONE);
133
134    builder.cycle_tracker_start("pre-compute-rounds-context");
135    let rounds_context = compute_rounds_context(builder, &rounds, log_blowup, alpha);
136    // Only used in static mode.
137    let alpha_pows = if builder.flags.static_only {
138        let max_width = get_max_matrix_width(builder, &rounds);
139        let mut ret = Vec::with_capacity(max_width + 1);
140        ret.push(C::EF::ONE.cons());
141        for i in 1..=max_width {
142            let curr = builder.eval(ret[i - 1].clone() * alpha);
143            builder.ext_reduce_circuit(curr);
144            ret.push(curr.into());
145        }
146        ret
147    } else {
148        vec![]
149    };
150    builder.cycle_tracker_end("pre-compute-rounds-context");
151
152    // Accumulators of the reduced opening sums, reset per query. The array `ro` is indexed by log_height.
153    let ro: Array<C, Ext<C::F, C::EF>> = builder.array(MAX_TWO_ADICITY + 1);
154    let alpha_pow: Array<C, Ext<C::F, C::EF>> = builder.array(MAX_TWO_ADICITY + 1);
155
156    iter_zip!(builder, proof.query_proofs).for_each(|ptr_vec, builder| {
157        let query_proof = builder.iter_ptr_get(&proof.query_proofs, ptr_vec[0]);
158        let index_bits = challenger.sample_bits(builder, log_max_lde_height);
159
160        // We reset the reduced opening accumulators at the start of each query.
161        // We describe what `ro[log_height]` computes per query in pseduo-code, where `log_height` is log2 of the size of the LDE domain:
162        // ro[log_height] = 0
163        // alpha_pow[log_height] = 1
164        // for round in rounds:
165        //   for mat in round.mats where (mat.domain.log_n + log_blowup == log_height): // preserving order of round.mats
166        //      // g is generator of F
167        //      // w_{log_height} is generator of subgroup of F of order 2^log_height
168        //      x = g * w_{log_height}^{reverse_bits(index >> (log_max_height - log_height), log_height)}
169        //      // reverse_bits(x, bits) takes an unsigned integer x with `bits` bits and returns the unsigned integer with the bits of x reversed.
170        //      // x is a rotated evaluation point in a coset of the LDE domain.
171        //      ps_at_x = [claimed evaluation of p at x for each polynomial p corresponding to column of mat]
172        //      // ps_at_x is array of Felt
173        //      for (z, ps_at_z) in zip(mat.points, mat.values):
174        //        // z is an out of domain point in Ext. There may be multiple per round to account for rotations in AIR constraints.
175        //        // ps_at_z is array of Ext for [claimed evaluation of p at z for each polyomial p corresponding to column of mat]
176        //        for (p_at_x, p_at_z) in zip(ps_at_x, ps_at_z):
177        //          ro[log_height] += alpha_pow[log_height] * (p_at_x - p_at_z) / (x - z)
178        //          alpha_pow[log_height] *= alpha
179        //
180        // The final value of ro[log_height] is the reduced opening value for log_height.
181        if builder.flags.static_only {
182            for j in 0..=MAX_TWO_ADICITY {
183                // ATTENTION: don't use set_value here, Fixed will share the same variable.
184                builder.set(&ro, j, C::EF::ZERO.cons());
185                builder.set(&alpha_pow, j, C::EF::ONE.cons());
186            }
187        } else {
188            let zero_ef = builder.eval(C::EF::ZERO.cons());
189            let one_ef = builder.eval(C::EF::ONE.cons());
190            for j in 0..=MAX_TWO_ADICITY {
191                // Use set_value here to save a copy.
192                builder.set_value(&ro, j, zero_ef);
193                builder.set_value(&alpha_pow, j, one_ef);
194            }
195        }
196        // **ATTENTION**: always check shape of user inputs.
197        builder.assert_usize_eq(query_proof.input_proof.len(), rounds.len());
198
199        // Pre-compute tag_exp
200        builder.cycle_tracker_start("cache-generator-powers");
201        {
202            // truncate index_bits to log_max_height
203            let index_bits_truncated = index_bits.slice(builder, 0, log_max_lde_height);
204
205            // b = index_bits
206            // w = generator of order 2^log_max_height
207            // we first compute `w ** (b[0] * 2^(log_max_height - 1) + ... + b[log_max_height - 1])` using a square-and-multiply algorithm.
208            let res = builder.exp_bits_big_endian(w, &index_bits_truncated);
209
210            // we now compute:
211            // tag_exp[log_max_height - i] = g * w ** (b[log_max_height - i] * 2^(log_max_height - 1) + ... + b[log_max_height - 1] * 2^(log_max_height - i))
212            // using a square-and-divide algorithm.
213            // g * res is tag_exp[0]
214            // `tag_exp` is used below as a rotated evaluation point in a coset of the LDE domain.
215            iter_zip!(builder, index_bits_truncated, tag_exp).for_each(|ptr_vec, builder| {
216                builder.iter_ptr_set(&tag_exp, ptr_vec[1], g * res);
217
218                let bit = builder.iter_ptr_get(&index_bits_truncated, ptr_vec[0]);
219                let div = builder.select_f(bit, max_gen_pow, one_var);
220                builder.assign(&res, res / div);
221                builder.assign(&res, res * res);
222            });
223        };
224        builder.cycle_tracker_end("cache-generator-powers");
225
226        iter_zip!(builder, query_proof.input_proof, rounds, rounds_context).for_each(
227            |ptr_vec, builder| {
228                let batch_opening = builder.iter_ptr_get(&query_proof.input_proof, ptr_vec[0]);
229                let round = builder.iter_ptr_get(&rounds, ptr_vec[1]);
230                let round_context = builder.iter_ptr_get(&rounds_context, ptr_vec[2]);
231
232                let batch_commit = round.batch_commit;
233                let mats = round.mats;
234                let RoundContext {
235                    ov_ptrs,
236                    perm_ov_ptrs,
237                    batch_dims,
238                    mat_alpha_pows,
239                    log_batch_max_height,
240                } = round_context;
241
242                // **ATTENTION**: always check shape of user inputs.
243                builder.assert_usize_eq(ov_ptrs.len(), mats.len());
244
245                let hint_id = batch_opening.opened_values.id.clone();
246                // For static to track the offset in the hint space.
247                let mut hint_offset = 0;
248                builder.cycle_tracker_start("compute-reduced-opening");
249                iter_zip!(builder, ov_ptrs, mats, mat_alpha_pows).for_each(|ptr_vec, builder| {
250                    let mat_opening = builder.iter_ptr_get(&ov_ptrs, ptr_vec[0]);
251                    let mat = builder.iter_ptr_get(&mats, ptr_vec[1]);
252                    let mat_alpha_pow = if builder.flags.static_only {
253                        builder.uninit()
254                    } else {
255                        builder.iter_ptr_get(&mat_alpha_pows, ptr_vec[2])
256                    };
257                    let mat_points = mat.points;
258                    let mat_values = mat.values;
259                    let log2_domain_size = mat.domain.log_n;
260                    let log_height = builder.eval_expr(log2_domain_size + RVar::from(log_blowup));
261
262                    let cur_ro = builder.get(&ro, log_height);
263                    let cur_alpha_pow = builder.get(&alpha_pow, log_height);
264
265                    builder.cycle_tracker_start("exp-reverse-bits-len");
266                    let height_idx = builder.eval_expr(log_max_lde_height - log_height);
267                    let x = builder.get(&tag_exp, height_idx);
268                    builder.cycle_tracker_end("exp-reverse-bits-len");
269
270                    let is_init: Usize<C::N> = builder.eval(C::N::ZERO);
271                    iter_zip!(builder, mat_points, mat_values).for_each(|ptr_vec, builder| {
272                        let z: Ext<C::F, C::EF> = builder.iter_ptr_get(&mat_points, ptr_vec[0]);
273                        let ps_at_z = builder.iter_ptr_get(&mat_values, ptr_vec[1]);
274
275                        builder.cycle_tracker_start("single-reduced-opening-eval");
276
277                        if builder.flags.static_only {
278                            let n: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
279                            let width = ps_at_z.len().value();
280                            if is_init.value() == 0 {
281                                let mat_opening_vals = {
282                                    let witness_refs = builder.get_witness_refs(hint_id.clone());
283                                    let start = hint_offset;
284                                    witness_refs[start..start + width].to_vec()
285                                };
286                                for (i, v) in mat_opening_vals.into_iter().enumerate() {
287                                    builder.set_value(&mat_opening, i, v.into());
288                                }
289                            }
290                            for (t, alpha_pow) in alpha_pows.iter().take(width).enumerate() {
291                                let p_at_x = builder.get(&mat_opening, t);
292                                let p_at_z = builder.get(&ps_at_z, t);
293                                builder.assign(&n, n + (p_at_z - p_at_x) * alpha_pow.clone());
294                            }
295                            builder.assign(&cur_ro, cur_ro + n / (z - x) * cur_alpha_pow);
296                            builder
297                                .assign(&cur_alpha_pow, cur_alpha_pow * alpha_pows[width].clone());
298                        } else {
299                            let mat_ro = builder.fri_single_reduced_opening_eval(
300                                alpha,
301                                hint_id.get_var(),
302                                is_init.get_var(),
303                                &mat_opening,
304                                &ps_at_z,
305                            );
306                            builder.assign(&cur_ro, cur_ro + (mat_ro * cur_alpha_pow / (z - x)));
307                            builder.assign(&cur_alpha_pow, cur_alpha_pow * mat_alpha_pow);
308                        }
309                        // The buffer `mat_opening` has now been written to, so we set `is_init` to 1.
310                        builder.assign(&is_init, C::N::ONE);
311                        builder.cycle_tracker_end("single-reduced-opening-eval");
312                    });
313                    if builder.flags.static_only {
314                        hint_offset += mat_opening.len().value();
315                    }
316                    builder.set_value(&ro, log_height, cur_ro);
317                    builder.set_value(&alpha_pow, log_height, cur_alpha_pow);
318                });
319                builder.cycle_tracker_end("compute-reduced-opening");
320
321                let bits_reduced: Usize<_> =
322                    builder.eval(log_max_lde_height - log_batch_max_height);
323                let index_bits_shifted_v1 = index_bits.shift(builder, bits_reduced);
324
325                builder.cycle_tracker_start("verify-batch");
326                verify_batch::<C>(
327                    builder,
328                    &batch_commit,
329                    batch_dims,
330                    index_bits_shifted_v1,
331                    &NestedOpenedValues::Felt(perm_ov_ptrs),
332                    &batch_opening.opening_proof,
333                );
334                builder.cycle_tracker_end("verify-batch");
335            },
336        );
337
338        let folded_eval = verify_query(
339            builder,
340            config,
341            &proof.commit_phase_commits,
342            &index_bits,
343            &query_proof,
344            &betas,
345            &ro,
346            log_max_lde_height,
347        );
348
349        builder.assert_ext_eq(folded_eval, final_poly_ct);
350    });
351    builder.cycle_tracker_end("stage-d-verifier-verify");
352}
353
354impl<C: Config> FromConstant<C> for TwoAdicPcsRoundVariable<C>
355where
356    C::F: TwoAdicField,
357{
358    type Constant = (
359        Hash<C::F, C::F, DIGEST_SIZE>,
360        Vec<(TwoAdicMultiplicativeCoset<C::F>, Vec<(C::EF, Vec<C::EF>)>)>,
361    );
362
363    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
364        let (commit_val, domains_and_openings_val) = value;
365
366        // Allocate the commitment.
367        let commit = builder.dyn_array::<Felt<_>>(DIGEST_SIZE);
368        let commit_val: [C::F; DIGEST_SIZE] = commit_val.into();
369        for (i, f) in commit_val.into_iter().enumerate() {
370            builder.set(&commit, i, f);
371        }
372
373        let mats = builder
374            .dyn_array::<TwoAdicPcsMatsVariable<C>>(RVar::from(domains_and_openings_val.len()));
375
376        for (i, (domain, opening)) in domains_and_openings_val.into_iter().enumerate() {
377            let domain = builder.constant::<TwoAdicMultiplicativeCosetVariable<_>>(domain);
378
379            let points_val = opening.iter().map(|(p, _)| *p).collect::<Vec<_>>();
380            let values_val = opening.iter().map(|(_, v)| v.clone()).collect::<Vec<_>>();
381            let points: Array<_, Ext<_, _>> = builder.dyn_array(points_val.len());
382            for (j, point) in points_val.into_iter().enumerate() {
383                let el: Ext<_, _> = builder.eval(point.cons());
384                builder.set_value(&points, j, el);
385            }
386            let values: Array<_, Array<_, Ext<_, _>>> = builder.dyn_array(values_val.len());
387            for (j, val) in values_val.into_iter().enumerate() {
388                let tmp = builder.dyn_array(val.len());
389                for (k, v) in val.into_iter().enumerate() {
390                    let el: Ext<_, _> = builder.eval(v.cons());
391                    builder.set_value(&tmp, k, el);
392                }
393                builder.set_value(&values, j, tmp);
394            }
395            let mat = TwoAdicPcsMatsVariable {
396                domain,
397                points,
398                values,
399            };
400            builder.set_value(&mats, i, mat);
401        }
402
403        Self {
404            batch_commit: DigestVariable::Felt(commit),
405            mats,
406            permutation: builder.dyn_array(0),
407        }
408    }
409}
410
411#[derive(Clone)]
412pub struct TwoAdicFriPcsVariable<C: Config> {
413    pub config: FriConfigVariable<C>,
414}
415
416impl<C: Config> PcsVariable<C> for TwoAdicFriPcsVariable<C>
417where
418    C::F: TwoAdicField,
419    C::EF: TwoAdicField,
420{
421    type Domain = TwoAdicMultiplicativeCosetVariable<C>;
422
423    type Commitment = DigestVariable<C>;
424
425    type Proof = FriProofVariable<C>;
426
427    fn natural_domain_for_log_degree(
428        &self,
429        builder: &mut Builder<C>,
430        log_degree: RVar<C::N>,
431    ) -> Self::Domain {
432        self.config.get_subgroup(builder, log_degree)
433    }
434
435    fn verify(
436        &self,
437        builder: &mut Builder<C>,
438        rounds: Array<C, TwoAdicPcsRoundVariable<C>>,
439        proof: Self::Proof,
440        log_max_height: RVar<C::N>,
441        challenger: &mut impl ChallengerVariable<C>,
442    ) {
443        verify_two_adic_pcs(
444            builder,
445            &self.config,
446            rounds,
447            proof,
448            log_max_height,
449            challenger,
450        )
451    }
452}
453
454fn get_max_matrix_width<C: Config>(
455    builder: &mut Builder<C>,
456    rounds: &Array<C, TwoAdicPcsRoundVariable<C>>,
457) -> usize {
458    let mut ret = 0;
459    for i in 0..rounds.len().value() {
460        let round = builder.get(rounds, i);
461        for j in 0..round.mats.len().value() {
462            let mat = builder.get(&round.mats, j);
463            let local = builder.get(&mat.values, 0);
464            ret = ret.max(local.len().value());
465        }
466    }
467    ret
468}
469
470#[derive(DslVariable, Clone)]
471struct RoundContext<C: Config> {
472    /// Opened values buffer.
473    ov_ptrs: Array<C, Array<C, Felt<C::F>>>,
474    /// Permuted opened values buffer.
475    perm_ov_ptrs: Array<C, Array<C, Felt<C::F>>>,
476    /// Permuted matrix dimensions.
477    batch_dims: Array<C, DimensionsVariable<C>>,
478    /// Alpha pows for each matrix.
479    mat_alpha_pows: Array<C, Ext<C::F, C::EF>>,
480    /// Max height in the matrices.
481    log_batch_max_height: Usize<C::N>,
482}
483
484fn compute_rounds_context<C: Config>(
485    builder: &mut Builder<C>,
486    rounds: &Array<C, TwoAdicPcsRoundVariable<C>>,
487    log_blowup: usize,
488    alpha: Ext<C::F, C::EF>,
489) -> Array<C, RoundContext<C>> {
490    let ret: Array<C, RoundContext<C>> = builder.array(rounds.len());
491
492    // This maximum is safe because the log width of any matrix in an AIR must fit
493    // within a single field element.
494    const MAX_LOG_WIDTH: usize = 31;
495    let pow_of_alpha: Array<C, Ext<_, _>> = builder.array(MAX_LOG_WIDTH);
496    if !builder.flags.static_only {
497        let current: Ext<_, _> = builder.eval(alpha);
498        for i in 0..MAX_LOG_WIDTH {
499            builder.set(&pow_of_alpha, i, current);
500            builder.assign(&current, current * current);
501        }
502    }
503
504    iter_zip!(builder, rounds, ret).for_each(|ptr_vec, builder| {
505        let round = builder.iter_ptr_get(rounds, ptr_vec[0]);
506        let permutation = round.permutation;
507        let to_perm_index = |builder: &mut Builder<_>, k: RVar<_>| {
508            // Always no permutation in static mode
509            if builder.flags.static_only {
510                builder.eval(k)
511            } else {
512                let ret: Usize<_> = builder.uninit();
513                builder.if_eq(permutation.len(), RVar::zero()).then_or_else(
514                    |builder| {
515                        builder.assign(&ret, k);
516                    },
517                    |builder| {
518                        let value = builder.get(&permutation, k);
519                        builder.assign(&ret, value);
520                    },
521                );
522                ret
523            }
524        };
525
526        let ov_ptrs: Array<C, Array<C, Felt<C::F>>> = builder.array(round.mats.len());
527        let perm_ov_ptrs: Array<C, Array<C, Felt<C::F>>> = builder.array(round.mats.len());
528        let batch_dims: Array<C, DimensionsVariable<C>> = builder.array(round.mats.len());
529        let mat_alpha_pows: Array<C, Ext<_, _>> = builder.array(round.mats.len());
530        let log_batch_max_height: Usize<_> = {
531            let log_batch_max_index = to_perm_index(builder, RVar::zero());
532            let mat = builder.get(&round.mats, log_batch_max_index);
533            let domain = mat.domain;
534            builder.eval(domain.log_n + RVar::from(log_blowup))
535        };
536
537        iter_zip!(builder, round.mats, ov_ptrs, mat_alpha_pows).for_each(|ptr_vec, builder| {
538            let mat = builder.iter_ptr_get(&round.mats, ptr_vec[0]);
539            let local = builder.get(&mat.values, 0);
540            // We allocate the underlying buffer for the current `ov_ptr` here. On allocation, it is uninit, and
541            // will be written to on the first call of `fri_single_reduced_opening_eval` for this `ov_ptr`.
542            let buf = builder.array(local.len());
543            let width = buf.len();
544            builder.iter_ptr_set(&ov_ptrs, ptr_vec[1], buf);
545
546            if !builder.flags.static_only {
547                let width = width.get_var();
548                // This is dynamic only so safe to cast.
549                let width_f = builder.unsafe_cast_var_to_felt(width);
550                let bits = builder.num2bits_f(width_f, MAX_LOG_WIDTH as u32);
551                let mat_alpha_pow: Ext<_, _> = builder.eval(C::EF::ONE.cons());
552                for i in 0..MAX_LOG_WIDTH {
553                    let bit = builder.get(&bits, i);
554                    builder.if_eq(bit, RVar::one()).then(|builder| {
555                        let to_mul = builder.get(&pow_of_alpha, i);
556                        builder.assign(&mat_alpha_pow, mat_alpha_pow * to_mul);
557                    });
558                }
559                builder.iter_ptr_set(&mat_alpha_pows, ptr_vec[2], mat_alpha_pow);
560            }
561        });
562        builder
563            .range(0, round.mats.len())
564            .for_each(|i_vec, builder| {
565                let i = i_vec[0];
566                let perm_i = to_perm_index(builder, i);
567                let mat = builder.get(&round.mats, perm_i.clone());
568
569                let domain = mat.domain;
570                let dim = DimensionsVariable::<C> {
571                    log_height: builder.eval(domain.log_n + RVar::from(log_blowup)),
572                };
573                builder.set_value(&batch_dims, i, dim);
574                let perm_ov_ptr = builder.get(&ov_ptrs, perm_i);
575                // Note both `ov_ptrs` and `perm_ov_ptrs` point to the same memory.
576                builder.set_value(&perm_ov_ptrs, i, perm_ov_ptr);
577            });
578        builder.iter_ptr_set(
579            &ret,
580            ptr_vec[1],
581            RoundContext {
582                ov_ptrs,
583                perm_ov_ptrs,
584                batch_dims,
585                mat_alpha_pows,
586                log_batch_max_height,
587            },
588        );
589    });
590    ret
591}
592
593pub mod tests {
594    use std::cmp::Reverse;
595
596    use itertools::Itertools;
597    use openvm_circuit::arch::instructions::program::Program;
598    use openvm_native_compiler::{
599        asm::AsmBuilder,
600        ir::{Array, RVar, DIGEST_SIZE},
601    };
602    use openvm_stark_backend::{
603        config::{StarkGenericConfig, Val},
604        p3_challenger::{CanObserve, FieldChallenger},
605        p3_commit::{Pcs, TwoAdicMultiplicativeCoset},
606        p3_matrix::dense::RowMajorMatrix,
607    };
608    use openvm_stark_sdk::{
609        config::baby_bear_poseidon2::{default_engine, BabyBearPoseidon2Config},
610        p3_baby_bear::BabyBear,
611    };
612    use rand::rngs::OsRng;
613
614    use crate::{
615        challenger::{duplex::DuplexChallengerVariable, CanObserveDigest, FeltChallenger},
616        commit::PcsVariable,
617        digest::DigestVariable,
618        fri::{
619            types::TwoAdicPcsRoundVariable, TwoAdicFriPcsVariable,
620            TwoAdicMultiplicativeCosetVariable,
621        },
622        hints::{Hintable, InnerFriProof, InnerVal},
623        utils::const_fri_config,
624    };
625
626    #[allow(dead_code)]
627    pub fn build_test_fri_with_cols_and_log2_rows(
628        nb_cols: usize,
629        nb_log2_rows: usize,
630    ) -> (Program<BabyBear>, Vec<Vec<BabyBear>>) {
631        type SC = BabyBearPoseidon2Config;
632        type F = Val<SC>;
633        type EF = <SC as StarkGenericConfig>::Challenge;
634        type Challenger = <SC as StarkGenericConfig>::Challenger;
635        type ScPcs = <SC as StarkGenericConfig>::Pcs;
636
637        let mut rng = &mut OsRng;
638        let log_degrees = &[nb_log2_rows];
639        let engine = default_engine();
640        let pcs = engine.config.pcs();
641        let perm = engine.perm;
642
643        // Generate proof.
644        let domains_and_polys = log_degrees
645            .iter()
646            .map(|&d| {
647                (
648                    <ScPcs as Pcs<EF, Challenger>>::natural_domain_for_degree(pcs, 1 << d),
649                    RowMajorMatrix::<F>::rand(&mut rng, 1 << d, nb_cols),
650                )
651            })
652            .sorted_by_key(|(dom, _)| Reverse(dom.log_n))
653            .collect::<Vec<_>>();
654        let (commit, data) = <ScPcs as Pcs<EF, Challenger>>::commit(pcs, domains_and_polys.clone());
655        let mut challenger = Challenger::new(perm.clone());
656        challenger.observe(commit);
657        let zeta = challenger.sample_ext_element::<EF>();
658        let points = domains_and_polys
659            .iter()
660            .map(|_| vec![zeta])
661            .collect::<Vec<_>>();
662        let (opening, proof) = pcs.open(vec![(&data, points)], &mut challenger);
663
664        // Verify proof.
665        let mut challenger = Challenger::new(perm.clone());
666        challenger.observe(commit);
667        challenger.sample_ext_element::<EF>();
668        let os: Vec<(TwoAdicMultiplicativeCoset<F>, Vec<_>)> = domains_and_polys
669            .iter()
670            .zip(&opening[0])
671            .map(|((domain, _), mat_openings)| (*domain, vec![(zeta, mat_openings[0].clone())]))
672            .collect();
673        pcs.verify(vec![(commit, os.clone())], &proof, &mut challenger)
674            .unwrap();
675
676        // Test the recursive Pcs.
677        let mut builder = AsmBuilder::<F, EF>::default();
678        let config = const_fri_config(&mut builder, &engine.fri_params);
679        let pcs_var = TwoAdicFriPcsVariable { config };
680        let rounds =
681            builder.constant::<Array<_, TwoAdicPcsRoundVariable<_>>>(vec![(commit, os.clone())]);
682
683        // Test natural domain for degree.
684        for log_d_val in log_degrees.iter() {
685            let log_d = *log_d_val;
686            let domain = pcs_var.natural_domain_for_log_degree(&mut builder, RVar::from(log_d));
687
688            let domain_val =
689                <ScPcs as Pcs<EF, Challenger>>::natural_domain_for_degree(pcs, 1 << log_d_val);
690
691            let expected_domain: TwoAdicMultiplicativeCosetVariable<_> =
692                builder.constant(domain_val);
693
694            builder.assert_eq::<TwoAdicMultiplicativeCosetVariable<_>>(domain, expected_domain);
695        }
696
697        // Test proof verification.
698        let proofvar = InnerFriProof::read(&mut builder);
699        let mut challenger = DuplexChallengerVariable::new(&mut builder);
700        let commit = <[InnerVal; DIGEST_SIZE]>::from(commit).to_vec();
701        let commit = DigestVariable::Felt(builder.constant::<Array<_, _>>(commit));
702        challenger.observe_digest(&mut builder, commit);
703        challenger.sample_ext(&mut builder);
704        pcs_var.verify(
705            &mut builder,
706            rounds,
707            proofvar,
708            RVar::from(nb_log2_rows),
709            &mut challenger,
710        );
711        builder.halt();
712
713        let program = builder.compile_isa();
714        let mut witness_stream = Vec::new();
715        witness_stream.extend(proof.write());
716        (program, witness_stream)
717    }
718
719    #[test]
720    fn test_two_adic_fri_pcs_single_batch() {
721        let (program, witness) = build_test_fri_with_cols_and_log2_rows(10, 10);
722        openvm_native_circuit::execute_program(program, witness);
723    }
724}