1use super::multicore;
5pub use ff::Field;
6use group::{
7 ff::{BatchInvert, PrimeField},
8 Group as _,
9};
10
11pub use pasta_curves::arithmetic::*;
12
13fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
14 let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
15
16 let c = if bases.len() < 4 {
17 1
18 } else if bases.len() < 32 {
19 3
20 } else {
21 (f64::from(bases.len() as u32)).ln().ceil() as usize
22 };
23
24 fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
25 let skip_bits = segment * c;
26 let skip_bytes = skip_bits / 8;
27
28 if skip_bytes >= 32 {
29 return 0;
30 }
31
32 let mut v = [0; 8];
33 for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
34 *v = *o;
35 }
36
37 let mut tmp = u64::from_le_bytes(v);
38 tmp >>= skip_bits - (skip_bytes * 8);
39 tmp = tmp % (1 << c);
40
41 tmp as usize
42 }
43
44 let segments = (256 / c) + 1;
45
46 for current_segment in (0..segments).rev() {
47 for _ in 0..c {
48 *acc = acc.double();
49 }
50
51 #[derive(Clone, Copy)]
52 enum Bucket<C: CurveAffine> {
53 None,
54 Affine(C),
55 Projective(C::Curve),
56 }
57
58 impl<C: CurveAffine> Bucket<C> {
59 fn add_assign(&mut self, other: &C) {
60 *self = match *self {
61 Bucket::None => Bucket::Affine(*other),
62 Bucket::Affine(a) => Bucket::Projective(a + *other),
63 Bucket::Projective(mut a) => {
64 a += *other;
65 Bucket::Projective(a)
66 }
67 }
68 }
69
70 fn add(self, mut other: C::Curve) -> C::Curve {
71 match self {
72 Bucket::None => other,
73 Bucket::Affine(a) => {
74 other += a;
75 other
76 }
77 Bucket::Projective(a) => other + &a,
78 }
79 }
80 }
81
82 let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
83
84 for (coeff, base) in coeffs.iter().zip(bases.iter()) {
85 let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
86 if coeff != 0 {
87 buckets[coeff - 1].add_assign(base);
88 }
89 }
90
91 let mut running_sum = C::Curve::identity();
96 for exp in buckets.into_iter().rev() {
97 running_sum = exp.add(running_sum);
98 *acc = *acc + &running_sum;
99 }
100 }
101}
102
103pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
106 let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
107 let mut acc = C::Curve::identity();
108
109 for byte_idx in (0..32).rev() {
111 for bit_idx in (0..8).rev() {
113 acc = acc.double();
114 for coeff_idx in 0..coeffs.len() {
116 let byte = coeffs[coeff_idx].as_ref()[byte_idx];
117 if ((byte >> bit_idx) & 1) != 0 {
118 acc += bases[coeff_idx];
119 }
120 }
121 }
122 }
123
124 acc
125}
126
127pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
133 assert_eq!(coeffs.len(), bases.len());
134
135 let num_threads = multicore::current_num_threads();
136 if coeffs.len() > num_threads {
137 let chunk = coeffs.len() / num_threads;
138 let num_chunks = coeffs.chunks(chunk).len();
139 let mut results = vec![C::Curve::identity(); num_chunks];
140 multicore::scope(|scope| {
141 let chunk = coeffs.len() / num_threads;
142
143 for ((coeffs, bases), acc) in coeffs
144 .chunks(chunk)
145 .zip(bases.chunks(chunk))
146 .zip(results.iter_mut())
147 {
148 scope.spawn(move |_| {
149 multiexp_serial(coeffs, bases, acc);
150 });
151 }
152 });
153 results.iter().fold(C::Curve::identity(), |a, b| a + b)
154 } else {
155 let mut acc = C::Curve::identity();
156 multiexp_serial(coeffs, bases, &mut acc);
157 acc
158 }
159}
160
161pub fn best_fft<G: Group>(a: &mut [G], omega: G::Scalar, log_n: u32) {
172 fn bitreverse(mut n: usize, l: usize) -> usize {
173 let mut r = 0;
174 for _ in 0..l {
175 r = (r << 1) | (n & 1);
176 n >>= 1;
177 }
178 r
179 }
180
181 let threads = multicore::current_num_threads();
182 let log_threads = log2_floor(threads);
183 let n = a.len() as usize;
184 assert_eq!(n, 1 << log_n);
185
186 for k in 0..n {
187 let rk = bitreverse(k, log_n as usize);
188 if k < rk {
189 a.swap(rk, k);
190 }
191 }
192
193 let twiddles: Vec<_> = (0..(n / 2) as usize)
195 .scan(G::Scalar::one(), |w, _| {
196 let tw = *w;
197 w.group_scale(&omega);
198 Some(tw)
199 })
200 .collect();
201
202 if log_n <= log_threads {
203 let mut chunk = 2_usize;
204 let mut twiddle_chunk = (n / 2) as usize;
205 for _ in 0..log_n {
206 a.chunks_mut(chunk).for_each(|coeffs| {
207 let (left, right) = coeffs.split_at_mut(chunk / 2);
208
209 let (a, left) = left.split_at_mut(1);
211 let (b, right) = right.split_at_mut(1);
212 let t = b[0];
213 b[0] = a[0];
214 a[0].group_add(&t);
215 b[0].group_sub(&t);
216
217 left.iter_mut()
218 .zip(right.iter_mut())
219 .enumerate()
220 .for_each(|(i, (a, b))| {
221 let mut t = *b;
222 t.group_scale(&twiddles[(i + 1) * twiddle_chunk]);
223 *b = *a;
224 a.group_add(&t);
225 b.group_sub(&t);
226 });
227 });
228 chunk *= 2;
229 twiddle_chunk /= 2;
230 }
231 } else {
232 recursive_butterfly_arithmetic(a, n, 1, &twiddles)
233 }
234}
235
236pub fn recursive_butterfly_arithmetic<G: Group>(
238 a: &mut [G],
239 n: usize,
240 twiddle_chunk: usize,
241 twiddles: &[G::Scalar],
242) {
243 if n == 2 {
244 let t = a[1];
245 a[1] = a[0];
246 a[0].group_add(&t);
247 a[1].group_sub(&t);
248 } else {
249 let (left, right) = a.split_at_mut(n / 2);
250 rayon::join(
251 || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
252 || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
253 );
254
255 let (a, left) = left.split_at_mut(1);
257 let (b, right) = right.split_at_mut(1);
258 let t = b[0];
259 b[0] = a[0];
260 a[0].group_add(&t);
261 b[0].group_sub(&t);
262
263 left.iter_mut()
264 .zip(right.iter_mut())
265 .enumerate()
266 .for_each(|(i, (a, b))| {
267 let mut t = *b;
268 t.group_scale(&twiddles[(i + 1) * twiddle_chunk]);
269 *b = *a;
270 a.group_add(&t);
271 b.group_sub(&t);
272 });
273 }
274}
275
276pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
278 poly.iter()
280 .rev()
281 .fold(F::zero(), |acc, coeff| acc * point + coeff)
282}
283
284pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
288 assert_eq!(a.len(), b.len());
290
291 let mut acc = F::zero();
292 for (a, b) in a.iter().zip(b.iter()) {
293 acc += (*a) * (*b);
294 }
295
296 acc
297}
298
299pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
302where
303 I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
304{
305 b = -b;
306 let a = a.into_iter();
307
308 let mut q = vec![F::zero(); a.len() - 1];
309
310 let mut tmp = F::zero();
311 for (q, r) in q.iter_mut().rev().zip(a.rev()) {
312 let mut lead_coeff = *r;
313 lead_coeff.sub_assign(&tmp);
314 *q = lead_coeff;
315 tmp = lead_coeff;
316 tmp.mul_assign(&b);
317 }
318
319 q
320}
321
322pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
325 let n = v.len();
326 let num_threads = multicore::current_num_threads();
327 let mut chunk = (n as usize) / num_threads;
328 if chunk < num_threads {
329 chunk = n as usize;
330 }
331
332 multicore::scope(|scope| {
333 for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
334 let f = f.clone();
335 scope.spawn(move |_| {
336 let start = chunk_num * chunk;
337 f(v, start);
338 });
339 }
340 });
341}
342
343fn log2_floor(num: usize) -> u32 {
344 assert!(num > 0);
345
346 let mut pow = 0;
347
348 while (1 << (pow + 1)) <= num {
349 pow += 1;
350 }
351
352 pow
353}
354
355pub fn lagrange_interpolate<F: FieldExt>(points: &[F], evals: &[F]) -> Vec<F> {
359 assert_eq!(points.len(), evals.len());
360 if points.len() == 1 {
361 return vec![evals[0]];
363 } else {
364 let mut denoms = Vec::with_capacity(points.len());
365 for (j, x_j) in points.iter().enumerate() {
366 let mut denom = Vec::with_capacity(points.len() - 1);
367 for x_k in points
368 .iter()
369 .enumerate()
370 .filter(|&(k, _)| k != j)
371 .map(|a| a.1)
372 {
373 denom.push(*x_j - x_k);
374 }
375 denoms.push(denom);
376 }
377 denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
379
380 let mut final_poly = vec![F::zero(); points.len()];
381 for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
382 let mut tmp: Vec<F> = Vec::with_capacity(points.len());
383 let mut product = Vec::with_capacity(points.len() - 1);
384 tmp.push(F::one());
385 for (x_k, denom) in points
386 .iter()
387 .enumerate()
388 .filter(|&(k, _)| k != j)
389 .map(|a| a.1)
390 .zip(denoms.into_iter())
391 {
392 product.resize(tmp.len() + 1, F::zero());
393 for ((a, b), product) in tmp
394 .iter()
395 .chain(std::iter::once(&F::zero()))
396 .zip(std::iter::once(&F::zero()).chain(tmp.iter()))
397 .zip(product.iter_mut())
398 {
399 *product = *a * (-denom * x_k) + *b * denom;
400 }
401 std::mem::swap(&mut tmp, &mut product);
402 }
403 assert_eq!(tmp.len(), points.len());
404 assert_eq!(product.len(), points.len() - 1);
405 for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp.into_iter()) {
406 *final_coeff += interpolation_coeff * eval;
407 }
408 }
409 final_poly
410 }
411}
412
413#[cfg(test)]
414use rand_core::OsRng;
415
416#[cfg(test)]
417use crate::pasta::Fp;
418
419#[test]
420fn test_lagrange_interpolate() {
421 let rng = OsRng;
422
423 let points = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
424 let evals = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
425
426 for coeffs in 0..5 {
427 let points = &points[0..coeffs];
428 let evals = &evals[0..coeffs];
429
430 let poly = lagrange_interpolate(points, evals);
431 assert_eq!(poly.len(), points.len());
432
433 for (point, eval) in points.iter().zip(evals) {
434 assert_eq!(eval_polynomial(&poly, *point), *eval);
435 }
436 }
437}