blst/
pippenger.rs

1// Copyright Supranational LLC
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use core::num::Wrapping;
6use core::ops::{Index, IndexMut};
7use core::slice::SliceIndex;
8use std::sync::Barrier;
9
10struct tile {
11    x: usize,
12    dx: usize,
13    y: usize,
14    dy: usize,
15}
16
17// Minimalist core::cell::Cell stand-in, but with Sync marker, which
18// makes it possible to pass it to multiple threads. It works, because
19// *here* each Cell is written only once and by just one thread.
20#[repr(transparent)]
21struct Cell<T: ?Sized> {
22    value: T,
23}
24unsafe impl<T: ?Sized + Sync> Sync for Cell<T> {}
25impl<T> Cell<T> {
26    pub fn as_ptr(&self) -> *mut T {
27        &self.value as *const T as *mut T
28    }
29}
30
31macro_rules! pippenger_mult_impl {
32    (
33        $points:ident,
34        $point:ty,
35        $point_affine:ty,
36        $to_affines:ident,
37        $scratch_sizeof:ident,
38        $multi_scalar_mult:ident,
39        $tile_mult:ident,
40        $add_or_double:ident,
41        $double:ident,
42        $test_mod:ident,
43        $generator:ident,
44        $mult:ident,
45        $add:ident,
46        $is_inf:ident,
47        $in_group:ident,
48        $from_affine:ident,
49    ) => {
50        pub struct $points {
51            points: Vec<$point_affine>,
52        }
53
54        impl<I: SliceIndex<[$point_affine]>> Index<I> for $points {
55            type Output = I::Output;
56
57            #[inline]
58            fn index(&self, i: I) -> &Self::Output {
59                &self.points[i]
60            }
61        }
62        impl<I: SliceIndex<[$point_affine]>> IndexMut<I> for $points {
63            #[inline]
64            fn index_mut(&mut self, i: I) -> &mut Self::Output {
65                &mut self.points[i]
66            }
67        }
68
69        impl $points {
70            #[inline]
71            pub fn as_slice(&self) -> &[$point_affine] {
72                self.points.as_slice()
73            }
74
75            pub fn from(points: &[$point]) -> Self {
76                let npoints = points.len();
77                let mut ret = Self {
78                    points: Vec::with_capacity(npoints),
79                };
80                unsafe { ret.points.set_len(npoints) };
81
82                let pool = mt::da_pool();
83                let ncpus = pool.max_count();
84                if ncpus < 2 || npoints < 768 {
85                    let p: [*const $point; 2] = [&points[0], ptr::null()];
86                    unsafe { $to_affines(&mut ret.points[0], &p[0], npoints) };
87                    return ret;
88                }
89
90                let mut nslices = (npoints + 511) / 512;
91                nslices = core::cmp::min(nslices, ncpus);
92                let wg = Arc::new((Barrier::new(2), AtomicUsize::new(nslices)));
93
94                let (mut delta, mut rem) =
95                    (npoints / nslices + 1, Wrapping(npoints % nslices));
96                let mut x = 0usize;
97                while x < npoints {
98                    let out = &mut ret.points[x];
99                    let inp = &points[x];
100
101                    delta -= (rem == Wrapping(0)) as usize;
102                    rem -= Wrapping(1);
103                    x += delta;
104
105                    let wg = wg.clone();
106                    pool.joined_execute(move || {
107                        let p: [*const $point; 2] = [inp, ptr::null()];
108                        unsafe { $to_affines(out, &p[0], delta) };
109                        if wg.1.fetch_sub(1, Ordering::AcqRel) == 1 {
110                            wg.0.wait();
111                        }
112                    });
113                }
114                wg.0.wait();
115
116                ret
117            }
118
119            #[inline]
120            pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
121                self.as_slice().mult(scalars, nbits)
122            }
123
124            #[inline]
125            pub fn add(&self) -> $point {
126                self.as_slice().add()
127            }
128        }
129
130        impl MultiPoint for [$point_affine] {
131            type Output = $point;
132
133            fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
134                let npoints = self.len();
135                let nbytes = (nbits + 7) / 8;
136
137                if scalars.len() < nbytes * npoints {
138                    panic!("scalars length mismatch");
139                }
140
141                let pool = mt::da_pool();
142                let ncpus = pool.max_count();
143                if ncpus < 2 {
144                    let p: [*const $point_affine; 2] = [&self[0], ptr::null()];
145                    let s: [*const u8; 2] = [&scalars[0], ptr::null()];
146
147                    unsafe {
148                        let mut scratch: Vec<u64> =
149                            Vec::with_capacity($scratch_sizeof(npoints) / 8);
150                        #[allow(clippy::uninit_vec)]
151                        scratch.set_len(scratch.capacity());
152                        let mut ret = <$point>::default();
153                        $multi_scalar_mult(
154                            &mut ret,
155                            &p[0],
156                            npoints,
157                            &s[0],
158                            nbits,
159                            &mut scratch[0],
160                        );
161                        return ret;
162                    }
163                }
164
165                if npoints < 32 {
166                    let (tx, rx) = channel();
167                    let counter = Arc::new(AtomicUsize::new(0));
168                    let n_workers = core::cmp::min(ncpus, npoints);
169
170                    for _ in 0..n_workers {
171                        let tx = tx.clone();
172                        let counter = counter.clone();
173
174                        pool.joined_execute(move || {
175                            let mut acc = <$point>::default();
176                            let mut tmp = <$point>::default();
177                            let mut first = true;
178
179                            loop {
180                                let work =
181                                    counter.fetch_add(1, Ordering::Relaxed);
182                                if work >= npoints {
183                                    break;
184                                }
185
186                                unsafe {
187                                    $from_affine(&mut tmp, &self[work]);
188                                    let scalar = &scalars[nbytes * work];
189                                    if first {
190                                        $mult(&mut acc, &tmp, scalar, nbits);
191                                        first = false;
192                                    } else {
193                                        $mult(&mut tmp, &tmp, scalar, nbits);
194                                        $add_or_double(&mut acc, &acc, &tmp);
195                                    }
196                                }
197                            }
198
199                            tx.send(acc).expect("disaster");
200                        });
201                    }
202
203                    let mut ret = rx.recv().expect("disaster");
204                    for _ in 1..n_workers {
205                        let p = rx.recv().expect("disaster");
206                        unsafe { $add_or_double(&mut ret, &ret, &p) };
207                    }
208
209                    return ret;
210                }
211
212                let (nx, ny, window) =
213                    breakdown(nbits, pippenger_window_size(npoints), ncpus);
214
215                // |grid[]| holds "coordinates" and place for result
216                let mut grid: Vec<(tile, Cell<$point>)> =
217                    Vec::with_capacity(nx * ny);
218                #[allow(clippy::uninit_vec)]
219                unsafe { grid.set_len(grid.capacity()) };
220                let dx = npoints / nx;
221                let mut y = window * (ny - 1);
222                let mut total = 0usize;
223
224                while total < nx {
225                    grid[total].0.x = total * dx;
226                    grid[total].0.dx = dx;
227                    grid[total].0.y = y;
228                    grid[total].0.dy = nbits - y;
229                    total += 1;
230                }
231                grid[total - 1].0.dx = npoints - grid[total - 1].0.x;
232                while y != 0 {
233                    y -= window;
234                    for i in 0..nx {
235                        grid[total].0.x = grid[i].0.x;
236                        grid[total].0.dx = grid[i].0.dx;
237                        grid[total].0.y = y;
238                        grid[total].0.dy = window;
239                        total += 1;
240                    }
241                }
242                let grid = &grid[..];
243
244                let points = &self[..];
245                let sz = unsafe { $scratch_sizeof(0) / 8 };
246
247                let mut row_sync: Vec<AtomicUsize> = Vec::with_capacity(ny);
248                row_sync.resize_with(ny, Default::default);
249                let row_sync = Arc::new(row_sync);
250                let counter = Arc::new(AtomicUsize::new(0));
251                let (tx, rx) = channel();
252                let n_workers = core::cmp::min(ncpus, total);
253                for _ in 0..n_workers {
254                    let tx = tx.clone();
255                    let counter = counter.clone();
256                    let row_sync = row_sync.clone();
257
258                    pool.joined_execute(move || {
259                        let mut scratch = vec![0u64; sz << (window - 1)];
260                        let mut p: [*const $point_affine; 2] =
261                            [ptr::null(), ptr::null()];
262                        let mut s: [*const u8; 2] = [ptr::null(), ptr::null()];
263
264                        loop {
265                            let work = counter.fetch_add(1, Ordering::Relaxed);
266                            if work >= total {
267                                break;
268                            }
269                            let x = grid[work].0.x;
270                            let y = grid[work].0.y;
271
272                            p[0] = &points[x];
273                            s[0] = &scalars[x * nbytes];
274                            unsafe {
275                                $tile_mult(
276                                    grid[work].1.as_ptr(),
277                                    &p[0],
278                                    grid[work].0.dx,
279                                    &s[0],
280                                    nbits,
281                                    &mut scratch[0],
282                                    y,
283                                    window,
284                                );
285                            }
286                            if row_sync[y / window]
287                                .fetch_add(1, Ordering::AcqRel)
288                                == nx - 1
289                            {
290                                tx.send(y).expect("disaster");
291                            }
292                        }
293                    });
294                }
295
296                let mut ret = <$point>::default();
297                let mut rows = vec![false; ny];
298                let mut row = 0usize;
299                for _ in 0..ny {
300                    let mut y = rx.recv().unwrap();
301                    rows[y / window] = true;
302                    while grid[row].0.y == y {
303                        while row < total && grid[row].0.y == y {
304                            unsafe {
305                                $add_or_double(
306                                    &mut ret,
307                                    &ret,
308                                    grid[row].1.as_ptr(),
309                                );
310                            }
311                            row += 1;
312                        }
313                        if y == 0 {
314                            break;
315                        }
316                        for _ in 0..window {
317                            unsafe { $double(&mut ret, &ret) };
318                        }
319                        y -= window;
320                        if !rows[y / window] {
321                            break;
322                        }
323                    }
324                }
325                ret
326            }
327
328            fn add(&self) -> $point {
329                let npoints = self.len();
330
331                let pool = mt::da_pool();
332                let ncpus = pool.max_count();
333                if ncpus < 2 || npoints < 384 {
334                    let p: [*const _; 2] = [&self[0], ptr::null()];
335                    let mut ret = <$point>::default();
336                    unsafe { $add(&mut ret, &p[0], npoints) };
337                    return ret;
338                }
339
340                let (tx, rx) = channel();
341                let counter = Arc::new(AtomicUsize::new(0));
342                let nchunks = (npoints + 255) / 256;
343                let chunk = npoints / nchunks + 1;
344
345                let n_workers = core::cmp::min(ncpus, nchunks);
346                for _ in 0..n_workers {
347                    let tx = tx.clone();
348                    let counter = counter.clone();
349
350                    pool.joined_execute(move || {
351                        let mut acc = <$point>::default();
352                        let mut chunk = chunk;
353                        let mut p: [*const _; 2] = [ptr::null(), ptr::null()];
354
355                        loop {
356                            let work =
357                                counter.fetch_add(chunk, Ordering::Relaxed);
358                            if work >= npoints {
359                                break;
360                            }
361                            p[0] = &self[work];
362                            if work + chunk > npoints {
363                                chunk = npoints - work;
364                            }
365                            unsafe {
366                                let mut t = MaybeUninit::<$point>::uninit();
367                                $add(t.as_mut_ptr(), &p[0], chunk);
368                                $add_or_double(&mut acc, &acc, t.as_ptr());
369                            };
370                        }
371                        tx.send(acc).expect("disaster");
372                    });
373                }
374
375                let mut ret = rx.recv().unwrap();
376                for _ in 1..n_workers {
377                    unsafe {
378                        $add_or_double(&mut ret, &ret, &rx.recv().unwrap())
379                    };
380                }
381
382                ret
383            }
384
385            fn validate(&self) -> Result<(), BLST_ERROR> {
386                fn check(point: &$point_affine) -> Result<(), BLST_ERROR> {
387                    if unsafe { $is_inf(point) } {
388                        return Err(BLST_ERROR::BLST_PK_IS_INFINITY);
389                    }
390                    if !unsafe { $in_group(point) } {
391                        return Err(BLST_ERROR::BLST_POINT_NOT_IN_GROUP);
392                    }
393                    Ok(())
394                }
395
396                let npoints = self.len();
397
398                let pool = mt::da_pool();
399                let n_workers = core::cmp::min(npoints, pool.max_count());
400                if n_workers < 2 {
401                    for i in 0..npoints {
402                        check(&self[i])?
403                    }
404                    return Ok(())
405                }
406
407                let counter = Arc::new(AtomicUsize::new(0));
408                let valid = Arc::new(AtomicBool::new(true));
409                let wg =
410                    Arc::new((Barrier::new(2), AtomicUsize::new(n_workers)));
411
412                for _ in 0..n_workers {
413                    let counter = counter.clone();
414                    let valid = valid.clone();
415                    let wg = wg.clone();
416
417                    pool.joined_execute(move || {
418                        while valid.load(Ordering::Relaxed) {
419                            let work = counter.fetch_add(1, Ordering::Relaxed);
420                            if work >= npoints {
421                                break;
422                            }
423
424                            if check(&self[work]).is_err() {
425                                valid.store(false, Ordering::Relaxed);
426                                break;
427                            }
428                        }
429
430                        if wg.1.fetch_sub(1, Ordering::AcqRel) == 1 {
431                            wg.0.wait();
432                        }
433                    });
434                }
435
436                wg.0.wait();
437
438                if valid.load(Ordering::Relaxed) {
439                    return Ok(());
440                } else {
441                    return Err(BLST_ERROR::BLST_POINT_NOT_IN_GROUP);
442                }
443            }
444        }
445
446        #[cfg(test)]
447        pippenger_test_mod!(
448            $test_mod,
449            $points,
450            $point,
451            $add_or_double,
452            $generator,
453            $mult,
454        );
455    };
456}
457
458#[cfg(test)]
459include!("pippenger-test_mod.rs");
460
461pippenger_mult_impl!(
462    p1_affines,
463    blst_p1,
464    blst_p1_affine,
465    blst_p1s_to_affine,
466    blst_p1s_mult_pippenger_scratch_sizeof,
467    blst_p1s_mult_pippenger,
468    blst_p1s_tile_pippenger,
469    blst_p1_add_or_double,
470    blst_p1_double,
471    p1_multi_point,
472    blst_p1_generator,
473    blst_p1_mult,
474    blst_p1s_add,
475    blst_p1_affine_is_inf,
476    blst_p1_affine_in_g1,
477    blst_p1_from_affine,
478);
479
480pippenger_mult_impl!(
481    p2_affines,
482    blst_p2,
483    blst_p2_affine,
484    blst_p2s_to_affine,
485    blst_p2s_mult_pippenger_scratch_sizeof,
486    blst_p2s_mult_pippenger,
487    blst_p2s_tile_pippenger,
488    blst_p2_add_or_double,
489    blst_p2_double,
490    p2_multi_point,
491    blst_p2_generator,
492    blst_p2_mult,
493    blst_p2s_add,
494    blst_p2_affine_is_inf,
495    blst_p2_affine_in_g2,
496    blst_p2_from_affine,
497);
498
499fn num_bits(l: usize) -> usize {
500    8 * core::mem::size_of_val(&l) - l.leading_zeros() as usize
501}
502
503fn breakdown(
504    nbits: usize,
505    window: usize,
506    ncpus: usize,
507) -> (usize, usize, usize) {
508    let mut nx: usize;
509    let mut wnd: usize;
510
511    if nbits > window * ncpus {
512        nx = 1;
513        wnd = num_bits(ncpus / 4);
514        if (window + wnd) > 18 {
515            wnd = window - wnd;
516        } else {
517            wnd = (nbits / window + ncpus - 1) / ncpus;
518            if (nbits / (window + 1) + ncpus - 1) / ncpus < wnd {
519                wnd = window + 1;
520            } else {
521                wnd = window;
522            }
523        }
524    } else {
525        nx = 2;
526        wnd = window - 2;
527        while (nbits / wnd + 1) * nx < ncpus {
528            nx += 1;
529            wnd = window - num_bits(3 * nx / 2);
530        }
531        nx -= 1;
532        wnd = window - num_bits(3 * nx / 2);
533    }
534    let ny = nbits / wnd + 1;
535    wnd = nbits / ny + 1;
536
537    (nx, ny, wnd)
538}
539
540fn pippenger_window_size(npoints: usize) -> usize {
541    let wbits = num_bits(npoints);
542
543    if wbits > 13 {
544        return wbits - 4;
545    }
546    if wbits > 5 {
547        return wbits - 3;
548    }
549    2
550}