1use std::ops::Neg;
2
3use crate::CurveAffine;
4use ff::Field;
5use ff::PrimeField;
6use group::Group;
7use rayon::iter::{
8 IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
9};
10
11const BATCH_SIZE: usize = 64;
12
13fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
14 let skip_bits = (window_index * window_size).saturating_sub(1);
24 let skip_bytes = skip_bits / 8;
25
26 let mut v: [u8; 4] = [0; 4];
28 for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
29 *dst = *src
30 }
31 let mut tmp = u32::from_le_bytes(v);
32
33 if window_index == 0 {
35 tmp <<= 1;
36 }
37
38 tmp >>= skip_bits - (skip_bytes * 8);
40 tmp &= (1 << (window_size + 1)) - 1;
42
43 let sign = tmp & (1 << window_size) == 0;
44
45 tmp = (tmp + 1) >> 1;
47
48 if sign {
50 tmp as i32
51 } else {
52 ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
53 }
54}
55
56fn batch_add<C: CurveAffine>(
58 size: usize,
59 buckets: &mut [BucketAffine<C>],
60 points: &[SchedulePoint],
61 bases: &[Affine<C>],
62) {
63 let mut t = vec![C::Base::ZERO; size]; let mut z = vec![C::Base::ZERO; size]; let mut acc = C::Base::ONE;
66
67 for (
68 (
69 SchedulePoint {
70 base_idx,
71 buck_idx,
72 sign,
73 },
74 t,
75 ),
76 z,
77 ) in points.iter().zip(t.iter_mut()).zip(z.iter_mut())
78 {
79 if buckets[*buck_idx].is_inf() {
80 continue;
82 }
83
84 if buckets[*buck_idx].x() == bases[*base_idx].x {
85 if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign {
91 let x_squared = bases[*base_idx].x.square();
93 *z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); *t = acc * (x_squared + x_squared + x_squared); acc *= *z;
96 continue;
97 }
98 buckets[*buck_idx].set_inf();
100 continue;
101 }
102 *z = buckets[*buck_idx].x() - bases[*base_idx].x; if *sign {
105 *t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y);
106 } else {
107 *t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y);
108 } acc *= *z;
110 }
111
112 acc = acc
113 .invert()
114 .expect("Some edge case has not been handled properly");
115
116 for (
117 (
118 SchedulePoint {
119 base_idx,
120 buck_idx,
121 sign,
122 },
123 t,
124 ),
125 z,
126 ) in points.iter().zip(t.iter()).zip(z.iter()).rev()
127 {
128 if buckets[*buck_idx].is_inf() {
129 continue;
131 }
132 let lambda = acc * t;
133 acc *= z; let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); if *sign {
136 buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y));
137 } else {
138 buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y));
139 } buckets[*buck_idx].set_x(&x);
141 }
142}
143
144#[derive(Debug, Clone, Copy)]
145struct Affine<C: CurveAffine> {
146 x: C::Base,
147 y: C::Base,
148}
149
150impl<C: CurveAffine> Affine<C> {
151 fn from(point: &C) -> Self {
152 let coords = point.coordinates().unwrap();
153
154 Self {
155 x: *coords.x(),
156 y: *coords.y(),
157 }
158 }
159
160 fn neg(&self) -> Self {
161 Self {
162 x: self.x,
163 y: -self.y,
164 }
165 }
166
167 fn eval(&self) -> C {
168 C::from_xy(self.x, self.y).unwrap()
169 }
170}
171
172#[derive(Debug, Clone)]
173enum BucketAffine<C: CurveAffine> {
174 None,
175 Point(Affine<C>),
176}
177
178#[derive(Debug, Clone)]
179enum Bucket<C: CurveAffine> {
180 None,
181 Point(C::Curve),
182}
183
184impl<C: CurveAffine> Bucket<C> {
185 fn add_assign(&mut self, point: &C, sign: bool) {
186 *self = match *self {
187 Bucket::None => Bucket::Point({
188 if sign {
189 point.to_curve()
190 } else {
191 point.to_curve().neg()
192 }
193 }),
194 Bucket::Point(a) => {
195 if sign {
196 Self::Point(a + point)
197 } else {
198 Self::Point(a - point)
199 }
200 }
201 }
202 }
203
204 fn add(&self, other: &BucketAffine<C>) -> C::Curve {
205 match (self, other) {
206 (Self::Point(this), BucketAffine::Point(other)) => *this + other.eval(),
207 (Self::Point(this), BucketAffine::None) => *this,
208 (Self::None, BucketAffine::Point(other)) => other.eval().to_curve(),
209 (Self::None, BucketAffine::None) => C::Curve::identity(),
210 }
211 }
212}
213
214impl<C: CurveAffine> BucketAffine<C> {
215 fn assign(&mut self, point: &Affine<C>, sign: bool) -> bool {
216 match *self {
217 Self::None => {
218 *self = Self::Point(if sign { *point } else { point.neg() });
219 true
220 }
221 Self::Point(_) => false,
222 }
223 }
224
225 fn x(&self) -> C::Base {
226 match self {
227 Self::None => panic!("::x None"),
228 Self::Point(a) => a.x,
229 }
230 }
231
232 fn y(&self) -> C::Base {
233 match self {
234 Self::None => panic!("::y None"),
235 Self::Point(a) => a.y,
236 }
237 }
238
239 fn is_inf(&self) -> bool {
240 match self {
241 Self::None => true,
242 Self::Point(_) => false,
243 }
244 }
245
246 fn set_x(&mut self, x: &C::Base) {
247 match self {
248 Self::None => panic!("::set_x None"),
249 Self::Point(ref mut a) => a.x = *x,
250 }
251 }
252
253 fn set_y(&mut self, y: &C::Base) {
254 match self {
255 Self::None => panic!("::set_y None"),
256 Self::Point(ref mut a) => a.y = *y,
257 }
258 }
259
260 fn set_inf(&mut self) {
261 match self {
262 Self::None => {}
263 Self::Point(_) => *self = Self::None,
264 }
265 }
266}
267
268struct Schedule<C: CurveAffine> {
269 buckets: Vec<BucketAffine<C>>,
270 set: [SchedulePoint; BATCH_SIZE],
271 ptr: usize,
272}
273
274#[derive(Debug, Clone, Default)]
275struct SchedulePoint {
276 base_idx: usize,
277 buck_idx: usize,
278 sign: bool,
279}
280
281impl SchedulePoint {
282 fn new(base_idx: usize, buck_idx: usize, sign: bool) -> Self {
283 Self {
284 base_idx,
285 buck_idx,
286 sign,
287 }
288 }
289}
290
291impl<C: CurveAffine> Schedule<C> {
292 fn new(c: usize) -> Self {
293 let set = (0..BATCH_SIZE)
294 .map(|_| SchedulePoint::default())
295 .collect::<Vec<_>>()
296 .try_into()
297 .unwrap();
298
299 Self {
300 buckets: vec![BucketAffine::None; 1 << (c - 1)],
301 set,
302 ptr: 0,
303 }
304 }
305
306 fn contains(&self, buck_idx: usize) -> bool {
307 self.set.iter().any(|sch| sch.buck_idx == buck_idx)
308 }
309
310 fn execute(&mut self, bases: &[Affine<C>]) {
311 if self.ptr != 0 {
312 batch_add(self.ptr, &mut self.buckets, &self.set, bases);
313 self.ptr = 0;
314 self.set
315 .iter_mut()
316 .for_each(|sch| *sch = SchedulePoint::default());
317 }
318 }
319
320 fn add(&mut self, bases: &[Affine<C>], base_idx: usize, buck_idx: usize, sign: bool) {
321 if !self.buckets[buck_idx].assign(&bases[base_idx], sign) {
322 self.set[self.ptr] = SchedulePoint::new(base_idx, buck_idx, sign);
323 self.ptr += 1;
324 }
325
326 if self.ptr == self.set.len() {
327 self.execute(bases);
328 }
329 }
330}
331
332pub fn msm_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
336 let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
337
338 let c = if bases.len() < 4 {
339 1
340 } else if bases.len() < 32 {
341 3
342 } else {
343 (f64::from(bases.len() as u32)).ln().ceil() as usize
344 };
345
346 let field_byte_size = C::Scalar::NUM_BITS.div_ceil(8u32) as usize;
347 let mut acc_or = vec![0; field_byte_size];
350 for coeff in &coeffs {
351 for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
352 *acc_limb |= *limb;
353 }
354 }
355 let max_byte_size = field_byte_size
356 - acc_or
357 .iter()
358 .rev()
359 .position(|v| *v != 0)
360 .unwrap_or(field_byte_size);
361 if max_byte_size == 0 {
362 return;
363 }
364 let number_of_windows = max_byte_size * 8_usize / c + 1;
365
366 for current_window in (0..number_of_windows).rev() {
367 for _ in 0..c {
368 *acc = acc.double();
369 }
370
371 #[derive(Clone, Copy)]
372 enum Bucket<C: CurveAffine> {
373 None,
374 Affine(C),
375 Projective(C::Curve),
376 }
377
378 impl<C: CurveAffine> Bucket<C> {
379 fn add_assign(&mut self, other: &C) {
380 *self = match *self {
381 Bucket::None => Bucket::Affine(*other),
382 Bucket::Affine(a) => Bucket::Projective(a + *other),
383 Bucket::Projective(mut a) => {
384 a += *other;
385 Bucket::Projective(a)
386 }
387 }
388 }
389
390 fn add(self, mut other: C::Curve) -> C::Curve {
391 match self {
392 Bucket::None => other,
393 Bucket::Affine(a) => {
394 other += a;
395 other
396 }
397 Bucket::Projective(a) => other + a,
398 }
399 }
400 }
401
402 let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];
403
404 for (coeff, base) in coeffs.iter().zip(bases.iter()) {
405 let coeff = get_booth_index(current_window, c, coeff.as_ref());
406 if coeff.is_positive() {
407 buckets[coeff as usize - 1].add_assign(base);
408 }
409 if coeff.is_negative() {
410 buckets[coeff.unsigned_abs() as usize - 1].add_assign(&base.neg());
411 }
412 }
413
414 let mut running_sum = C::Curve::identity();
419 for exp in buckets.into_iter().rev() {
420 running_sum = exp.add(running_sum);
421 *acc += &running_sum;
422 }
423 }
424}
425
426pub fn msm_parallel<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
432 assert_eq!(coeffs.len(), bases.len());
433
434 let num_threads = rayon::current_num_threads();
435 if coeffs.len() > num_threads {
436 let chunk = coeffs.len() / num_threads;
437 let num_chunks = coeffs.chunks(chunk).len();
438 let mut results = vec![C::Curve::identity(); num_chunks];
439 rayon::scope(|scope| {
440 let chunk = coeffs.len() / num_threads;
441
442 for ((coeffs, bases), acc) in coeffs
443 .chunks(chunk)
444 .zip(bases.chunks(chunk))
445 .zip(results.iter_mut())
446 {
447 scope.spawn(move |_| {
448 msm_serial(coeffs, bases, acc);
449 });
450 }
451 });
452 results.iter().fold(C::Curve::identity(), |a, b| a + b)
453 } else {
454 let mut acc = C::Curve::identity();
455 msm_serial(coeffs, bases, &mut acc);
456 acc
457 }
458}
459
460pub fn msm_best<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
464 assert_eq!(coeffs.len(), bases.len());
465
466 let c = if bases.len() < 4 {
468 1
469 } else if bases.len() < 32 {
470 3
471 } else {
472 (f64::from(bases.len() as u32)).ln().ceil() as usize
473 };
474
475 if c < 10 {
476 return msm_parallel(coeffs, bases);
477 }
478
479 let coeffs: Vec<_> = coeffs.par_iter().map(|a| a.to_repr()).collect();
481 let bases_local: Vec<_> = bases.par_iter().map(Affine::from).collect();
483
484 let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;
486 let mut acc = vec![C::Curve::identity(); number_of_windows];
488 acc.par_iter_mut().enumerate().rev().for_each(|(w, acc)| {
489 let mut j_bucks = vec![Bucket::<C>::None; 1 << (c - 1)];
491
492 let mut sched = Schedule::new(c);
494
495 for (base_idx, coeff) in coeffs.iter().enumerate() {
496 let buck_idx = get_booth_index(w, c, coeff.as_ref());
497
498 if buck_idx != 0 {
499 let sign = buck_idx.is_positive();
501 let buck_idx = buck_idx.unsigned_abs() as usize - 1;
502
503 if sched.contains(buck_idx) {
504 j_bucks[buck_idx].add_assign(&bases[base_idx], sign);
507 } else {
508 sched.add(&bases_local, base_idx, buck_idx, sign);
510 }
511 }
512 }
513
514 sched.execute(&bases_local);
516
517 let mut running_sum = C::Curve::identity();
522 for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() {
523 running_sum += j_buck.add(a_buck);
524 *acc += running_sum;
525 }
526
527 for _ in 0..c * w {
529 *acc = acc.double();
530 }
531 });
532 acc.into_iter().sum::<_>()
533}
534
535#[cfg(test)]
536mod test {
537 use std::ops::Neg;
538
539 use crate::bn256::{Fr, G1Affine, G1};
540 use ark_std::{end_timer, start_timer};
541 use ff::{Field, PrimeField};
542 use group::{Curve, Group};
543 use pasta_curves::arithmetic::CurveAffine;
544 use rand_core::OsRng;
545
546 #[test]
547 fn test_booth_encoding() {
548 fn mul(scalar: &Fr, point: &G1Affine, window: usize) -> G1Affine {
549 let u = scalar.to_repr();
550 let n = Fr::NUM_BITS as usize / window + 1;
551
552 let table = (0..=1 << (window - 1))
553 .map(|i| point * Fr::from(i as u64))
554 .collect::<Vec<_>>();
555
556 let mut acc = G1::identity();
557 for i in (0..n).rev() {
558 for _ in 0..window {
559 acc = acc.double();
560 }
561
562 let idx = super::get_booth_index(i, window, u.as_ref());
563
564 if idx.is_negative() {
565 acc += table[idx.unsigned_abs() as usize].neg();
566 }
567 if idx.is_positive() {
568 acc += table[idx.unsigned_abs() as usize];
569 }
570 }
571
572 acc.to_affine()
573 }
574
575 let (scalars, points): (Vec<_>, Vec<_>) = (0..10)
576 .map(|_| {
577 let scalar = Fr::random(OsRng);
578 let point = G1Affine::random(OsRng);
579 (scalar, point)
580 })
581 .unzip();
582
583 for window in 1..10 {
584 for (scalar, point) in scalars.iter().zip(points.iter()) {
585 let c0 = mul(scalar, point, window);
586 let c1 = point * scalar;
587 assert_eq!(c0, c1.to_affine());
588 }
589 }
590 }
591
592 fn run_msm_cross<C: CurveAffine>(min_k: usize, max_k: usize) {
593 use rayon::iter::{IntoParallelIterator, ParallelIterator};
594
595 let points = (0..1 << max_k)
596 .into_par_iter()
597 .map(|_| C::Curve::random(OsRng))
598 .collect::<Vec<_>>();
599 let mut affine_points = vec![C::identity(); 1 << max_k];
600 C::Curve::batch_normalize(&points[..], &mut affine_points[..]);
601 let points = affine_points;
602
603 let scalars = (0..1 << max_k)
604 .into_par_iter()
605 .map(|_| C::Scalar::random(OsRng))
606 .collect::<Vec<_>>();
607
608 for k in min_k..=max_k {
609 let points = &points[..1 << k];
610 let scalars = &scalars[..1 << k];
611
612 let t0 = start_timer!(|| format!("cyclone indep k={}", k));
613 let e0 = super::msm_best(scalars, points);
614 end_timer!(t0);
615
616 let t1 = start_timer!(|| format!("older k={}", k));
617 let e1 = super::msm_parallel(scalars, points);
618 end_timer!(t1);
619 assert_eq!(e0, e1);
620 }
621 }
622
623 #[test]
624 fn test_msm_cross() {
625 run_msm_cross::<G1Affine>(14, 22);
626 }
627}