1use std::ops::Neg;
2
3use ff::{Field, PrimeField};
4use group::Group;
5use rayon::iter::{
6 IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
7};
8
9use crate::CurveAffine;
10
11const BATCH_SIZE: usize = 64;
12
13fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
14 let skip_bits = (window_index * window_size).saturating_sub(1);
25 let skip_bytes = skip_bits / 8;
26
27 let mut v: [u8; 4] = [0; 4];
29 for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
30 *dst = *src
31 }
32 let mut tmp = u32::from_le_bytes(v);
33
34 if window_index == 0 {
36 tmp <<= 1;
37 }
38
39 tmp >>= skip_bits - (skip_bytes * 8);
41 tmp &= (1 << (window_size + 1)) - 1;
43
44 let sign = tmp & (1 << window_size) == 0;
45
46 tmp = (tmp + 1) >> 1;
48
49 if sign {
51 tmp as i32
52 } else {
53 ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
54 }
55}
56
57fn batch_add<C: CurveAffine>(
59 size: usize,
60 buckets: &mut [BucketAffine<C>],
61 points: &[SchedulePoint],
62 bases: &[Affine<C>],
63) {
64 let mut t = vec![C::Base::ZERO; size]; let mut z = vec![C::Base::ZERO; size]; let mut acc = C::Base::ONE;
67
68 for (
69 (
70 SchedulePoint {
71 base_idx,
72 buck_idx,
73 sign,
74 },
75 t,
76 ),
77 z,
78 ) in points.iter().zip(t.iter_mut()).zip(z.iter_mut())
79 {
80 if buckets[*buck_idx].is_inf() {
81 continue;
83 }
84
85 if buckets[*buck_idx].x() == bases[*base_idx].x {
86 if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign {
92 let x_squared = bases[*base_idx].x.square();
94 *z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); *t = acc * (x_squared + x_squared + x_squared); acc *= *z;
97 continue;
98 }
99 buckets[*buck_idx].set_inf();
101 continue;
102 }
103 *z = buckets[*buck_idx].x() - bases[*base_idx].x; if *sign {
106 *t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y);
107 } else {
108 *t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y);
109 } acc *= *z;
111 }
112
113 acc = acc
114 .invert()
115 .expect("Some edge case has not been handled properly");
116
117 for (
118 (
119 SchedulePoint {
120 base_idx,
121 buck_idx,
122 sign,
123 },
124 t,
125 ),
126 z,
127 ) in points.iter().zip(t.iter()).zip(z.iter()).rev()
128 {
129 if buckets[*buck_idx].is_inf() {
130 continue;
132 }
133 let lambda = acc * t;
134 acc *= z; let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); if *sign {
137 buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y));
138 } else {
139 buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y));
140 } buckets[*buck_idx].set_x(&x);
142 }
143}
144
145#[derive(Debug, Clone, Copy)]
146struct Affine<C: CurveAffine> {
147 x: C::Base,
148 y: C::Base,
149}
150
151impl<C: CurveAffine> Affine<C> {
152 fn from(point: &C) -> Self {
153 let coords = point.coordinates().unwrap();
154
155 Self {
156 x: *coords.x(),
157 y: *coords.y(),
158 }
159 }
160
161 fn neg(&self) -> Self {
162 Self {
163 x: self.x,
164 y: -self.y,
165 }
166 }
167
168 fn eval(&self) -> C {
169 C::from_xy(self.x, self.y).unwrap()
170 }
171}
172
173#[derive(Debug, Clone)]
174enum BucketAffine<C: CurveAffine> {
175 None,
176 Point(Affine<C>),
177}
178
179#[derive(Debug, Clone)]
180enum Bucket<C: CurveAffine> {
181 None,
182 Point(C::Curve),
183}
184
185impl<C: CurveAffine> Bucket<C> {
186 fn add_assign(&mut self, point: &C, sign: bool) {
187 *self = match *self {
188 Bucket::None => Bucket::Point({
189 if sign {
190 point.to_curve()
191 } else {
192 point.to_curve().neg()
193 }
194 }),
195 Bucket::Point(a) => {
196 if sign {
197 Self::Point(a + point)
198 } else {
199 Self::Point(a - point)
200 }
201 }
202 }
203 }
204
205 fn add(&self, other: &BucketAffine<C>) -> C::Curve {
206 match (self, other) {
207 (Self::Point(this), BucketAffine::Point(other)) => *this + other.eval(),
208 (Self::Point(this), BucketAffine::None) => *this,
209 (Self::None, BucketAffine::Point(other)) => other.eval().to_curve(),
210 (Self::None, BucketAffine::None) => C::Curve::identity(),
211 }
212 }
213}
214
215impl<C: CurveAffine> BucketAffine<C> {
216 fn assign(&mut self, point: &Affine<C>, sign: bool) -> bool {
217 match *self {
218 Self::None => {
219 *self = Self::Point(if sign { *point } else { point.neg() });
220 true
221 }
222 Self::Point(_) => false,
223 }
224 }
225
226 fn x(&self) -> C::Base {
227 match self {
228 Self::None => panic!("::x None"),
229 Self::Point(a) => a.x,
230 }
231 }
232
233 fn y(&self) -> C::Base {
234 match self {
235 Self::None => panic!("::y None"),
236 Self::Point(a) => a.y,
237 }
238 }
239
240 fn is_inf(&self) -> bool {
241 match self {
242 Self::None => true,
243 Self::Point(_) => false,
244 }
245 }
246
247 fn set_x(&mut self, x: &C::Base) {
248 match self {
249 Self::None => panic!("::set_x None"),
250 Self::Point(ref mut a) => a.x = *x,
251 }
252 }
253
254 fn set_y(&mut self, y: &C::Base) {
255 match self {
256 Self::None => panic!("::set_y None"),
257 Self::Point(ref mut a) => a.y = *y,
258 }
259 }
260
261 fn set_inf(&mut self) {
262 match self {
263 Self::None => {}
264 Self::Point(_) => *self = Self::None,
265 }
266 }
267}
268
269struct Schedule<C: CurveAffine> {
270 buckets: Vec<BucketAffine<C>>,
271 set: [SchedulePoint; BATCH_SIZE],
272 ptr: usize,
273}
274
275#[derive(Debug, Clone, Default)]
276struct SchedulePoint {
277 base_idx: usize,
278 buck_idx: usize,
279 sign: bool,
280}
281
282impl SchedulePoint {
283 fn new(base_idx: usize, buck_idx: usize, sign: bool) -> Self {
284 Self {
285 base_idx,
286 buck_idx,
287 sign,
288 }
289 }
290}
291
292impl<C: CurveAffine> Schedule<C> {
293 fn new(c: usize) -> Self {
294 let set = (0..BATCH_SIZE)
295 .map(|_| SchedulePoint::default())
296 .collect::<Vec<_>>()
297 .try_into()
298 .unwrap();
299
300 Self {
301 buckets: vec![BucketAffine::None; 1 << (c - 1)],
302 set,
303 ptr: 0,
304 }
305 }
306
307 fn contains(&self, buck_idx: usize) -> bool {
308 self.set.iter().any(|sch| sch.buck_idx == buck_idx)
309 }
310
311 fn execute(&mut self, bases: &[Affine<C>]) {
312 if self.ptr != 0 {
313 batch_add(self.ptr, &mut self.buckets, &self.set, bases);
314 self.ptr = 0;
315 self.set
316 .iter_mut()
317 .for_each(|sch| *sch = SchedulePoint::default());
318 }
319 }
320
321 fn add(&mut self, bases: &[Affine<C>], base_idx: usize, buck_idx: usize, sign: bool) {
322 if !self.buckets[buck_idx].assign(&bases[base_idx], sign) {
323 self.set[self.ptr] = SchedulePoint::new(base_idx, buck_idx, sign);
324 self.ptr += 1;
325 }
326
327 if self.ptr == self.set.len() {
328 self.execute(bases);
329 }
330 }
331}
332
333pub fn msm_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
337 let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
338
339 let c = if bases.len() < 4 {
340 1
341 } else if bases.len() < 32 {
342 3
343 } else {
344 (f64::from(bases.len() as u32)).ln().ceil() as usize
345 };
346
347 let field_byte_size = C::Scalar::NUM_BITS.div_ceil(8u32) as usize;
348 let mut acc_or = vec![0; field_byte_size];
351 for coeff in &coeffs {
352 for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
353 *acc_limb |= *limb;
354 }
355 }
356 let max_byte_size = field_byte_size
357 - acc_or
358 .iter()
359 .rev()
360 .position(|v| *v != 0)
361 .unwrap_or(field_byte_size);
362 if max_byte_size == 0 {
363 return;
364 }
365 let number_of_windows = max_byte_size * 8_usize / c + 1;
366
367 for current_window in (0..number_of_windows).rev() {
368 for _ in 0..c {
369 *acc = acc.double();
370 }
371
372 #[derive(Clone, Copy)]
373 enum Bucket<C: CurveAffine> {
374 None,
375 Affine(C),
376 Projective(C::Curve),
377 }
378
379 impl<C: CurveAffine> Bucket<C> {
380 fn add_assign(&mut self, other: &C) {
381 *self = match *self {
382 Bucket::None => Bucket::Affine(*other),
383 Bucket::Affine(a) => Bucket::Projective(a + *other),
384 Bucket::Projective(mut a) => {
385 a += *other;
386 Bucket::Projective(a)
387 }
388 }
389 }
390
391 fn add(self, mut other: C::Curve) -> C::Curve {
392 match self {
393 Bucket::None => other,
394 Bucket::Affine(a) => {
395 other += a;
396 other
397 }
398 Bucket::Projective(a) => other + a,
399 }
400 }
401 }
402
403 let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];
404
405 for (coeff, base) in coeffs.iter().zip(bases.iter()) {
406 let coeff = get_booth_index(current_window, c, coeff.as_ref());
407 if coeff.is_positive() {
408 buckets[coeff as usize - 1].add_assign(base);
409 }
410 if coeff.is_negative() {
411 buckets[coeff.unsigned_abs() as usize - 1].add_assign(&base.neg());
412 }
413 }
414
415 let mut running_sum = C::Curve::identity();
420 for exp in buckets.into_iter().rev() {
421 running_sum = exp.add(running_sum);
422 *acc += &running_sum;
423 }
424 }
425}
426
427pub fn msm_parallel<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
433 assert_eq!(coeffs.len(), bases.len());
434
435 let num_threads = rayon::current_num_threads();
436 if coeffs.len() > num_threads {
437 let chunk = coeffs.len() / num_threads;
438 let num_chunks = coeffs.chunks(chunk).len();
439 let mut results = vec![C::Curve::identity(); num_chunks];
440 rayon::scope(|scope| {
441 let chunk = coeffs.len() / num_threads;
442
443 for ((coeffs, bases), acc) in coeffs
444 .chunks(chunk)
445 .zip(bases.chunks(chunk))
446 .zip(results.iter_mut())
447 {
448 scope.spawn(move |_| {
449 msm_serial(coeffs, bases, acc);
450 });
451 }
452 });
453 results.iter().fold(C::Curve::identity(), |a, b| a + b)
454 } else {
455 let mut acc = C::Curve::identity();
456 msm_serial(coeffs, bases, &mut acc);
457 acc
458 }
459}
460
461pub fn msm_best<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
465 assert_eq!(coeffs.len(), bases.len());
466
467 let c = if bases.len() < 4 {
469 1
470 } else if bases.len() < 32 {
471 3
472 } else {
473 (f64::from(bases.len() as u32)).ln().ceil() as usize
474 };
475
476 if c < 10 {
477 return msm_parallel(coeffs, bases);
478 }
479
480 let coeffs: Vec<_> = coeffs.par_iter().map(|a| a.to_repr()).collect();
482 let bases_local: Vec<_> = bases.par_iter().map(Affine::from).collect();
484
485 let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;
487 let mut acc = vec![C::Curve::identity(); number_of_windows];
489 acc.par_iter_mut().enumerate().rev().for_each(|(w, acc)| {
490 let mut j_bucks = vec![Bucket::<C>::None; 1 << (c - 1)];
492
493 let mut sched = Schedule::new(c);
495
496 for (base_idx, coeff) in coeffs.iter().enumerate() {
497 let buck_idx = get_booth_index(w, c, coeff.as_ref());
498
499 if buck_idx != 0 {
500 let sign = buck_idx.is_positive();
502 let buck_idx = buck_idx.unsigned_abs() as usize - 1;
503
504 if sched.contains(buck_idx) {
505 j_bucks[buck_idx].add_assign(&bases[base_idx], sign);
508 } else {
509 sched.add(&bases_local, base_idx, buck_idx, sign);
511 }
512 }
513 }
514
515 sched.execute(&bases_local);
517
518 let mut running_sum = C::Curve::identity();
523 for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() {
524 running_sum += j_buck.add(a_buck);
525 *acc += running_sum;
526 }
527
528 for _ in 0..c * w {
530 *acc = acc.double();
531 }
532 });
533 acc.into_iter().sum::<_>()
534}
535
536#[cfg(test)]
537mod test {
538 use std::ops::Neg;
539
540 use ark_std::{end_timer, start_timer};
541 use ff::{Field, PrimeField};
542 use group::{Curve, Group};
543 use rand_core::OsRng;
544
545 use crate::{
546 bn256::{Fr, G1Affine, G1},
547 CurveAffine,
548 };
549
550 #[test]
551 fn test_booth_encoding() {
552 fn mul(scalar: &Fr, point: &G1Affine, window: usize) -> G1Affine {
553 let u = scalar.to_repr();
554 let n = Fr::NUM_BITS as usize / window + 1;
555
556 let table = (0..=1 << (window - 1))
557 .map(|i| point * Fr::from(i as u64))
558 .collect::<Vec<_>>();
559
560 let mut acc = G1::identity();
561 for i in (0..n).rev() {
562 for _ in 0..window {
563 acc = acc.double();
564 }
565
566 let idx = super::get_booth_index(i, window, u.as_ref());
567
568 if idx.is_negative() {
569 acc += table[idx.unsigned_abs() as usize].neg();
570 }
571 if idx.is_positive() {
572 acc += table[idx.unsigned_abs() as usize];
573 }
574 }
575
576 acc.to_affine()
577 }
578
579 let (scalars, points): (Vec<_>, Vec<_>) = (0..10)
580 .map(|_| {
581 let scalar = Fr::random(OsRng);
582 let point = G1Affine::random(OsRng);
583 (scalar, point)
584 })
585 .unzip();
586
587 for window in 1..10 {
588 for (scalar, point) in scalars.iter().zip(points.iter()) {
589 let c0 = mul(scalar, point, window);
590 let c1 = point * scalar;
591 assert_eq!(c0, c1.to_affine());
592 }
593 }
594 }
595
596 fn run_msm_cross<C: CurveAffine>(min_k: usize, max_k: usize) {
597 use rayon::iter::{IntoParallelIterator, ParallelIterator};
598
599 let points = (0..1 << max_k)
600 .into_par_iter()
601 .map(|_| C::Curve::random(OsRng))
602 .collect::<Vec<_>>();
603 let mut affine_points = vec![C::identity(); 1 << max_k];
604 C::Curve::batch_normalize(&points[..], &mut affine_points[..]);
605 let points = affine_points;
606
607 let scalars = (0..1 << max_k)
608 .into_par_iter()
609 .map(|_| C::Scalar::random(OsRng))
610 .collect::<Vec<_>>();
611
612 for k in min_k..=max_k {
613 let points = &points[..1 << k];
614 let scalars = &scalars[..1 << k];
615
616 let t0 = start_timer!(|| format!("cyclone indep k={}", k));
617 let e0 = super::msm_best(scalars, points);
618 end_timer!(t0);
619
620 let t1 = start_timer!(|| format!("older k={}", k));
621 let e1 = super::msm_parallel(scalars, points);
622 end_timer!(t1);
623 assert_eq!(e0, e1);
624 }
625 }
626
627 #[test]
628 fn test_msm_cross() {
629 run_msm_cross::<G1Affine>(14, 18);
630 }
631}