halo2curves/
msm.rs

1use std::ops::Neg;
2
3use ff::{Field, PrimeField};
4use group::Group;
5use rayon::iter::{
6    IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
7};
8
9use crate::CurveAffine;
10
11const BATCH_SIZE: usize = 64;
12
13fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
14    // Booth encoding:
15    // * step by `window` size
16    // * slice by size of `window + 1``
17    // * each window overlap by 1 bit * append a zero bit to the least significant
18    //   end
19    // Indexing rule for example window size 3 where we slice by 4 bits:
20    // `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
21    // So we can reduce the bucket size without preprocessing scalars
22    // and remembering them as in classic signed digit encoding
23
24    let skip_bits = (window_index * window_size).saturating_sub(1);
25    let skip_bytes = skip_bits / 8;
26
27    // fill into a u32
28    let mut v: [u8; 4] = [0; 4];
29    for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
30        *dst = *src
31    }
32    let mut tmp = u32::from_le_bytes(v);
33
34    // pad with one 0 if slicing the least significant window
35    if window_index == 0 {
36        tmp <<= 1;
37    }
38
39    // remove further bits
40    tmp >>= skip_bits - (skip_bytes * 8);
41    // apply the booth window
42    tmp &= (1 << (window_size + 1)) - 1;
43
44    let sign = tmp & (1 << window_size) == 0;
45
46    // div ceil by 2
47    tmp = (tmp + 1) >> 1;
48
49    // find the booth action index
50    if sign {
51        tmp as i32
52    } else {
53        ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
54    }
55}
56
57/// Batch addition.
58fn batch_add<C: CurveAffine>(
59    size: usize,
60    buckets: &mut [BucketAffine<C>],
61    points: &[SchedulePoint],
62    bases: &[Affine<C>],
63) {
64    let mut t = vec![C::Base::ZERO; size]; // Stores x2 - x1
65    let mut z = vec![C::Base::ZERO; size]; // Stores y2 - y1
66    let mut acc = C::Base::ONE;
67
68    for (
69        (
70            SchedulePoint {
71                base_idx,
72                buck_idx,
73                sign,
74            },
75            t,
76        ),
77        z,
78    ) in points.iter().zip(t.iter_mut()).zip(z.iter_mut())
79    {
80        if buckets[*buck_idx].is_inf() {
81            // We assume bases[*base_idx] != infinity always.
82            continue;
83        }
84
85        if buckets[*buck_idx].x() == bases[*base_idx].x {
86            // y-coordinate matches:
87            //  1. y1 == y2 and sign = false or
88            //  2. y1 != y2 and sign = true
89            //  => ( y1 == y2) xor !sign
90            //  (This uses the fact that x1 == x2 and both points satisfy the curve eq.)
91            if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign {
92                // Doubling
93                let x_squared = bases[*base_idx].x.square();
94                *z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); // 2y
95                *t = acc * (x_squared + x_squared + x_squared); // acc * 3x^2
96                acc *= *z;
97                continue;
98            }
99            // P + (-P)
100            buckets[*buck_idx].set_inf();
101            continue;
102        }
103        // Addition
104        *z = buckets[*buck_idx].x() - bases[*base_idx].x; // x2 - x1
105        if *sign {
106            *t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y);
107        } else {
108            *t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y);
109        } // y2 - y1
110        acc *= *z;
111    }
112
113    acc = acc
114        .invert()
115        .expect("Some edge case has not been handled properly");
116
117    for (
118        (
119            SchedulePoint {
120                base_idx,
121                buck_idx,
122                sign,
123            },
124            t,
125        ),
126        z,
127    ) in points.iter().zip(t.iter()).zip(z.iter()).rev()
128    {
129        if buckets[*buck_idx].is_inf() {
130            // We assume bases[*base_idx] != infinity always.
131            continue;
132        }
133        let lambda = acc * t;
134        acc *= z; // update acc
135        let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); // x_result
136        if *sign {
137            buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y));
138        } else {
139            buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y));
140        } // y_result = lambda * (x1 - x_result) - y1
141        buckets[*buck_idx].set_x(&x);
142    }
143}
144
145#[derive(Debug, Clone, Copy)]
146struct Affine<C: CurveAffine> {
147    x: C::Base,
148    y: C::Base,
149}
150
151impl<C: CurveAffine> Affine<C> {
152    fn from(point: &C) -> Self {
153        let coords = point.coordinates().unwrap();
154
155        Self {
156            x: *coords.x(),
157            y: *coords.y(),
158        }
159    }
160
161    fn neg(&self) -> Self {
162        Self {
163            x: self.x,
164            y: -self.y,
165        }
166    }
167
168    fn eval(&self) -> C {
169        C::from_xy(self.x, self.y).unwrap()
170    }
171}
172
173#[derive(Debug, Clone)]
174enum BucketAffine<C: CurveAffine> {
175    None,
176    Point(Affine<C>),
177}
178
179#[derive(Debug, Clone)]
180enum Bucket<C: CurveAffine> {
181    None,
182    Point(C::Curve),
183}
184
185impl<C: CurveAffine> Bucket<C> {
186    fn add_assign(&mut self, point: &C, sign: bool) {
187        *self = match *self {
188            Bucket::None => Bucket::Point({
189                if sign {
190                    point.to_curve()
191                } else {
192                    point.to_curve().neg()
193                }
194            }),
195            Bucket::Point(a) => {
196                if sign {
197                    Self::Point(a + point)
198                } else {
199                    Self::Point(a - point)
200                }
201            }
202        }
203    }
204
205    fn add(&self, other: &BucketAffine<C>) -> C::Curve {
206        match (self, other) {
207            (Self::Point(this), BucketAffine::Point(other)) => *this + other.eval(),
208            (Self::Point(this), BucketAffine::None) => *this,
209            (Self::None, BucketAffine::Point(other)) => other.eval().to_curve(),
210            (Self::None, BucketAffine::None) => C::Curve::identity(),
211        }
212    }
213}
214
215impl<C: CurveAffine> BucketAffine<C> {
216    fn assign(&mut self, point: &Affine<C>, sign: bool) -> bool {
217        match *self {
218            Self::None => {
219                *self = Self::Point(if sign { *point } else { point.neg() });
220                true
221            }
222            Self::Point(_) => false,
223        }
224    }
225
226    fn x(&self) -> C::Base {
227        match self {
228            Self::None => panic!("::x None"),
229            Self::Point(a) => a.x,
230        }
231    }
232
233    fn y(&self) -> C::Base {
234        match self {
235            Self::None => panic!("::y None"),
236            Self::Point(a) => a.y,
237        }
238    }
239
240    fn is_inf(&self) -> bool {
241        match self {
242            Self::None => true,
243            Self::Point(_) => false,
244        }
245    }
246
247    fn set_x(&mut self, x: &C::Base) {
248        match self {
249            Self::None => panic!("::set_x None"),
250            Self::Point(ref mut a) => a.x = *x,
251        }
252    }
253
254    fn set_y(&mut self, y: &C::Base) {
255        match self {
256            Self::None => panic!("::set_y None"),
257            Self::Point(ref mut a) => a.y = *y,
258        }
259    }
260
261    fn set_inf(&mut self) {
262        match self {
263            Self::None => {}
264            Self::Point(_) => *self = Self::None,
265        }
266    }
267}
268
269struct Schedule<C: CurveAffine> {
270    buckets: Vec<BucketAffine<C>>,
271    set: [SchedulePoint; BATCH_SIZE],
272    ptr: usize,
273}
274
275#[derive(Debug, Clone, Default)]
276struct SchedulePoint {
277    base_idx: usize,
278    buck_idx: usize,
279    sign: bool,
280}
281
282impl SchedulePoint {
283    fn new(base_idx: usize, buck_idx: usize, sign: bool) -> Self {
284        Self {
285            base_idx,
286            buck_idx,
287            sign,
288        }
289    }
290}
291
292impl<C: CurveAffine> Schedule<C> {
293    fn new(c: usize) -> Self {
294        let set = (0..BATCH_SIZE)
295            .map(|_| SchedulePoint::default())
296            .collect::<Vec<_>>()
297            .try_into()
298            .unwrap();
299
300        Self {
301            buckets: vec![BucketAffine::None; 1 << (c - 1)],
302            set,
303            ptr: 0,
304        }
305    }
306
307    fn contains(&self, buck_idx: usize) -> bool {
308        self.set.iter().any(|sch| sch.buck_idx == buck_idx)
309    }
310
311    fn execute(&mut self, bases: &[Affine<C>]) {
312        if self.ptr != 0 {
313            batch_add(self.ptr, &mut self.buckets, &self.set, bases);
314            self.ptr = 0;
315            self.set
316                .iter_mut()
317                .for_each(|sch| *sch = SchedulePoint::default());
318        }
319    }
320
321    fn add(&mut self, bases: &[Affine<C>], base_idx: usize, buck_idx: usize, sign: bool) {
322        if !self.buckets[buck_idx].assign(&bases[base_idx], sign) {
323            self.set[self.ptr] = SchedulePoint::new(base_idx, buck_idx, sign);
324            self.ptr += 1;
325        }
326
327        if self.ptr == self.set.len() {
328            self.execute(bases);
329        }
330    }
331}
332
333/// Performs a multi-scalar multiplication operation.
334///
335/// This function will panic if coeffs and bases have a different length.
336pub fn msm_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
337    let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
338
339    let c = if bases.len() < 4 {
340        1
341    } else if bases.len() < 32 {
342        3
343    } else {
344        (f64::from(bases.len() as u32)).ln().ceil() as usize
345    };
346
347    let field_byte_size = C::Scalar::NUM_BITS.div_ceil(8u32) as usize;
348    // OR all coefficients in order to make a mask to figure out the maximum number
349    // of bytes used among all coefficients.
350    let mut acc_or = vec![0; field_byte_size];
351    for coeff in &coeffs {
352        for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
353            *acc_limb |= *limb;
354        }
355    }
356    let max_byte_size = field_byte_size
357        - acc_or
358            .iter()
359            .rev()
360            .position(|v| *v != 0)
361            .unwrap_or(field_byte_size);
362    if max_byte_size == 0 {
363        return;
364    }
365    let number_of_windows = max_byte_size * 8_usize / c + 1;
366
367    for current_window in (0..number_of_windows).rev() {
368        for _ in 0..c {
369            *acc = acc.double();
370        }
371
372        #[derive(Clone, Copy)]
373        enum Bucket<C: CurveAffine> {
374            None,
375            Affine(C),
376            Projective(C::Curve),
377        }
378
379        impl<C: CurveAffine> Bucket<C> {
380            fn add_assign(&mut self, other: &C) {
381                *self = match *self {
382                    Bucket::None => Bucket::Affine(*other),
383                    Bucket::Affine(a) => Bucket::Projective(a + *other),
384                    Bucket::Projective(mut a) => {
385                        a += *other;
386                        Bucket::Projective(a)
387                    }
388                }
389            }
390
391            fn add(self, mut other: C::Curve) -> C::Curve {
392                match self {
393                    Bucket::None => other,
394                    Bucket::Affine(a) => {
395                        other += a;
396                        other
397                    }
398                    Bucket::Projective(a) => other + a,
399                }
400            }
401        }
402
403        let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];
404
405        for (coeff, base) in coeffs.iter().zip(bases.iter()) {
406            let coeff = get_booth_index(current_window, c, coeff.as_ref());
407            if coeff.is_positive() {
408                buckets[coeff as usize - 1].add_assign(base);
409            }
410            if coeff.is_negative() {
411                buckets[coeff.unsigned_abs() as usize - 1].add_assign(&base.neg());
412            }
413        }
414
415        // Summation by parts
416        // e.g. 3a + 2b + 1c = a +
417        //                    (a) + b +
418        //                    ((a) + b) + c
419        let mut running_sum = C::Curve::identity();
420        for exp in buckets.into_iter().rev() {
421            running_sum = exp.add(running_sum);
422            *acc += &running_sum;
423        }
424    }
425}
426
427/// Performs a multi-scalar multiplication operation.
428///
429/// This function will panic if coeffs and bases have a different length.
430///
431/// This will use multithreading if beneficial.
432pub fn msm_parallel<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
433    assert_eq!(coeffs.len(), bases.len());
434
435    let num_threads = rayon::current_num_threads();
436    if coeffs.len() > num_threads {
437        let chunk = coeffs.len() / num_threads;
438        let num_chunks = coeffs.chunks(chunk).len();
439        let mut results = vec![C::Curve::identity(); num_chunks];
440        rayon::scope(|scope| {
441            let chunk = coeffs.len() / num_threads;
442
443            for ((coeffs, bases), acc) in coeffs
444                .chunks(chunk)
445                .zip(bases.chunks(chunk))
446                .zip(results.iter_mut())
447            {
448                scope.spawn(move |_| {
449                    msm_serial(coeffs, bases, acc);
450                });
451            }
452        });
453        results.iter().fold(C::Curve::identity(), |a, b| a + b)
454    } else {
455        let mut acc = C::Curve::identity();
456        msm_serial(coeffs, bases, &mut acc);
457        acc
458    }
459}
460
461/// This function will panic if coeffs and bases have a different length.
462///
463/// This will use multithreading if beneficial.
464pub fn msm_best<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
465    assert_eq!(coeffs.len(), bases.len());
466
467    // TODO: consider adjusting it with empirical data?
468    let c = if bases.len() < 4 {
469        1
470    } else if bases.len() < 32 {
471        3
472    } else {
473        (f64::from(bases.len() as u32)).ln().ceil() as usize
474    };
475
476    if c < 10 {
477        return msm_parallel(coeffs, bases);
478    }
479
480    // coeffs to byte representation
481    let coeffs: Vec<_> = coeffs.par_iter().map(|a| a.to_repr()).collect();
482    // copy bases into `Affine` to skip in on curve check for every access
483    let bases_local: Vec<_> = bases.par_iter().map(Affine::from).collect();
484
485    // number of windows
486    let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;
487    // accumumator for each window
488    let mut acc = vec![C::Curve::identity(); number_of_windows];
489    acc.par_iter_mut().enumerate().rev().for_each(|(w, acc)| {
490        // jacobian buckets for already scheduled points
491        let mut j_bucks = vec![Bucket::<C>::None; 1 << (c - 1)];
492
493        // schedular for affine addition
494        let mut sched = Schedule::new(c);
495
496        for (base_idx, coeff) in coeffs.iter().enumerate() {
497            let buck_idx = get_booth_index(w, c, coeff.as_ref());
498
499            if buck_idx != 0 {
500                // parse bucket index
501                let sign = buck_idx.is_positive();
502                let buck_idx = buck_idx.unsigned_abs() as usize - 1;
503
504                if sched.contains(buck_idx) {
505                    // greedy accumulation
506                    // we use original bases here
507                    j_bucks[buck_idx].add_assign(&bases[base_idx], sign);
508                } else {
509                    // also flushes the schedule if full
510                    sched.add(&bases_local, base_idx, buck_idx, sign);
511                }
512            }
513        }
514
515        // flush the schedule
516        sched.execute(&bases_local);
517
518        // summation by parts
519        // e.g. 3a + 2b + 1c = a +
520        //                    (a) + b +
521        //                    ((a) + b) + c
522        let mut running_sum = C::Curve::identity();
523        for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() {
524            running_sum += j_buck.add(a_buck);
525            *acc += running_sum;
526        }
527
528        // shift accumulator to the window position
529        for _ in 0..c * w {
530            *acc = acc.double();
531        }
532    });
533    acc.into_iter().sum::<_>()
534}
535
536#[cfg(test)]
537mod test {
538    use std::ops::Neg;
539
540    use ark_std::{end_timer, start_timer};
541    use ff::{Field, PrimeField};
542    use group::{Curve, Group};
543    use rand_core::OsRng;
544
545    use crate::{
546        bn256::{Fr, G1Affine, G1},
547        CurveAffine,
548    };
549
550    #[test]
551    fn test_booth_encoding() {
552        fn mul(scalar: &Fr, point: &G1Affine, window: usize) -> G1Affine {
553            let u = scalar.to_repr();
554            let n = Fr::NUM_BITS as usize / window + 1;
555
556            let table = (0..=1 << (window - 1))
557                .map(|i| point * Fr::from(i as u64))
558                .collect::<Vec<_>>();
559
560            let mut acc = G1::identity();
561            for i in (0..n).rev() {
562                for _ in 0..window {
563                    acc = acc.double();
564                }
565
566                let idx = super::get_booth_index(i, window, u.as_ref());
567
568                if idx.is_negative() {
569                    acc += table[idx.unsigned_abs() as usize].neg();
570                }
571                if idx.is_positive() {
572                    acc += table[idx.unsigned_abs() as usize];
573                }
574            }
575
576            acc.to_affine()
577        }
578
579        let (scalars, points): (Vec<_>, Vec<_>) = (0..10)
580            .map(|_| {
581                let scalar = Fr::random(OsRng);
582                let point = G1Affine::random(OsRng);
583                (scalar, point)
584            })
585            .unzip();
586
587        for window in 1..10 {
588            for (scalar, point) in scalars.iter().zip(points.iter()) {
589                let c0 = mul(scalar, point, window);
590                let c1 = point * scalar;
591                assert_eq!(c0, c1.to_affine());
592            }
593        }
594    }
595
596    fn run_msm_cross<C: CurveAffine>(min_k: usize, max_k: usize) {
597        use rayon::iter::{IntoParallelIterator, ParallelIterator};
598
599        let points = (0..1 << max_k)
600            .into_par_iter()
601            .map(|_| C::Curve::random(OsRng))
602            .collect::<Vec<_>>();
603        let mut affine_points = vec![C::identity(); 1 << max_k];
604        C::Curve::batch_normalize(&points[..], &mut affine_points[..]);
605        let points = affine_points;
606
607        let scalars = (0..1 << max_k)
608            .into_par_iter()
609            .map(|_| C::Scalar::random(OsRng))
610            .collect::<Vec<_>>();
611
612        for k in min_k..=max_k {
613            let points = &points[..1 << k];
614            let scalars = &scalars[..1 << k];
615
616            let t0 = start_timer!(|| format!("cyclone indep k={}", k));
617            let e0 = super::msm_best(scalars, points);
618            end_timer!(t0);
619
620            let t1 = start_timer!(|| format!("older k={}", k));
621            let e1 = super::msm_parallel(scalars, points);
622            end_timer!(t1);
623            assert_eq!(e0, e1);
624        }
625    }
626
627    #[test]
628    fn test_msm_cross() {
629        run_msm_cross::<G1Affine>(14, 18);
630    }
631}