openvm_native_recursion/fri/
domain.rs

1use openvm_native_compiler::prelude::*;
2use openvm_stark_backend::{
3    p3_commit::{LagrangeSelectors, TwoAdicMultiplicativeCoset},
4    p3_field::{Field, FieldAlgebra, TwoAdicField},
5};
6
7use super::types::FriConfigVariable;
8use crate::commit::PolynomialSpaceVariable;
9
10/// Reference: [`openvm_stark_backend::p3_commit::TwoAdicMultiplicativeCoset`]
11#[derive(DslVariable, Clone)]
12pub struct TwoAdicMultiplicativeCosetVariable<C: Config> {
13    pub log_n: Usize<C::N>,
14    pub shift: Felt<C::F>,
15    pub g: Felt<C::F>,
16}
17
18impl<C: Config> TwoAdicMultiplicativeCosetVariable<C> {
19    pub fn first_point(&self) -> Felt<C::F> {
20        self.shift
21    }
22
23    pub fn gen(&self) -> Felt<C::F> {
24        self.g
25    }
26}
27
28impl<C: Config> FromConstant<C> for TwoAdicMultiplicativeCosetVariable<C>
29where
30    C::F: TwoAdicField,
31{
32    type Constant = TwoAdicMultiplicativeCoset<C::F>;
33
34    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
35        let g_val = C::F::two_adic_generator(value.log_n);
36        TwoAdicMultiplicativeCosetVariable::<C> {
37            // builder.eval is necessary to assign a variable in the dynamic mode.
38            log_n: builder.eval(RVar::from(value.log_n)),
39            shift: builder.eval(value.shift),
40            g: builder.eval(g_val),
41        }
42    }
43}
44
45impl<C: Config> PolynomialSpaceVariable<C> for TwoAdicMultiplicativeCosetVariable<C>
46where
47    C::F: TwoAdicField,
48{
49    type Constant = TwoAdicMultiplicativeCoset<C::F>;
50
51    fn next_point(
52        &self,
53        builder: &mut Builder<C>,
54        point: Ext<<C as Config>::F, <C as Config>::EF>,
55    ) -> Ext<<C as Config>::F, <C as Config>::EF> {
56        builder.eval(point * self.gen())
57    }
58
59    fn selectors_at_point(
60        &self,
61        builder: &mut Builder<C>,
62        point: Ext<<C as Config>::F, <C as Config>::EF>,
63    ) -> LagrangeSelectors<Ext<<C as Config>::F, <C as Config>::EF>> {
64        let unshifted_point: Ext<_, _> = builder.eval(point * self.shift.inverse());
65        let z_h_expr =
66            builder.exp_power_of_2_v::<Ext<_, _>>(unshifted_point, self.log_n.clone()) - C::EF::ONE;
67        let z_h: Ext<_, _> = builder.eval(z_h_expr);
68
69        LagrangeSelectors {
70            is_first_row: builder.eval(z_h / (unshifted_point - C::EF::ONE)),
71            is_last_row: builder.eval(z_h / (unshifted_point - self.gen().inverse())),
72            is_transition: builder.eval(unshifted_point - self.gen().inverse()),
73            inv_zeroifier: builder.eval(z_h.inverse()),
74        }
75    }
76
77    fn zp_at_point(
78        &self,
79        builder: &mut Builder<C>,
80        point: Ext<<C as Config>::F, <C as Config>::EF>,
81    ) -> Ext<<C as Config>::F, <C as Config>::EF> {
82        let unshifted_power =
83            builder.exp_power_of_2_v::<Ext<_, _>>(point * self.shift.inverse(), self.log_n.clone());
84        builder.eval(unshifted_power - C::EF::ONE)
85    }
86
87    fn split_domains(
88        &self,
89        builder: &mut Builder<C>,
90        log_num_chunks: impl Into<RVar<C::N>>,
91        num_chunks: impl Into<RVar<C::N>>,
92    ) -> Array<C, Self> {
93        let log_num_chunks = log_num_chunks.into();
94        let num_chunks = num_chunks.into();
95        let log_n = builder.eval_expr(self.log_n.clone() - log_num_chunks);
96
97        let g_dom = self.gen();
98        let g = builder.exp_power_of_2_v::<Felt<C::F>>(g_dom, log_num_chunks);
99
100        let domain_power: Felt<_> = builder.eval(C::F::ONE);
101
102        let domains = builder.array(num_chunks);
103
104        builder.range(0, num_chunks).for_each(|i_vec, builder| {
105            let log_n = builder.eval(log_n);
106            let domain = TwoAdicMultiplicativeCosetVariable {
107                log_n,
108                shift: builder.eval(self.shift * domain_power),
109                g,
110            };
111            // ATTENTION: here must use `builder.set_value`. `builder.set` will convert `Usize::Const`
112            // to `Usize::Var` because it calls `builder.eval`.
113            builder.set_value(&domains, i_vec[0], domain);
114            builder.assign(&domain_power, domain_power * g_dom);
115        });
116
117        domains
118    }
119
120    fn split_domains_const(&self, builder: &mut Builder<C>, log_num_chunks: usize) -> Vec<Self> {
121        let num_chunks = 1 << log_num_chunks;
122        let log_n: Usize<_> =
123            builder.eval(self.log_n.clone() - C::N::from_canonical_usize(log_num_chunks));
124
125        let g_dom = self.gen();
126        let g = builder.exp_power_of_2_v::<Felt<C::F>>(g_dom, log_num_chunks);
127
128        let domain_power: Felt<_> = builder.eval(C::F::ONE);
129        let mut domains = vec![];
130
131        for _ in 0..num_chunks {
132            domains.push(TwoAdicMultiplicativeCosetVariable {
133                log_n: log_n.clone(),
134                shift: builder.eval(self.shift * domain_power),
135                g,
136            });
137            builder.assign(&domain_power, domain_power * g_dom);
138        }
139        domains
140    }
141
142    fn create_disjoint_domain(
143        &self,
144        builder: &mut Builder<C>,
145        log_degree: RVar<<C as Config>::N>,
146        config: Option<FriConfigVariable<C>>,
147    ) -> Self {
148        let domain = config.unwrap().get_subgroup(builder, log_degree);
149        TwoAdicMultiplicativeCosetVariable {
150            log_n: domain.log_n,
151            shift: builder.eval(self.shift * C::F::GENERATOR),
152            g: domain.g,
153        }
154    }
155}
156
157#[cfg(test)]
158pub(crate) mod tests {
159    use openvm_native_circuit::execute_program;
160    use openvm_native_compiler::asm::AsmBuilder;
161    use openvm_stark_backend::{
162        config::{Domain, StarkGenericConfig, Val},
163        p3_commit::{Pcs, PolynomialSpace},
164        p3_field::PrimeField,
165    };
166    use openvm_stark_sdk::config::{
167        baby_bear_poseidon2::{config_from_perm, default_perm, BabyBearPoseidon2Config},
168        fri_params::SecurityParameters,
169    };
170    use rand::{thread_rng, Rng};
171
172    use super::*;
173    use crate::utils::const_fri_config;
174
175    pub(crate) fn domain_assertions<F: TwoAdicField + PrimeField, C: Config<N = F, F = F>>(
176        builder: &mut Builder<C>,
177        domain: &TwoAdicMultiplicativeCosetVariable<C>,
178        domain_val: &TwoAdicMultiplicativeCoset<F>,
179        zeta_val: C::EF,
180    ) {
181        // Assert the domain parameters are the same.
182        builder.assert_var_eq(
183            domain.log_n.clone(),
184            F::from_canonical_usize(domain_val.log_n),
185        );
186        builder.assert_felt_eq(domain.shift, domain_val.shift);
187
188        // Get a random point.
189        let zeta: Ext<_, _> = builder.eval(zeta_val.cons());
190
191        // Compare the selector values of the reference and the builder.
192        let sels_expected = domain_val.selectors_at_point(zeta_val);
193        let sels = domain.selectors_at_point(builder, zeta);
194        builder.assert_ext_eq(sels.is_first_row, sels_expected.is_first_row.cons());
195        builder.assert_ext_eq(sels.is_last_row, sels_expected.is_last_row.cons());
196        builder.assert_ext_eq(sels.is_transition, sels_expected.is_transition.cons());
197
198        let zp_val = domain_val.zp_at_point(zeta_val);
199        let zp = domain.zp_at_point(builder, zeta);
200        builder.assert_ext_eq(zp, zp_val.cons());
201    }
202
203    fn test_domain_impl(static_only: bool) {
204        type SC = BabyBearPoseidon2Config;
205        type F = Val<SC>;
206        type EF = <SC as StarkGenericConfig>::Challenge;
207        type Challenger = <SC as StarkGenericConfig>::Challenger;
208        type ScPcs = <SC as StarkGenericConfig>::Pcs;
209
210        let mut rng = thread_rng();
211        let security_params = SecurityParameters::standard_fast();
212        let config = config_from_perm(&default_perm(), security_params.clone());
213        let pcs = config.pcs();
214        let natural_domain_for_degree = |degree: usize| -> Domain<SC> {
215            <ScPcs as Pcs<EF, Challenger>>::natural_domain_for_degree(pcs, degree)
216        };
217
218        // Initialize a builder.
219        let mut builder = AsmBuilder::<F, EF>::default();
220        builder.flags.static_only = static_only;
221
222        let config_var = const_fri_config(&mut builder, &security_params.fri_params);
223        for i in 0..5 {
224            let log_d_val = 10 + i;
225
226            let log_quotient_degree = 2;
227
228            // Initialize a reference domain.
229            let domain_val = natural_domain_for_degree(1 << log_d_val);
230            let domain = builder.constant(domain_val);
231
232            // builder.assert_felt_eq(domain.shift, domain_val.shift);
233            let zeta_val = rng.gen::<EF>();
234            domain_assertions(&mut builder, &domain, &domain_val, zeta_val);
235
236            // Try a shifted domain.
237            let disjoint_domain_val =
238                domain_val.create_disjoint_domain(1 << (log_d_val + log_quotient_degree));
239            let disjoint_domain = builder.constant(disjoint_domain_val);
240            domain_assertions(
241                &mut builder,
242                &disjoint_domain,
243                &disjoint_domain_val,
244                zeta_val,
245            );
246
247            let log_degree = log_d_val + log_quotient_degree;
248            let disjoint_domain_gen = domain.create_disjoint_domain(
249                &mut builder,
250                log_degree.into(),
251                Some(config_var.clone()),
252            );
253            domain_assertions(
254                &mut builder,
255                &disjoint_domain_gen,
256                &disjoint_domain_val,
257                zeta_val,
258            );
259
260            // Now try split domains
261            let qc_domains_val = disjoint_domain_val.split_domains(1 << log_quotient_degree);
262            for dom_val in qc_domains_val.iter() {
263                let dom = builder.constant(*dom_val);
264                domain_assertions(&mut builder, &dom, dom_val, zeta_val);
265            }
266
267            // Test the splitting of domains by the builder.
268            let quotient_size = 1 << log_quotient_degree;
269            let qc_domains =
270                disjoint_domain.split_domains(&mut builder, log_quotient_degree, quotient_size);
271            for (i, dom_val) in qc_domains_val.iter().enumerate() {
272                let dom = builder.get(&qc_domains, i);
273                domain_assertions(&mut builder, &dom, dom_val, zeta_val);
274            }
275        }
276        builder.halt();
277
278        let program = builder.compile_isa();
279        execute_program(program, vec![]);
280    }
281    #[test]
282    fn test_domain_static() {
283        test_domain_impl(true);
284    }
285
286    #[test]
287    fn test_domain_dynamic() {
288        test_domain_impl(false);
289    }
290}