1use super::multicore;
5pub use ff::Field;
6use group::{
7 ff::{BatchInvert, PrimeField},
8 prime::PrimeCurveAffine,
9 Curve, GroupOpsOwned, ScalarMulOwned,
10};
11use rayon::prelude::*;
12
13use halo2curves::msm::msm_best;
14pub use halo2curves::{CurveAffine, CurveExt};
15
16pub trait FftGroup<Scalar: Field>:
20 Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
21{
22}
23
24impl<T, Scalar> FftGroup<Scalar> for T
25where
26 Scalar: Field,
27 T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
28{
29}
30
31pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
38 msm_best(coeffs, bases)
39}
40
41pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(
43 a: &mut [G],
44 omega: Scalar,
45 log_n: u32,
46 data: &FFTData<Scalar>,
47 inverse: bool,
48) {
49 fft::fft(a, omega, log_n, data, inverse);
50}
51
52pub fn g_to_lagrange<C: PrimeCurveAffine>(g_projective: Vec<C::Curve>, k: u32) -> Vec<C> {
54 let n_inv = C::Scalar::TWO_INV.pow_vartime([k as u64, 0, 0, 0]);
55 let omega = C::Scalar::ROOT_OF_UNITY;
56 let mut omega_inv = C::Scalar::ROOT_OF_UNITY_INV;
57 for _ in k..C::Scalar::S {
58 omega_inv = omega_inv.square();
59 }
60
61 let mut g_lagrange_projective = g_projective;
62 let n = g_lagrange_projective.len();
63 let fft_data = FFTData::new(n, omega, omega_inv);
64
65 best_fft(&mut g_lagrange_projective, omega_inv, k, &fft_data, true);
66 parallelize(&mut g_lagrange_projective, |g, _| {
67 for g in g.iter_mut() {
68 *g *= n_inv;
69 }
70 });
71
72 let mut g_lagrange = vec![C::identity(); 1 << k];
73 parallelize(&mut g_lagrange, |g_lagrange, starts| {
74 C::Curve::batch_normalize(
75 &g_lagrange_projective[starts..(starts + g_lagrange.len())],
76 g_lagrange,
77 );
78 });
79
80 g_lagrange
81}
82
83pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
85 fn evaluate<F: Field>(poly: &[F], point: F) -> F {
86 poly.iter()
87 .rev()
88 .fold(F::ZERO, |acc, coeff| acc * point + coeff)
89 }
90 let n = poly.len();
91 let num_threads = multicore::current_num_threads();
92 if n * 2 < num_threads {
93 evaluate(poly, point)
94 } else {
95 let chunk_size = (n + num_threads - 1) / num_threads;
96 let mut parts = vec![F::ZERO; num_threads];
97 multicore::scope(|scope| {
98 for (chunk_idx, (out, poly)) in
99 parts.chunks_mut(1).zip(poly.chunks(chunk_size)).enumerate()
100 {
101 scope.spawn(move |_| {
102 let start = chunk_idx * chunk_size;
103 out[0] = evaluate(poly, point) * point.pow_vartime([start as u64, 0, 0, 0]);
104 });
105 }
106 });
107 parts.iter().fold(F::ZERO, |acc, coeff| acc + coeff)
108 }
109}
110
111pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
117 assert_eq!(a.len(), b.len());
118
119 if a.len() < 32 {
120 let mut acc = F::ZERO;
122 for (a, b) in a.iter().zip(b.iter()) {
123 acc += (*a) * (*b);
124 }
125 return acc;
126 }
127
128 a.par_iter().zip(b.par_iter()).map(|(a, b)| (*a) * b).sum()
130}
131
132pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
135where
136 I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
137{
138 b = -b;
139 let a = a.into_iter();
140
141 let mut q = vec![F::ZERO; a.len() - 1];
142
143 let mut tmp = F::ZERO;
144 for (q, r) in q.iter_mut().rev().zip(a.rev()) {
145 let mut lead_coeff = *r;
146 lead_coeff.sub_assign(&tmp);
147 *q = lead_coeff;
148 tmp = lead_coeff;
149 tmp.mul_assign(&b);
150 }
151
152 q
153}
154
155pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
158 let f = &f;
182 let total_iters = v.len();
183 let num_threads = multicore::current_num_threads();
184 let base_chunk_size = total_iters / num_threads;
185 let cutoff_chunk_id = total_iters % num_threads;
186 let split_pos = cutoff_chunk_id * (base_chunk_size + 1);
187 let (v_hi, v_lo) = v.split_at_mut(split_pos);
188
189 multicore::scope(|scope| {
190 if cutoff_chunk_id != 0 {
192 for (chunk_id, chunk) in v_hi.chunks_exact_mut(base_chunk_size + 1).enumerate() {
193 let offset = chunk_id * (base_chunk_size + 1);
194 scope.spawn(move |_| f(chunk, offset));
195 }
196 }
197 if base_chunk_size != 0 {
199 for (chunk_id, chunk) in v_lo.chunks_exact_mut(base_chunk_size).enumerate() {
200 let offset = split_pos + (chunk_id * base_chunk_size);
201 scope.spawn(move |_| f(chunk, offset));
202 }
203 }
204 });
205}
206
207pub fn log2_floor(num: usize) -> u32 {
208 assert!(num > 0);
209
210 let mut pow = 0;
211
212 while (1 << (pow + 1)) <= num {
213 pow += 1;
214 }
215
216 pow
217}
218
219pub fn lagrange_interpolate<F: Field>(points: &[F], evals: &[F]) -> Vec<F> {
223 assert_eq!(points.len(), evals.len());
224 if points.len() == 1 {
225 vec![evals[0]]
227 } else {
228 let mut denoms = Vec::with_capacity(points.len());
229 for (j, x_j) in points.iter().enumerate() {
230 let mut denom = Vec::with_capacity(points.len() - 1);
231 for x_k in points
232 .iter()
233 .enumerate()
234 .filter(|&(k, _)| k != j)
235 .map(|a| a.1)
236 {
237 denom.push(*x_j - x_k);
238 }
239 denoms.push(denom);
240 }
241 denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
243
244 let mut final_poly = vec![F::ZERO; points.len()];
245 for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
246 let mut tmp: Vec<F> = Vec::with_capacity(points.len());
247 let mut product = Vec::with_capacity(points.len() - 1);
248 tmp.push(F::ONE);
249 for (x_k, denom) in points
250 .iter()
251 .enumerate()
252 .filter(|&(k, _)| k != j)
253 .map(|a| a.1)
254 .zip(denoms)
255 {
256 product.resize(tmp.len() + 1, F::ZERO);
257 for ((a, b), product) in tmp
258 .iter()
259 .chain(std::iter::once(&F::ZERO))
260 .zip(std::iter::once(&F::ZERO).chain(tmp.iter()))
261 .zip(product.iter_mut())
262 {
263 *product = *a * (-denom * x_k) + *b * denom;
264 }
265 std::mem::swap(&mut tmp, &mut product);
266 }
267 assert_eq!(tmp.len(), points.len());
268 assert_eq!(product.len(), points.len() - 1);
269 for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp) {
270 *final_coeff += interpolation_coeff * eval;
271 }
272 }
273 final_poly
274 }
275}
276
277pub(crate) fn evaluate_vanishing_polynomial<F: Field>(roots: &[F], z: F) -> F {
278 fn evaluate<F: Field>(roots: &[F], z: F) -> F {
279 roots.iter().fold(F::ONE, |acc, point| (z - point) * acc)
280 }
281 let n = roots.len();
282 let num_threads = multicore::current_num_threads();
283 if n * 2 < num_threads {
284 evaluate(roots, z)
285 } else {
286 let chunk_size = (n + num_threads - 1) / num_threads;
287 let mut parts = vec![F::ONE; num_threads];
288 multicore::scope(|scope| {
289 for (out, roots) in parts.chunks_mut(1).zip(roots.chunks(chunk_size)) {
290 scope.spawn(move |_| out[0] = evaluate(roots, z));
291 }
292 });
293 parts.iter().fold(F::ONE, |acc, part| acc * part)
294 }
295}
296
297pub(crate) fn powers<F: Field>(base: F) -> impl Iterator<Item = F> {
298 std::iter::successors(Some(F::ONE), move |power| Some(base * power))
299}
300
301pub fn bitreverse(mut n: usize, l: usize) -> usize {
303 let mut r = 0;
304 for _ in 0..l {
305 r = (r << 1) | (n & 1);
306 n >>= 1;
307 }
308 r
309}
310
311#[cfg(test)]
312use rand_core::OsRng;
313
314use crate::fft::{self, recursive::FFTData};
315#[cfg(test)]
316use crate::halo2curves::pasta::Fp;
317#[test]
320fn test_lagrange_interpolate() {
321 let rng = OsRng;
322
323 let points = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
324 let evals = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
325
326 for coeffs in 0..5 {
327 let points = &points[0..coeffs];
328 let evals = &evals[0..coeffs];
329
330 let poly = lagrange_interpolate(points, evals);
331 assert_eq!(poly.len(), points.len());
332
333 for (point, eval) in points.iter().zip(evals) {
334 assert_eq!(eval_polynomial(&poly, *point), *eval);
335 }
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use rand_core::OsRng;
343
344 #[test]
345 fn test_compute_inner_product() {
346 let rng = OsRng;
347
348 let a_small: Vec<Fp> = (0..16).map(|_| Fp::random(rng)).collect();
350 let b_small: Vec<Fp> = (0..16).map(|_| Fp::random(rng)).collect();
351 let result_small = compute_inner_product(&a_small, &b_small);
352 let expected_small = a_small
353 .iter()
354 .zip(b_small.iter())
355 .fold(Fp::ZERO, |acc, (a, b)| acc + (*a) * (*b));
356 assert_eq!(result_small, expected_small);
357
358 let a_large: Vec<Fp> = (0..64).map(|_| Fp::random(rng)).collect();
360 let b_large: Vec<Fp> = (0..64).map(|_| Fp::random(rng)).collect();
361 let result_large = compute_inner_product(&a_large, &b_large);
362 let expected_large = a_large
363 .iter()
364 .zip(b_large.iter())
365 .fold(Fp::ZERO, |acc, (a, b)| acc + (*a) * (*b));
366 assert_eq!(result_large, expected_large);
367 }
368}