1use core::num::Wrapping;
6use core::ops::{Index, IndexMut};
7use core::slice::SliceIndex;
8use std::sync::Barrier;
9
10struct tile {
11 x: usize,
12 dx: usize,
13 y: usize,
14 dy: usize,
15}
16
17#[repr(transparent)]
21struct Cell<T: ?Sized> {
22 value: T,
23}
24unsafe impl<T: ?Sized + Sync> Sync for Cell<T> {}
25impl<T> Cell<T> {
26 pub fn as_ptr(&self) -> *mut T {
27 &self.value as *const T as *mut T
28 }
29}
30
31macro_rules! pippenger_mult_impl {
32 (
33 $points:ident,
34 $point:ty,
35 $point_affine:ty,
36 $to_affines:ident,
37 $scratch_sizeof:ident,
38 $multi_scalar_mult:ident,
39 $tile_mult:ident,
40 $add_or_double:ident,
41 $double:ident,
42 $test_mod:ident,
43 $generator:ident,
44 $mult:ident,
45 $add:ident,
46 $is_inf:ident,
47 $in_group:ident,
48 $from_affine:ident,
49 ) => {
50 pub struct $points {
51 points: Vec<$point_affine>,
52 }
53
54 impl<I: SliceIndex<[$point_affine]>> Index<I> for $points {
55 type Output = I::Output;
56
57 #[inline]
58 fn index(&self, i: I) -> &Self::Output {
59 &self.points[i]
60 }
61 }
62 impl<I: SliceIndex<[$point_affine]>> IndexMut<I> for $points {
63 #[inline]
64 fn index_mut(&mut self, i: I) -> &mut Self::Output {
65 &mut self.points[i]
66 }
67 }
68
69 impl $points {
70 #[inline]
71 pub fn as_slice(&self) -> &[$point_affine] {
72 self.points.as_slice()
73 }
74
75 pub fn from(points: &[$point]) -> Self {
76 let npoints = points.len();
77 let mut ret = Self {
78 points: Vec::with_capacity(npoints),
79 };
80 unsafe { ret.points.set_len(npoints) };
81
82 let pool = mt::da_pool();
83 let ncpus = pool.max_count();
84 if ncpus < 2 || npoints < 768 {
85 let p: [*const $point; 2] = [&points[0], ptr::null()];
86 unsafe { $to_affines(&mut ret.points[0], &p[0], npoints) };
87 return ret;
88 }
89
90 let mut nslices = (npoints + 511) / 512;
91 nslices = core::cmp::min(nslices, ncpus);
92 let wg = Arc::new((Barrier::new(2), AtomicUsize::new(nslices)));
93
94 let (mut delta, mut rem) =
95 (npoints / nslices + 1, Wrapping(npoints % nslices));
96 let mut x = 0usize;
97 while x < npoints {
98 let out = &mut ret.points[x];
99 let inp = &points[x];
100
101 delta -= (rem == Wrapping(0)) as usize;
102 rem -= Wrapping(1);
103 x += delta;
104
105 let wg = wg.clone();
106 pool.joined_execute(move || {
107 let p: [*const $point; 2] = [inp, ptr::null()];
108 unsafe { $to_affines(out, &p[0], delta) };
109 if wg.1.fetch_sub(1, Ordering::AcqRel) == 1 {
110 wg.0.wait();
111 }
112 });
113 }
114 wg.0.wait();
115
116 ret
117 }
118
119 #[inline]
120 pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
121 self.as_slice().mult(scalars, nbits)
122 }
123
124 #[inline]
125 pub fn add(&self) -> $point {
126 self.as_slice().add()
127 }
128 }
129
130 impl MultiPoint for [$point_affine] {
131 type Output = $point;
132
133 fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
134 let npoints = self.len();
135 let nbytes = (nbits + 7) / 8;
136
137 if scalars.len() < nbytes * npoints {
138 panic!("scalars length mismatch");
139 }
140
141 let pool = mt::da_pool();
142 let ncpus = pool.max_count();
143 if ncpus < 2 {
144 let p: [*const $point_affine; 2] = [&self[0], ptr::null()];
145 let s: [*const u8; 2] = [&scalars[0], ptr::null()];
146
147 unsafe {
148 let mut scratch: Vec<u64> =
149 Vec::with_capacity($scratch_sizeof(npoints) / 8);
150 #[allow(clippy::uninit_vec)]
151 scratch.set_len(scratch.capacity());
152 let mut ret = <$point>::default();
153 $multi_scalar_mult(
154 &mut ret,
155 &p[0],
156 npoints,
157 &s[0],
158 nbits,
159 &mut scratch[0],
160 );
161 return ret;
162 }
163 }
164
165 if npoints < 32 {
166 let (tx, rx) = channel();
167 let counter = Arc::new(AtomicUsize::new(0));
168 let n_workers = core::cmp::min(ncpus, npoints);
169
170 for _ in 0..n_workers {
171 let tx = tx.clone();
172 let counter = counter.clone();
173
174 pool.joined_execute(move || {
175 let mut acc = <$point>::default();
176 let mut tmp = <$point>::default();
177 let mut first = true;
178
179 loop {
180 let work =
181 counter.fetch_add(1, Ordering::Relaxed);
182 if work >= npoints {
183 break;
184 }
185
186 unsafe {
187 $from_affine(&mut tmp, &self[work]);
188 let scalar = &scalars[nbytes * work];
189 if first {
190 $mult(&mut acc, &tmp, scalar, nbits);
191 first = false;
192 } else {
193 $mult(&mut tmp, &tmp, scalar, nbits);
194 $add_or_double(&mut acc, &acc, &tmp);
195 }
196 }
197 }
198
199 tx.send(acc).expect("disaster");
200 });
201 }
202
203 let mut ret = rx.recv().expect("disaster");
204 for _ in 1..n_workers {
205 let p = rx.recv().expect("disaster");
206 unsafe { $add_or_double(&mut ret, &ret, &p) };
207 }
208
209 return ret;
210 }
211
212 let (nx, ny, window) =
213 breakdown(nbits, pippenger_window_size(npoints), ncpus);
214
215 let mut grid: Vec<(tile, Cell<$point>)> =
217 Vec::with_capacity(nx * ny);
218 #[allow(clippy::uninit_vec)]
219 unsafe { grid.set_len(grid.capacity()) };
220 let dx = npoints / nx;
221 let mut y = window * (ny - 1);
222 let mut total = 0usize;
223
224 while total < nx {
225 grid[total].0.x = total * dx;
226 grid[total].0.dx = dx;
227 grid[total].0.y = y;
228 grid[total].0.dy = nbits - y;
229 total += 1;
230 }
231 grid[total - 1].0.dx = npoints - grid[total - 1].0.x;
232 while y != 0 {
233 y -= window;
234 for i in 0..nx {
235 grid[total].0.x = grid[i].0.x;
236 grid[total].0.dx = grid[i].0.dx;
237 grid[total].0.y = y;
238 grid[total].0.dy = window;
239 total += 1;
240 }
241 }
242 let grid = &grid[..];
243
244 let points = &self[..];
245 let sz = unsafe { $scratch_sizeof(0) / 8 };
246
247 let mut row_sync: Vec<AtomicUsize> = Vec::with_capacity(ny);
248 row_sync.resize_with(ny, Default::default);
249 let row_sync = Arc::new(row_sync);
250 let counter = Arc::new(AtomicUsize::new(0));
251 let (tx, rx) = channel();
252 let n_workers = core::cmp::min(ncpus, total);
253 for _ in 0..n_workers {
254 let tx = tx.clone();
255 let counter = counter.clone();
256 let row_sync = row_sync.clone();
257
258 pool.joined_execute(move || {
259 let mut scratch = vec![0u64; sz << (window - 1)];
260 let mut p: [*const $point_affine; 2] =
261 [ptr::null(), ptr::null()];
262 let mut s: [*const u8; 2] = [ptr::null(), ptr::null()];
263
264 loop {
265 let work = counter.fetch_add(1, Ordering::Relaxed);
266 if work >= total {
267 break;
268 }
269 let x = grid[work].0.x;
270 let y = grid[work].0.y;
271
272 p[0] = &points[x];
273 s[0] = &scalars[x * nbytes];
274 unsafe {
275 $tile_mult(
276 grid[work].1.as_ptr(),
277 &p[0],
278 grid[work].0.dx,
279 &s[0],
280 nbits,
281 &mut scratch[0],
282 y,
283 window,
284 );
285 }
286 if row_sync[y / window]
287 .fetch_add(1, Ordering::AcqRel)
288 == nx - 1
289 {
290 tx.send(y).expect("disaster");
291 }
292 }
293 });
294 }
295
296 let mut ret = <$point>::default();
297 let mut rows = vec![false; ny];
298 let mut row = 0usize;
299 for _ in 0..ny {
300 let mut y = rx.recv().unwrap();
301 rows[y / window] = true;
302 while grid[row].0.y == y {
303 while row < total && grid[row].0.y == y {
304 unsafe {
305 $add_or_double(
306 &mut ret,
307 &ret,
308 grid[row].1.as_ptr(),
309 );
310 }
311 row += 1;
312 }
313 if y == 0 {
314 break;
315 }
316 for _ in 0..window {
317 unsafe { $double(&mut ret, &ret) };
318 }
319 y -= window;
320 if !rows[y / window] {
321 break;
322 }
323 }
324 }
325 ret
326 }
327
328 fn add(&self) -> $point {
329 let npoints = self.len();
330
331 let pool = mt::da_pool();
332 let ncpus = pool.max_count();
333 if ncpus < 2 || npoints < 384 {
334 let p: [*const _; 2] = [&self[0], ptr::null()];
335 let mut ret = <$point>::default();
336 unsafe { $add(&mut ret, &p[0], npoints) };
337 return ret;
338 }
339
340 let (tx, rx) = channel();
341 let counter = Arc::new(AtomicUsize::new(0));
342 let nchunks = (npoints + 255) / 256;
343 let chunk = npoints / nchunks + 1;
344
345 let n_workers = core::cmp::min(ncpus, nchunks);
346 for _ in 0..n_workers {
347 let tx = tx.clone();
348 let counter = counter.clone();
349
350 pool.joined_execute(move || {
351 let mut acc = <$point>::default();
352 let mut chunk = chunk;
353 let mut p: [*const _; 2] = [ptr::null(), ptr::null()];
354
355 loop {
356 let work =
357 counter.fetch_add(chunk, Ordering::Relaxed);
358 if work >= npoints {
359 break;
360 }
361 p[0] = &self[work];
362 if work + chunk > npoints {
363 chunk = npoints - work;
364 }
365 unsafe {
366 let mut t = MaybeUninit::<$point>::uninit();
367 $add(t.as_mut_ptr(), &p[0], chunk);
368 $add_or_double(&mut acc, &acc, t.as_ptr());
369 };
370 }
371 tx.send(acc).expect("disaster");
372 });
373 }
374
375 let mut ret = rx.recv().unwrap();
376 for _ in 1..n_workers {
377 unsafe {
378 $add_or_double(&mut ret, &ret, &rx.recv().unwrap())
379 };
380 }
381
382 ret
383 }
384
385 fn validate(&self) -> Result<(), BLST_ERROR> {
386 fn check(point: &$point_affine) -> Result<(), BLST_ERROR> {
387 if unsafe { $is_inf(point) } {
388 return Err(BLST_ERROR::BLST_PK_IS_INFINITY);
389 }
390 if !unsafe { $in_group(point) } {
391 return Err(BLST_ERROR::BLST_POINT_NOT_IN_GROUP);
392 }
393 Ok(())
394 }
395
396 let npoints = self.len();
397
398 let pool = mt::da_pool();
399 let n_workers = core::cmp::min(npoints, pool.max_count());
400 if n_workers < 2 {
401 for i in 0..npoints {
402 check(&self[i])?
403 }
404 return Ok(())
405 }
406
407 let counter = Arc::new(AtomicUsize::new(0));
408 let valid = Arc::new(AtomicBool::new(true));
409 let wg =
410 Arc::new((Barrier::new(2), AtomicUsize::new(n_workers)));
411
412 for _ in 0..n_workers {
413 let counter = counter.clone();
414 let valid = valid.clone();
415 let wg = wg.clone();
416
417 pool.joined_execute(move || {
418 while valid.load(Ordering::Relaxed) {
419 let work = counter.fetch_add(1, Ordering::Relaxed);
420 if work >= npoints {
421 break;
422 }
423
424 if check(&self[work]).is_err() {
425 valid.store(false, Ordering::Relaxed);
426 break;
427 }
428 }
429
430 if wg.1.fetch_sub(1, Ordering::AcqRel) == 1 {
431 wg.0.wait();
432 }
433 });
434 }
435
436 wg.0.wait();
437
438 if valid.load(Ordering::Relaxed) {
439 return Ok(());
440 } else {
441 return Err(BLST_ERROR::BLST_POINT_NOT_IN_GROUP);
442 }
443 }
444 }
445
446 #[cfg(test)]
447 pippenger_test_mod!(
448 $test_mod,
449 $points,
450 $point,
451 $add_or_double,
452 $generator,
453 $mult,
454 );
455 };
456}
457
458#[cfg(test)]
459include!("pippenger-test_mod.rs");
460
461pippenger_mult_impl!(
462 p1_affines,
463 blst_p1,
464 blst_p1_affine,
465 blst_p1s_to_affine,
466 blst_p1s_mult_pippenger_scratch_sizeof,
467 blst_p1s_mult_pippenger,
468 blst_p1s_tile_pippenger,
469 blst_p1_add_or_double,
470 blst_p1_double,
471 p1_multi_point,
472 blst_p1_generator,
473 blst_p1_mult,
474 blst_p1s_add,
475 blst_p1_affine_is_inf,
476 blst_p1_affine_in_g1,
477 blst_p1_from_affine,
478);
479
480pippenger_mult_impl!(
481 p2_affines,
482 blst_p2,
483 blst_p2_affine,
484 blst_p2s_to_affine,
485 blst_p2s_mult_pippenger_scratch_sizeof,
486 blst_p2s_mult_pippenger,
487 blst_p2s_tile_pippenger,
488 blst_p2_add_or_double,
489 blst_p2_double,
490 p2_multi_point,
491 blst_p2_generator,
492 blst_p2_mult,
493 blst_p2s_add,
494 blst_p2_affine_is_inf,
495 blst_p2_affine_in_g2,
496 blst_p2_from_affine,
497);
498
499fn num_bits(l: usize) -> usize {
500 8 * core::mem::size_of_val(&l) - l.leading_zeros() as usize
501}
502
503fn breakdown(
504 nbits: usize,
505 window: usize,
506 ncpus: usize,
507) -> (usize, usize, usize) {
508 let mut nx: usize;
509 let mut wnd: usize;
510
511 if nbits > window * ncpus {
512 nx = 1;
513 wnd = num_bits(ncpus / 4);
514 if (window + wnd) > 18 {
515 wnd = window - wnd;
516 } else {
517 wnd = (nbits / window + ncpus - 1) / ncpus;
518 if (nbits / (window + 1) + ncpus - 1) / ncpus < wnd {
519 wnd = window + 1;
520 } else {
521 wnd = window;
522 }
523 }
524 } else {
525 nx = 2;
526 wnd = window - 2;
527 while (nbits / wnd + 1) * nx < ncpus {
528 nx += 1;
529 wnd = window - num_bits(3 * nx / 2);
530 }
531 nx -= 1;
532 wnd = window - num_bits(3 * nx / 2);
533 }
534 let ny = nbits / wnd + 1;
535 wnd = nbits / ny + 1;
536
537 (nx, ny, wnd)
538}
539
540fn pippenger_window_size(npoints: usize) -> usize {
541 let wbits = num_bits(npoints);
542
543 if wbits > 13 {
544 return wbits - 4;
545 }
546 if wbits > 5 {
547 return wbits - 3;
548 }
549 2
550}