openvm_native_recursion/fri/
two_adic_pcs.rs

1use openvm_native_compiler::prelude::*;
2use openvm_native_compiler_derive::iter_zip;
3use openvm_stark_backend::p3_field::{
4    coset::TwoAdicMultiplicativeCoset, BasedVectorSpace, PrimeCharacteristicRing, TwoAdicField,
5};
6use p3_symmetric::Hash;
7
8use super::{
9    types::{
10        DimensionsVariable, FriConfigVariable, TwoAdicPcsMatsVariable, TwoAdicPcsRoundVariable,
11    },
12    verify_batch, verify_query, NestedOpenedValues, TwoAdicMultiplicativeCosetVariable,
13};
14use crate::{
15    challenger::ChallengerVariable, commit::PcsVariable, digest::DigestVariable,
16    fri::types::FriProofVariable,
17};
18
19// The maximum two-adicity of Felt `C::F`. This means `C::F` does not have a multiplicative subgroup
20// of order 2^{MAX_TWO_ADICITY + 1}. Currently set to 27 for BabyBear.
21pub const MAX_TWO_ADICITY: usize = 27;
22
23/// Notes:
24/// 1. FieldMerkleTreeMMCS sorts traces by height in descending order when committing data.
25/// 2. **Required** that `C::F` has two-adicity <= [MAX_TWO_ADICITY]. In particular this implies
26///    that all LDE matrices have `log2(lde_height) <= MAX_TWO_ADICITY`.
27/// 3. **Required** that the maximum trace height is `2^log_max_height - 1`.
28///
29/// Reference:
30/// <https://github.com/Plonky3/Plonky3/blob/27b3127dab047e07145c38143379edec2960b3e1/merkle-tree/src/merkle_tree.rs#L53>
31/// So traces are sorted in `opening_proof`.
32///
33/// 2. FieldMerkleTreeMMCS::poseidon2 keeps the raw values in the original order. So traces are not
34///    sorted in `opened_values`.
35///
36/// Reference:
37/// <https://github.com/Plonky3/Plonky3/blob/27b3127dab047e07145c38143379edec2960b3e1/merkle-tree/src/mmcs.rs#L87>
38/// <https://github.com/Plonky3/Plonky3/blob/27b3127dab047e07145c38143379edec2960b3e1/merkle-tree/src/merkle_tree.rs#L100>
39/// <https://github.com/Plonky3/Plonky3/blob/784b7dd1fa87c1202e63350cc8182d7c5327a7af/fri/src/verifier.rs#L22>
40pub fn verify_two_adic_pcs<C: Config>(
41    builder: &mut Builder<C>,
42    config: &FriConfigVariable<C>,
43    rounds: Array<C, TwoAdicPcsRoundVariable<C>>,
44    proof: FriProofVariable<C>,
45    log_max_height: RVar<C::N>,
46    challenger: &mut impl ChallengerVariable<C>,
47) where
48    C::F: TwoAdicField,
49    C::EF: TwoAdicField,
50{
51    // Currently do not support other final poly len
52    builder.assert_var_eq(RVar::from(config.log_final_poly_len), RVar::zero());
53    // The `proof.final_poly` length is in general `2^{log_final_poly_len}`.
54    // We require `log_final_poly_len = 0`, so `proof.final_poly` has length `1`.
55    builder.assert_usize_eq(proof.final_poly.len(), RVar::one());
56    // Constant term of final poly
57    let final_poly_ct = builder.get(&proof.final_poly, 0);
58
59    let g = builder.generator();
60
61    let log_blowup = config.log_blowup;
62    iter_zip!(builder, rounds).for_each(|ptr_vec, builder| {
63        let round = builder.iter_ptr_get(&rounds, ptr_vec[0]);
64        iter_zip!(builder, round.mats).for_each(|ptr_vec, builder| {
65            let mat = builder.iter_ptr_get(&round.mats, ptr_vec[0]);
66            iter_zip!(builder, mat.values).for_each(|ptr_vec, builder| {
67                let value = builder.iter_ptr_get(&mat.values, ptr_vec[0]);
68                iter_zip!(builder, value).for_each(|ptr_vec, builder| {
69                    if builder.flags.static_only {
70                        let ext = builder.iter_ptr_get(&value, ptr_vec[0]);
71                        let arr = builder.ext2felt(ext);
72                        challenger.observe_slice(builder, arr);
73                    } else {
74                        let ptr = ptr_vec[0];
75                        for i in 0..C::EF::DIMENSION {
76                            let f: Felt<_> = builder.uninit();
77                            builder.load(
78                                f,
79                                Ptr {
80                                    address: ptr.variable(),
81                                },
82                                MemIndex {
83                                    index: RVar::from(i),
84                                    offset: 0,
85                                    size: 1,
86                                },
87                            );
88                            challenger.observe(builder, f);
89                        }
90                    }
91                });
92            });
93        });
94    });
95    let alpha = challenger.sample_ext(builder);
96    if builder.flags.static_only {
97        builder.ext_reduce_circuit(alpha);
98    }
99
100    builder.cycle_tracker_start("stage-d-verifier-verify");
101    // **ATTENTION**: always check shape of user inputs.
102    builder.assert_usize_eq(proof.query_proofs.len(), RVar::from(config.num_queries));
103    builder.assert_usize_eq(proof.commit_phase_commits.len(), log_max_height);
104    builder.assert_usize_eq(proof.commit_pow_witnesses.len(), log_max_height);
105    let betas: Array<C, Ext<C::F, C::EF>> = builder.array(log_max_height);
106    let betas_squared: Array<C, Ext<C::F, C::EF>> = builder.array(log_max_height);
107    // `i_plus_one_arr[i] = i + 1`. This is needed to add "enumerate" to `iter_zip!`
108    // There is no risk of overflow because `log_max_height` is much less than `C::N::MODULUS`
109    let i_plus_one_arr: Array<C, Usize<C::N>> = builder.array(log_max_height);
110    let i_var: Usize<C::N> = builder.eval(C::N::ZERO);
111    iter_zip!(
112        builder,
113        proof.commit_phase_commits,
114        proof.commit_pow_witnesses,
115        betas,
116        betas_squared,
117        i_plus_one_arr
118    )
119    .for_each(|ptr_vec, builder| {
120        let [comm_ptr, commit_pow_ptr, beta_ptr, beta_sq_ptr, i_plus_one_ptr] =
121            ptr_vec.try_into().unwrap();
122
123        let comm = builder.iter_ptr_get(&proof.commit_phase_commits, comm_ptr);
124        let commit_pow = builder.iter_ptr_get(&proof.commit_pow_witnesses, commit_pow_ptr);
125        challenger.observe_digest(builder, comm);
126        challenger.check_witness(builder, config.commit_proof_of_work_bits, commit_pow);
127        let sample = challenger.sample_ext(builder);
128        builder.iter_ptr_set(&betas, beta_ptr, sample);
129        builder.iter_ptr_set(&betas_squared, beta_sq_ptr, sample * sample);
130        builder.assign(&i_var, i_var.clone() + C::N::ONE);
131        // Note: this is a deep clone when `builder.flags.static == true`
132        let i_plus_one_clone: Usize<C::N> = builder.eval(i_var.clone());
133        builder.iter_ptr_set(&i_plus_one_arr, i_plus_one_ptr, i_plus_one_clone);
134    });
135
136    iter_zip!(builder, proof.final_poly).for_each(|ptr_vec, builder| {
137        let final_poly_elem = builder.iter_ptr_get(&proof.final_poly, ptr_vec[0]);
138        let final_poly_elem_felts = builder.ext2felt(final_poly_elem);
139        challenger.observe_slice(builder, final_poly_elem_felts);
140    });
141
142    challenger.check_witness(
143        builder,
144        config.query_proof_of_work_bits,
145        proof.query_pow_witness,
146    );
147
148    let log_max_lde_height = builder.eval_expr(log_max_height + RVar::from(log_blowup));
149    // tag_exp is a shared buffer.
150    let tag_exp: Array<C, Felt<C::F>> = builder.array(log_max_lde_height);
151    let w = config.get_two_adic_generator(builder, log_max_lde_height);
152    let max_gen_pow = config.get_two_adic_generator(builder, 1);
153    let one_var: Felt<C::F> = builder.eval(C::F::ONE);
154
155    builder.cycle_tracker_start("pre-compute-rounds-context");
156    let rounds_context = compute_rounds_context(builder, &rounds, log_blowup, alpha);
157    // Only used in static mode.
158    let alpha_pows = if builder.flags.static_only {
159        let max_width = get_max_matrix_width(builder, &rounds);
160        let mut ret = Vec::with_capacity(max_width + 1);
161        ret.push(C::EF::ONE.cons());
162        for i in 1..=max_width {
163            let curr = builder.eval(ret[i - 1].clone() * alpha);
164            builder.ext_reduce_circuit(curr);
165            ret.push(curr.into());
166        }
167        ret
168    } else {
169        vec![]
170    };
171    builder.cycle_tracker_end("pre-compute-rounds-context");
172
173    // Accumulator of the reduced opening sums, reset per query. The array `ro` is indexed by
174    // log_height.
175    let ro: Array<C, Ext<C::F, C::EF>> = builder.array(MAX_TWO_ADICITY + 1);
176    let alpha_pow: Array<C, Ext<C::F, C::EF>> = builder.array(MAX_TWO_ADICITY + 1);
177
178    iter_zip!(builder, proof.query_proofs).for_each(|ptr_vec, builder| {
179        let query_proof = builder.iter_ptr_get(&proof.query_proofs, ptr_vec[0]);
180        let index_bits = challenger.sample_bits(builder, log_max_lde_height);
181
182        // We reset the reduced opening accumulator at the start of each query.
183        // We describe what `ro[log_height]` computes per query in pseudo-code, where `log_height`
184        // is log2 of the size of the LDE domain: ro[log_height] = 0
185        // alpha_pow[log_height] = 1
186        // for round in rounds:
187        //   for mat in round.mats where (mat.domain.log_n + log_blowup == log_height): //
188        // preserving order of round.mats      // g is generator of F
189        //      // w_{log_height} is generator of subgroup of F of order 2^log_height
190        //      x = g * w_{log_height}^{reverse_bits(index >> (log_max_height - log_height),
191        // log_height)}      // reverse_bits(x, bits) takes an unsigned integer x with
192        // `bits` bits and returns the unsigned integer with the bits of x reversed.      //
193        // x is a rotated evaluation point in a coset of the LDE domain.      ps_at_x =
194        // [claimed evaluation of p at x for each polynomial p corresponding to column of mat]
195        //      // ps_at_x is array of Felt
196        //      for (z, ps_at_z) in zip(mat.points, mat.values):
197        //        // z is an out of domain point in Ext. There may be multiple per round to account
198        // for rotations in AIR constraints.        // ps_at_z is array of Ext for [claimed
199        // evaluation of p at z for each polyomial p corresponding to column of mat]
200        //        for (p_at_x, p_at_z) in zip(ps_at_x, ps_at_z):
201        //          ro[log_height] += alpha_pow[log_height] * (p_at_x - p_at_z) / (x - z)
202        //          alpha_pow[log_height] *= alpha
203        //
204        // The final value of ro[log_height] is the reduced opening value for log_height.
205        if builder.flags.static_only {
206            for j in 0..=MAX_TWO_ADICITY {
207                // ATTENTION: don't use set_value here, Fixed will share the same variable.
208                builder.set(&ro, j, C::EF::ZERO.cons());
209                builder.set(&alpha_pow, j, C::EF::ONE.cons());
210            }
211        } else {
212            let zero_ef = builder.eval(C::EF::ZERO.cons());
213            let one_ef = builder.eval(C::EF::ONE.cons());
214            for j in 0..=MAX_TWO_ADICITY {
215                // Use set_value here to save a copy.
216                builder.set_value(&ro, j, zero_ef);
217                builder.set_value(&alpha_pow, j, one_ef);
218            }
219        }
220        // **ATTENTION**: always check shape of user inputs.
221        builder.assert_usize_eq(query_proof.input_proof.len(), rounds.len());
222
223        // Pre-compute tag_exp
224        builder.cycle_tracker_start("cache-generator-powers");
225        {
226            // truncate index_bits to log_max_height
227            let index_bits_truncated = index_bits.slice(builder, 0, log_max_lde_height);
228
229            // b = index_bits
230            // w = generator of order 2^log_max_height
231            // we first compute `w ** (b[0] * 2^(log_max_height - 1) + ... + b[log_max_height - 1])`
232            // using a square-and-multiply algorithm.
233            let res = builder.exp_bits_big_endian(w, &index_bits_truncated);
234
235            // we now compute:
236            // tag_exp[log_max_height - i] = g * w ** (b[log_max_height - i] * 2^(log_max_height -
237            // 1) + ... + b[log_max_height - 1] * 2^(log_max_height - i))
238            // using a square-and-divide algorithm.
239            // g * res is tag_exp[0]
240            // `tag_exp` is used below as a rotated evaluation point in a coset of the LDE domain.
241            iter_zip!(builder, index_bits_truncated, tag_exp).for_each(|ptr_vec, builder| {
242                builder.iter_ptr_set(&tag_exp, ptr_vec[1], g * res);
243
244                let bit = builder.iter_ptr_get(&index_bits_truncated, ptr_vec[0]);
245                let div = builder.select_f(bit, max_gen_pow, one_var);
246                builder.assign(&res, res / div);
247                builder.assign(&res, res * res);
248            });
249        };
250        builder.cycle_tracker_end("cache-generator-powers");
251
252        iter_zip!(builder, query_proof.input_proof, rounds, rounds_context).for_each(
253            |ptr_vec, builder| {
254                let batch_opening = builder.iter_ptr_get(&query_proof.input_proof, ptr_vec[0]);
255                let round = builder.iter_ptr_get(&rounds, ptr_vec[1]);
256                let round_context = builder.iter_ptr_get(&rounds_context, ptr_vec[2]);
257
258                let batch_commit = round.batch_commit;
259                let mats = round.mats;
260                let RoundContext {
261                    ov_ptrs,
262                    perm_ov_ptrs,
263                    batch_dims,
264                    mat_alpha_pows,
265                    log_batch_max_height,
266                } = round_context;
267
268                // **ATTENTION**: always check shape of user inputs.
269                builder.assert_usize_eq(ov_ptrs.len(), mats.len());
270
271                let hint_id = batch_opening.opened_values.id.clone();
272                // For static to track the offset in the hint space.
273                let mut hint_offset = 0;
274                builder.cycle_tracker_start("compute-reduced-opening");
275                iter_zip!(builder, ov_ptrs, mats, mat_alpha_pows).for_each(|ptr_vec, builder| {
276                    let mat_opening = builder.iter_ptr_get(&ov_ptrs, ptr_vec[0]);
277                    let mat = builder.iter_ptr_get(&mats, ptr_vec[1]);
278                    let mat_alpha_pow = if builder.flags.static_only {
279                        builder.uninit()
280                    } else {
281                        builder.iter_ptr_get(&mat_alpha_pows, ptr_vec[2])
282                    };
283                    let mat_points = mat.points;
284                    let mat_values = mat.values;
285                    let log2_domain_size = mat.domain.log_n;
286                    let log_height = builder.eval_expr(log2_domain_size + RVar::from(log_blowup));
287
288                    let cur_ro = builder.get(&ro, log_height);
289                    let cur_alpha_pow = builder.get(&alpha_pow, log_height);
290
291                    builder.cycle_tracker_start("exp-reverse-bits-len");
292                    let height_idx = builder.eval_expr(log_max_lde_height - log_height);
293                    let x = builder.get(&tag_exp, height_idx);
294                    builder.cycle_tracker_end("exp-reverse-bits-len");
295
296                    let is_init: Usize<C::N> = builder.eval(C::N::ZERO);
297                    iter_zip!(builder, mat_points, mat_values).for_each(|ptr_vec, builder| {
298                        let z: Ext<C::F, C::EF> = builder.iter_ptr_get(&mat_points, ptr_vec[0]);
299                        let ps_at_z = builder.iter_ptr_get(&mat_values, ptr_vec[1]);
300
301                        builder.cycle_tracker_start("single-reduced-opening-eval");
302
303                        if builder.flags.static_only {
304                            let n: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
305                            let width = ps_at_z.len().value();
306                            if is_init.value() == 0 {
307                                let mat_opening_vals = {
308                                    let witness_refs = builder.get_witness_refs(hint_id.clone());
309                                    let start = hint_offset;
310                                    witness_refs[start..start + width].to_vec()
311                                };
312                                for (i, v) in mat_opening_vals.into_iter().enumerate() {
313                                    builder.set_value(&mat_opening, i, v.into());
314                                }
315                            }
316                            for (t, alpha_pow) in alpha_pows.iter().take(width).enumerate() {
317                                let p_at_x = builder.get(&mat_opening, t);
318                                let p_at_z = builder.get(&ps_at_z, t);
319                                builder.assign(&n, n + (p_at_z - p_at_x) * alpha_pow.clone());
320                            }
321                            builder.assign(&cur_ro, cur_ro + n / (z - x) * cur_alpha_pow);
322                            builder
323                                .assign(&cur_alpha_pow, cur_alpha_pow * alpha_pows[width].clone());
324                        } else {
325                            let mat_ro = builder.fri_single_reduced_opening_eval(
326                                alpha,
327                                hint_id.get_var(),
328                                is_init.get_var(),
329                                &mat_opening,
330                                &ps_at_z,
331                            );
332                            builder.assign(&cur_ro, cur_ro + (mat_ro * cur_alpha_pow / (z - x)));
333                            builder.assign(&cur_alpha_pow, cur_alpha_pow * mat_alpha_pow);
334                        }
335                        // The buffer `mat_opening` has now been written to, so we set `is_init` to
336                        // 1.
337                        builder.assign(&is_init, C::N::ONE);
338                        builder.cycle_tracker_end("single-reduced-opening-eval");
339                    });
340                    if builder.flags.static_only {
341                        hint_offset += mat_opening.len().value();
342                    }
343                    builder.set_value(&ro, log_height, cur_ro);
344                    builder.set_value(&alpha_pow, log_height, cur_alpha_pow);
345                });
346                builder.cycle_tracker_end("compute-reduced-opening");
347
348                let bits_reduced: Usize<_> =
349                    builder.eval(log_max_lde_height - log_batch_max_height);
350                let index_bits_shifted_v1 = index_bits.shift(builder, bits_reduced);
351
352                builder.cycle_tracker_start("verify-batch");
353                verify_batch::<C>(
354                    builder,
355                    &batch_commit,
356                    batch_dims,
357                    index_bits_shifted_v1,
358                    &NestedOpenedValues::Felt(perm_ov_ptrs),
359                    &batch_opening.opening_proof,
360                );
361                builder.cycle_tracker_end("verify-batch");
362            },
363        );
364
365        // Note[jpw]: we do not need to assert `ro[log_blowup] = 0` here because we include
366        // `ro[log_blowup]` in the `verify_query` low-degree test. See comments therein.
367
368        let folded_eval = verify_query(
369            builder,
370            config,
371            &proof.commit_phase_commits,
372            &index_bits,
373            &query_proof,
374            &betas,
375            &betas_squared,
376            &ro,
377            log_max_lde_height,
378            &i_plus_one_arr,
379        );
380
381        builder.assert_ext_eq(folded_eval, final_poly_ct);
382    });
383    builder.cycle_tracker_end("stage-d-verifier-verify");
384}
385
386impl<C: Config> FromConstant<C> for TwoAdicPcsRoundVariable<C>
387where
388    C::F: TwoAdicField,
389{
390    type Constant = (
391        Hash<C::F, C::F, DIGEST_SIZE>,
392        Vec<(TwoAdicMultiplicativeCoset<C::F>, Vec<(C::EF, Vec<C::EF>)>)>,
393    );
394
395    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
396        let (commit_val, domains_and_openings_val) = value;
397
398        // Allocate the commitment.
399        let commit = builder.dyn_array::<Felt<_>>(DIGEST_SIZE);
400        let commit_val: [C::F; DIGEST_SIZE] = commit_val.into();
401        for (i, f) in commit_val.into_iter().enumerate() {
402            builder.set(&commit, i, f);
403        }
404
405        let mats = builder
406            .dyn_array::<TwoAdicPcsMatsVariable<C>>(RVar::from(domains_and_openings_val.len()));
407
408        for (i, (domain, opening)) in domains_and_openings_val.into_iter().enumerate() {
409            let domain = builder.constant::<TwoAdicMultiplicativeCosetVariable<_>>(domain);
410
411            let points_val = opening.iter().map(|(p, _)| *p).collect::<Vec<_>>();
412            let values_val = opening.iter().map(|(_, v)| v.clone()).collect::<Vec<_>>();
413            let points: Array<_, Ext<_, _>> = builder.dyn_array(points_val.len());
414            for (j, point) in points_val.into_iter().enumerate() {
415                let el: Ext<_, _> = builder.eval(point.cons());
416                builder.set_value(&points, j, el);
417            }
418            let values: Array<_, Array<_, Ext<_, _>>> = builder.dyn_array(values_val.len());
419            for (j, val) in values_val.into_iter().enumerate() {
420                let tmp = builder.dyn_array(val.len());
421                for (k, v) in val.into_iter().enumerate() {
422                    let el: Ext<_, _> = builder.eval(v.cons());
423                    builder.set_value(&tmp, k, el);
424                }
425                builder.set_value(&values, j, tmp);
426            }
427            let mat = TwoAdicPcsMatsVariable {
428                domain,
429                points,
430                values,
431            };
432            builder.set_value(&mats, i, mat);
433        }
434
435        Self {
436            batch_commit: DigestVariable::Felt(commit),
437            mats,
438            permutation: builder.dyn_array(0),
439        }
440    }
441}
442
443#[derive(Clone)]
444pub struct TwoAdicFriPcsVariable<C: Config> {
445    pub config: FriConfigVariable<C>,
446}
447
448impl<C: Config> PcsVariable<C> for TwoAdicFriPcsVariable<C>
449where
450    C::F: TwoAdicField,
451    C::EF: TwoAdicField,
452{
453    type Domain = TwoAdicMultiplicativeCosetVariable<C>;
454
455    type Commitment = DigestVariable<C>;
456
457    type Proof = FriProofVariable<C>;
458
459    fn natural_domain_for_log_degree(
460        &self,
461        builder: &mut Builder<C>,
462        log_degree: RVar<C::N>,
463    ) -> Self::Domain {
464        self.config.get_subgroup(builder, log_degree)
465    }
466
467    fn verify(
468        &self,
469        builder: &mut Builder<C>,
470        rounds: Array<C, TwoAdicPcsRoundVariable<C>>,
471        proof: Self::Proof,
472        log_max_height: RVar<C::N>,
473        challenger: &mut impl ChallengerVariable<C>,
474    ) {
475        verify_two_adic_pcs(
476            builder,
477            &self.config,
478            rounds,
479            proof,
480            log_max_height,
481            challenger,
482        )
483    }
484}
485
486fn get_max_matrix_width<C: Config>(
487    builder: &mut Builder<C>,
488    rounds: &Array<C, TwoAdicPcsRoundVariable<C>>,
489) -> usize {
490    let mut ret = 0;
491    for i in 0..rounds.len().value() {
492        let round = builder.get(rounds, i);
493        for j in 0..round.mats.len().value() {
494            let mat = builder.get(&round.mats, j);
495            let local = builder.get(&mat.values, 0);
496            ret = ret.max(local.len().value());
497        }
498    }
499    ret
500}
501
502#[derive(DslVariable, Clone)]
503struct RoundContext<C: Config> {
504    /// Opened values buffer.
505    ov_ptrs: Array<C, Array<C, Felt<C::F>>>,
506    /// Permuted opened values buffer.
507    perm_ov_ptrs: Array<C, Array<C, Felt<C::F>>>,
508    /// Permuted matrix dimensions.
509    batch_dims: Array<C, DimensionsVariable<C>>,
510    /// Alpha pows for each matrix.
511    mat_alpha_pows: Array<C, Ext<C::F, C::EF>>,
512    /// Max height in the matrices.
513    log_batch_max_height: Usize<C::N>,
514}
515
516fn compute_rounds_context<C: Config>(
517    builder: &mut Builder<C>,
518    rounds: &Array<C, TwoAdicPcsRoundVariable<C>>,
519    log_blowup: usize,
520    alpha: Ext<C::F, C::EF>,
521) -> Array<C, RoundContext<C>> {
522    let ret: Array<C, RoundContext<C>> = builder.array(rounds.len());
523
524    // This maximum is safe because the log width of any matrix in an AIR must fit
525    // within a single field element.
526    const MAX_LOG_WIDTH: usize = 31;
527    let pow_of_alpha: Array<C, Ext<_, _>> = builder.array(MAX_LOG_WIDTH);
528    if !builder.flags.static_only {
529        let current: Ext<_, _> = builder.eval(alpha);
530        for i in 0..MAX_LOG_WIDTH {
531            builder.set(&pow_of_alpha, i, current);
532            builder.assign(&current, current * current);
533        }
534    }
535
536    iter_zip!(builder, rounds, ret).for_each(|ptr_vec, builder| {
537        let round = builder.iter_ptr_get(rounds, ptr_vec[0]);
538        let permutation = round.permutation;
539        let to_perm_index = |builder: &mut Builder<_>, k: RVar<_>| {
540            // Always no permutation in static mode
541            if builder.flags.static_only {
542                builder.eval(k)
543            } else {
544                let ret: Usize<_> = builder.uninit();
545                builder.if_eq(permutation.len(), RVar::zero()).then_or_else(
546                    |builder| {
547                        builder.assign(&ret, k);
548                    },
549                    |builder| {
550                        let value = builder.get(&permutation, k);
551                        builder.assign(&ret, value);
552                    },
553                );
554                ret
555            }
556        };
557
558        let ov_ptrs: Array<C, Array<C, Felt<C::F>>> = builder.array(round.mats.len());
559        let perm_ov_ptrs: Array<C, Array<C, Felt<C::F>>> = builder.array(round.mats.len());
560        let batch_dims: Array<C, DimensionsVariable<C>> = builder.array(round.mats.len());
561        let mat_alpha_pows: Array<C, Ext<_, _>> = builder.array(round.mats.len());
562        let log_batch_max_height: Usize<_> = {
563            let log_batch_max_index = to_perm_index(builder, RVar::zero());
564            let mat = builder.get(&round.mats, log_batch_max_index);
565            let domain = mat.domain;
566            builder.eval(domain.log_n + RVar::from(log_blowup))
567        };
568
569        iter_zip!(builder, round.mats, ov_ptrs, mat_alpha_pows).for_each(|ptr_vec, builder| {
570            let mat = builder.iter_ptr_get(&round.mats, ptr_vec[0]);
571            let local = builder.get(&mat.values, 0);
572            // We allocate the underlying buffer for the current `ov_ptr` here. On allocation, it is
573            // uninit, and will be written to on the first call of
574            // `fri_single_reduced_opening_eval` for this `ov_ptr`.
575            let buf = builder.array(local.len());
576            let width = buf.len();
577            builder.iter_ptr_set(&ov_ptrs, ptr_vec[1], buf);
578
579            if !builder.flags.static_only {
580                let width = width.get_var();
581                // This is dynamic only so safe to cast.
582                let width_f = builder.unsafe_cast_var_to_felt(width);
583                let bits = builder.num2bits_f(width_f, MAX_LOG_WIDTH as u32);
584                let mat_alpha_pow: Ext<_, _> = builder.eval(C::EF::ONE.cons());
585                for i in 0..MAX_LOG_WIDTH {
586                    let bit = builder.get(&bits, i);
587                    builder.if_eq(bit, RVar::one()).then(|builder| {
588                        let to_mul = builder.get(&pow_of_alpha, i);
589                        builder.assign(&mat_alpha_pow, mat_alpha_pow * to_mul);
590                    });
591                }
592                builder.iter_ptr_set(&mat_alpha_pows, ptr_vec[2], mat_alpha_pow);
593            }
594        });
595        builder
596            .range(0, round.mats.len())
597            .for_each(|i_vec, builder| {
598                let i = i_vec[0];
599                let perm_i = to_perm_index(builder, i);
600                let mat = builder.get(&round.mats, perm_i.clone());
601
602                let domain = mat.domain;
603                let dim = DimensionsVariable::<C> {
604                    log_height: builder.eval(domain.log_n + RVar::from(log_blowup)),
605                };
606                builder.set_value(&batch_dims, i, dim);
607                let perm_ov_ptr = builder.get(&ov_ptrs, perm_i);
608                // Note both `ov_ptrs` and `perm_ov_ptrs` point to the same memory.
609                builder.set_value(&perm_ov_ptrs, i, perm_ov_ptr);
610            });
611        builder.iter_ptr_set(
612            &ret,
613            ptr_vec[1],
614            RoundContext {
615                ov_ptrs,
616                perm_ov_ptrs,
617                batch_dims,
618                mat_alpha_pows,
619                log_batch_max_height,
620            },
621        );
622    });
623    ret
624}
625
626pub mod tests {
627    use std::cmp::Reverse;
628
629    use itertools::Itertools;
630    use openvm_circuit::arch::instructions::program::Program;
631    use openvm_native_compiler::{
632        asm::AsmBuilder,
633        conversion::CompilerOptions,
634        ir::{Array, RVar, DIGEST_SIZE},
635    };
636    use openvm_stark_backend::{
637        config::{StarkGenericConfig, Val},
638        engine::StarkEngine,
639        p3_challenger::{CanObserve, FieldChallenger},
640        p3_commit::Pcs,
641        p3_field::coset::TwoAdicMultiplicativeCoset,
642        p3_matrix::dense::RowMajorMatrix,
643    };
644    use openvm_stark_sdk::{
645        config::baby_bear_poseidon2::{default_engine, BabyBearPoseidon2Config},
646        p3_baby_bear::BabyBear,
647    };
648    use rand::{rngs::StdRng, SeedableRng};
649
650    use crate::{
651        challenger::{duplex::DuplexChallengerVariable, CanObserveDigest, FeltChallenger},
652        commit::PcsVariable,
653        digest::DigestVariable,
654        fri::{
655            types::TwoAdicPcsRoundVariable, TwoAdicFriPcsVariable,
656            TwoAdicMultiplicativeCosetVariable,
657        },
658        hints::{Hintable, InnerFriProof, InnerVal},
659        utils::const_fri_config,
660    };
661
662    pub fn build_test_fri_with_cols_and_log2_rows(
663        nb_cols: usize,
664        nb_log2_rows: usize,
665    ) -> (Program<BabyBear>, Vec<Vec<BabyBear>>) {
666        type SC = BabyBearPoseidon2Config;
667        type F = Val<SC>;
668        type EF = <SC as StarkGenericConfig>::Challenge;
669        type Challenger = <SC as StarkGenericConfig>::Challenger;
670        type ScPcs = <SC as StarkGenericConfig>::Pcs;
671
672        let mut rng = StdRng::seed_from_u64(0);
673        let log_degrees = &[nb_log2_rows];
674        let engine = default_engine();
675        let pcs = engine.config().pcs();
676        let perm = engine.perm.clone();
677
678        // Generate proof.
679        let domains_and_polys = log_degrees
680            .iter()
681            .map(|&d| {
682                (
683                    <ScPcs as Pcs<EF, Challenger>>::natural_domain_for_degree(pcs, 1 << d),
684                    RowMajorMatrix::<F>::rand(&mut rng, 1 << d, nb_cols),
685                )
686            })
687            .sorted_by_key(|(dom, _)| Reverse(dom.log_size()))
688            .collect::<Vec<_>>();
689        let (commit, data) = <ScPcs as Pcs<EF, Challenger>>::commit(pcs, domains_and_polys.clone());
690        let mut challenger = Challenger::new(perm.clone());
691        challenger.observe(commit);
692        let zeta = challenger.sample_algebra_element::<EF>();
693        let points = domains_and_polys
694            .iter()
695            .map(|_| vec![zeta])
696            .collect::<Vec<_>>();
697        let (opening, proof) = pcs.open(vec![(&data, points)], &mut challenger);
698
699        // Verify proof.
700        let mut challenger = Challenger::new(perm.clone());
701        challenger.observe(commit);
702        let _ = challenger.sample_algebra_element::<EF>();
703        let os: Vec<(TwoAdicMultiplicativeCoset<F>, Vec<_>)> = domains_and_polys
704            .iter()
705            .zip(&opening[0])
706            .map(|((domain, _), mat_openings)| (*domain, vec![(zeta, mat_openings[0].clone())]))
707            .collect();
708        pcs.verify(vec![(commit, os.clone())], &proof, &mut challenger)
709            .unwrap();
710
711        // Test the recursive Pcs.
712        let mut builder = AsmBuilder::<F, EF>::default();
713        let config = const_fri_config(&mut builder, &engine.fri_params);
714        let pcs_var = TwoAdicFriPcsVariable { config };
715        let rounds =
716            builder.constant::<Array<_, TwoAdicPcsRoundVariable<_>>>(vec![(commit, os.clone())]);
717
718        // Test natural domain for degree.
719        for log_d_val in log_degrees.iter() {
720            let log_d = *log_d_val;
721            let domain = pcs_var.natural_domain_for_log_degree(&mut builder, RVar::from(log_d));
722
723            let domain_val =
724                <ScPcs as Pcs<EF, Challenger>>::natural_domain_for_degree(pcs, 1 << log_d_val);
725
726            let expected_domain: TwoAdicMultiplicativeCosetVariable<_> =
727                builder.constant(domain_val);
728
729            builder.assert_eq::<TwoAdicMultiplicativeCosetVariable<_>>(domain, expected_domain);
730        }
731
732        // Test proof verification.
733        let proofvar = InnerFriProof::read(&mut builder);
734        let mut challenger = DuplexChallengerVariable::new(&mut builder);
735        let commit = <[InnerVal; DIGEST_SIZE]>::from(commit).to_vec();
736        let commit = DigestVariable::Felt(builder.constant::<Array<_, _>>(commit));
737        challenger.observe_digest(&mut builder, commit);
738        challenger.sample_ext(&mut builder);
739        pcs_var.verify(
740            &mut builder,
741            rounds,
742            proofvar,
743            RVar::from(nb_log2_rows),
744            &mut challenger,
745        );
746        builder.halt();
747
748        let program =
749            builder.compile_isa_with_options(CompilerOptions::default().with_cycle_tracker());
750        let mut witness_stream = Vec::new();
751        witness_stream.extend(proof.write());
752        (program, witness_stream)
753    }
754
755    #[test]
756    fn test_two_adic_fri_pcs_single_batch() {
757        let (program, witness) = build_test_fri_with_cols_and_log2_rows(10, 10);
758        openvm_native_circuit::execute_program(program, witness);
759    }
760}