p3_commit/
domain.rs
1use alloc::vec::Vec;
2
3use itertools::Itertools;
4use p3_field::{
5 batch_multiplicative_inverse, cyclic_subgroup_coset_known_order, ExtensionField, Field,
6 TwoAdicField,
7};
8use p3_matrix::dense::RowMajorMatrix;
9use p3_matrix::Matrix;
10use p3_util::{log2_ceil_usize, log2_strict_usize};
11
12#[derive(Debug)]
13pub struct LagrangeSelectors<T> {
14 pub is_first_row: T,
15 pub is_last_row: T,
16 pub is_transition: T,
17 pub inv_zeroifier: T,
18}
19
20pub trait PolynomialSpace: Copy {
21 type Val: Field;
22
23 fn size(&self) -> usize;
24
25 fn first_point(&self) -> Self::Val;
26
27 fn next_point<Ext: ExtensionField<Self::Val>>(&self, x: Ext) -> Option<Ext>;
29
30 fn create_disjoint_domain(&self, min_size: usize) -> Self;
33
34 fn split_domains(&self, num_chunks: usize) -> Vec<Self>;
37 fn split_evals(
40 &self,
41 num_chunks: usize,
42 evals: RowMajorMatrix<Self::Val>,
43 ) -> Vec<RowMajorMatrix<Self::Val>>;
44
45 fn zp_at_point<Ext: ExtensionField<Self::Val>>(&self, point: Ext) -> Ext;
46
47 fn selectors_at_point<Ext: ExtensionField<Self::Val>>(
49 &self,
50 point: Ext,
51 ) -> LagrangeSelectors<Ext>;
52
53 fn selectors_on_coset(&self, coset: Self) -> LagrangeSelectors<Vec<Self::Val>>;
55}
56
57#[derive(Copy, Clone, Debug)]
58pub struct TwoAdicMultiplicativeCoset<Val: TwoAdicField> {
59 pub log_n: usize,
60 pub shift: Val,
61}
62
63impl<Val: TwoAdicField> TwoAdicMultiplicativeCoset<Val> {
64 fn gen(&self) -> Val {
65 Val::two_adic_generator(self.log_n)
66 }
67}
68
69impl<Val: TwoAdicField> PolynomialSpace for TwoAdicMultiplicativeCoset<Val> {
70 type Val = Val;
71
72 fn size(&self) -> usize {
73 1 << self.log_n
74 }
75
76 fn first_point(&self) -> Self::Val {
77 self.shift
78 }
79 fn next_point<Ext: ExtensionField<Val>>(&self, x: Ext) -> Option<Ext> {
80 Some(x * self.gen())
81 }
82
83 fn create_disjoint_domain(&self, min_size: usize) -> Self {
84 Self {
85 log_n: log2_ceil_usize(min_size),
86 shift: self.shift * Val::GENERATOR,
87 }
88 }
89 fn split_domains(&self, num_chunks: usize) -> Vec<Self> {
90 let log_chunks = log2_strict_usize(num_chunks);
91 (0..num_chunks)
92 .map(|i| Self {
93 log_n: self.log_n - log_chunks,
94 shift: self.shift * self.gen().exp_u64(i as u64),
95 })
96 .collect()
97 }
98
99 fn split_evals(
100 &self,
101 num_chunks: usize,
102 evals: RowMajorMatrix<Self::Val>,
103 ) -> Vec<RowMajorMatrix<Self::Val>> {
104 (0..num_chunks)
106 .map(|i| {
107 evals
108 .as_view()
109 .vertically_strided(num_chunks, i)
110 .to_row_major_matrix()
111 })
112 .collect()
113 }
114 fn zp_at_point<Ext: ExtensionField<Val>>(&self, point: Ext) -> Ext {
115 (point * self.shift.inverse()).exp_power_of_2(self.log_n) - Ext::ONE
116 }
117
118 fn selectors_at_point<Ext: ExtensionField<Val>>(&self, point: Ext) -> LagrangeSelectors<Ext> {
119 let unshifted_point = point * self.shift.inverse();
120 let z_h = unshifted_point.exp_power_of_2(self.log_n) - Ext::ONE;
121 LagrangeSelectors {
122 is_first_row: z_h / (unshifted_point - Ext::ONE),
123 is_last_row: z_h / (unshifted_point - self.gen().inverse()),
124 is_transition: unshifted_point - self.gen().inverse(),
125 inv_zeroifier: z_h.inverse(),
126 }
127 }
128
129 fn selectors_on_coset(&self, coset: Self) -> LagrangeSelectors<Vec<Val>> {
130 assert_eq!(self.shift, Val::ONE);
131 assert_ne!(coset.shift, Val::ONE);
132 assert!(coset.log_n >= self.log_n);
133 let rate_bits = coset.log_n - self.log_n;
134
135 let s_pow_n = coset.shift.exp_power_of_2(self.log_n);
136 let evals = Val::two_adic_generator(rate_bits)
138 .powers()
139 .take(1 << rate_bits)
140 .map(|x| s_pow_n * x - Val::ONE)
141 .collect_vec();
142
143 let xs = cyclic_subgroup_coset_known_order(coset.gen(), coset.shift, 1 << coset.log_n)
144 .collect_vec();
145
146 let single_point_selector = |i: u64| {
147 let coset_i = self.gen().exp_u64(i);
148 let denoms = xs.iter().map(|&x| x - coset_i).collect_vec();
149 let invs = batch_multiplicative_inverse(&denoms);
150 evals
151 .iter()
152 .cycle()
153 .zip(invs)
154 .map(|(&z_h, inv)| z_h * inv)
155 .collect_vec()
156 };
157
158 let subgroup_last = self.gen().inverse();
159
160 LagrangeSelectors {
161 is_first_row: single_point_selector(0),
162 is_last_row: single_point_selector((1 << self.log_n) - 1),
163 is_transition: xs.into_iter().map(|x| x - subgroup_last).collect(),
164 inv_zeroifier: batch_multiplicative_inverse(&evals)
165 .into_iter()
166 .cycle()
167 .take(1 << coset.log_n)
168 .collect(),
169 }
170 }
171}