halo2curves_axiom/
msm.rs

1use std::ops::Neg;
2
3use crate::CurveAffine;
4use ff::Field;
5use ff::PrimeField;
6use group::Group;
7use rayon::iter::{
8    IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
9};
10
11const BATCH_SIZE: usize = 64;
12
13fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
14    // Booth encoding:
15    // * step by `window` size
16    // * slice by size of `window + 1``
17    // * each window overlap by 1 bit * append a zero bit to the least significant end
18    // Indexing rule for example window size 3 where we slice by 4 bits:
19    // `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
20    // So we can reduce the bucket size without preprocessing scalars
21    // and remembering them as in classic signed digit encoding
22
23    let skip_bits = (window_index * window_size).saturating_sub(1);
24    let skip_bytes = skip_bits / 8;
25
26    // fill into a u32
27    let mut v: [u8; 4] = [0; 4];
28    for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
29        *dst = *src
30    }
31    let mut tmp = u32::from_le_bytes(v);
32
33    // pad with one 0 if slicing the least significant window
34    if window_index == 0 {
35        tmp <<= 1;
36    }
37
38    // remove further bits
39    tmp >>= skip_bits - (skip_bytes * 8);
40    // apply the booth window
41    tmp &= (1 << (window_size + 1)) - 1;
42
43    let sign = tmp & (1 << window_size) == 0;
44
45    // div ceil by 2
46    tmp = (tmp + 1) >> 1;
47
48    // find the booth action index
49    if sign {
50        tmp as i32
51    } else {
52        ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
53    }
54}
55
56/// Batch addition.
57fn batch_add<C: CurveAffine>(
58    size: usize,
59    buckets: &mut [BucketAffine<C>],
60    points: &[SchedulePoint],
61    bases: &[Affine<C>],
62) {
63    let mut t = vec![C::Base::ZERO; size]; // Stores x2 - x1
64    let mut z = vec![C::Base::ZERO; size]; // Stores y2 - y1
65    let mut acc = C::Base::ONE;
66
67    for (
68        (
69            SchedulePoint {
70                base_idx,
71                buck_idx,
72                sign,
73            },
74            t,
75        ),
76        z,
77    ) in points.iter().zip(t.iter_mut()).zip(z.iter_mut())
78    {
79        if buckets[*buck_idx].is_inf() {
80            // We assume bases[*base_idx] != infinity always.
81            continue;
82        }
83
84        if buckets[*buck_idx].x() == bases[*base_idx].x {
85            // y-coordinate matches:
86            //  1. y1 == y2 and sign = false or
87            //  2. y1 != y2 and sign = true
88            //  => ( y1 == y2) xor !sign
89            //  (This uses the fact that x1 == x2 and both points satisfy the curve eq.)
90            if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign {
91                // Doubling
92                let x_squared = bases[*base_idx].x.square();
93                *z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); // 2y
94                *t = acc * (x_squared + x_squared + x_squared); // acc * 3x^2
95                acc *= *z;
96                continue;
97            }
98            // P + (-P)
99            buckets[*buck_idx].set_inf();
100            continue;
101        }
102        // Addition
103        *z = buckets[*buck_idx].x() - bases[*base_idx].x; // x2 - x1
104        if *sign {
105            *t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y);
106        } else {
107            *t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y);
108        } // y2 - y1
109        acc *= *z;
110    }
111
112    acc = acc
113        .invert()
114        .expect("Some edge case has not been handled properly");
115
116    for (
117        (
118            SchedulePoint {
119                base_idx,
120                buck_idx,
121                sign,
122            },
123            t,
124        ),
125        z,
126    ) in points.iter().zip(t.iter()).zip(z.iter()).rev()
127    {
128        if buckets[*buck_idx].is_inf() {
129            // We assume bases[*base_idx] != infinity always.
130            continue;
131        }
132        let lambda = acc * t;
133        acc *= z; // update acc
134        let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); // x_result
135        if *sign {
136            buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y));
137        } else {
138            buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y));
139        } // y_result = lambda * (x1 - x_result) - y1
140        buckets[*buck_idx].set_x(&x);
141    }
142}
143
144#[derive(Debug, Clone, Copy)]
145struct Affine<C: CurveAffine> {
146    x: C::Base,
147    y: C::Base,
148}
149
150impl<C: CurveAffine> Affine<C> {
151    fn from(point: &C) -> Self {
152        let coords = point.coordinates().unwrap();
153
154        Self {
155            x: *coords.x(),
156            y: *coords.y(),
157        }
158    }
159
160    fn neg(&self) -> Self {
161        Self {
162            x: self.x,
163            y: -self.y,
164        }
165    }
166
167    fn eval(&self) -> C {
168        C::from_xy(self.x, self.y).unwrap()
169    }
170}
171
172#[derive(Debug, Clone)]
173enum BucketAffine<C: CurveAffine> {
174    None,
175    Point(Affine<C>),
176}
177
178#[derive(Debug, Clone)]
179enum Bucket<C: CurveAffine> {
180    None,
181    Point(C::Curve),
182}
183
184impl<C: CurveAffine> Bucket<C> {
185    fn add_assign(&mut self, point: &C, sign: bool) {
186        *self = match *self {
187            Bucket::None => Bucket::Point({
188                if sign {
189                    point.to_curve()
190                } else {
191                    point.to_curve().neg()
192                }
193            }),
194            Bucket::Point(a) => {
195                if sign {
196                    Self::Point(a + point)
197                } else {
198                    Self::Point(a - point)
199                }
200            }
201        }
202    }
203
204    fn add(&self, other: &BucketAffine<C>) -> C::Curve {
205        match (self, other) {
206            (Self::Point(this), BucketAffine::Point(other)) => *this + other.eval(),
207            (Self::Point(this), BucketAffine::None) => *this,
208            (Self::None, BucketAffine::Point(other)) => other.eval().to_curve(),
209            (Self::None, BucketAffine::None) => C::Curve::identity(),
210        }
211    }
212}
213
214impl<C: CurveAffine> BucketAffine<C> {
215    fn assign(&mut self, point: &Affine<C>, sign: bool) -> bool {
216        match *self {
217            Self::None => {
218                *self = Self::Point(if sign { *point } else { point.neg() });
219                true
220            }
221            Self::Point(_) => false,
222        }
223    }
224
225    fn x(&self) -> C::Base {
226        match self {
227            Self::None => panic!("::x None"),
228            Self::Point(a) => a.x,
229        }
230    }
231
232    fn y(&self) -> C::Base {
233        match self {
234            Self::None => panic!("::y None"),
235            Self::Point(a) => a.y,
236        }
237    }
238
239    fn is_inf(&self) -> bool {
240        match self {
241            Self::None => true,
242            Self::Point(_) => false,
243        }
244    }
245
246    fn set_x(&mut self, x: &C::Base) {
247        match self {
248            Self::None => panic!("::set_x None"),
249            Self::Point(ref mut a) => a.x = *x,
250        }
251    }
252
253    fn set_y(&mut self, y: &C::Base) {
254        match self {
255            Self::None => panic!("::set_y None"),
256            Self::Point(ref mut a) => a.y = *y,
257        }
258    }
259
260    fn set_inf(&mut self) {
261        match self {
262            Self::None => {}
263            Self::Point(_) => *self = Self::None,
264        }
265    }
266}
267
268struct Schedule<C: CurveAffine> {
269    buckets: Vec<BucketAffine<C>>,
270    set: [SchedulePoint; BATCH_SIZE],
271    ptr: usize,
272}
273
274#[derive(Debug, Clone, Default)]
275struct SchedulePoint {
276    base_idx: usize,
277    buck_idx: usize,
278    sign: bool,
279}
280
281impl SchedulePoint {
282    fn new(base_idx: usize, buck_idx: usize, sign: bool) -> Self {
283        Self {
284            base_idx,
285            buck_idx,
286            sign,
287        }
288    }
289}
290
291impl<C: CurveAffine> Schedule<C> {
292    fn new(c: usize) -> Self {
293        let set = (0..BATCH_SIZE)
294            .map(|_| SchedulePoint::default())
295            .collect::<Vec<_>>()
296            .try_into()
297            .unwrap();
298
299        Self {
300            buckets: vec![BucketAffine::None; 1 << (c - 1)],
301            set,
302            ptr: 0,
303        }
304    }
305
306    fn contains(&self, buck_idx: usize) -> bool {
307        self.set.iter().any(|sch| sch.buck_idx == buck_idx)
308    }
309
310    fn execute(&mut self, bases: &[Affine<C>]) {
311        if self.ptr != 0 {
312            batch_add(self.ptr, &mut self.buckets, &self.set, bases);
313            self.ptr = 0;
314            self.set
315                .iter_mut()
316                .for_each(|sch| *sch = SchedulePoint::default());
317        }
318    }
319
320    fn add(&mut self, bases: &[Affine<C>], base_idx: usize, buck_idx: usize, sign: bool) {
321        if !self.buckets[buck_idx].assign(&bases[base_idx], sign) {
322            self.set[self.ptr] = SchedulePoint::new(base_idx, buck_idx, sign);
323            self.ptr += 1;
324        }
325
326        if self.ptr == self.set.len() {
327            self.execute(bases);
328        }
329    }
330}
331
332/// Performs a multi-scalar multiplication operation.
333///
334/// This function will panic if coeffs and bases have a different length.
335pub fn msm_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
336    let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
337
338    let c = if bases.len() < 4 {
339        1
340    } else if bases.len() < 32 {
341        3
342    } else {
343        (f64::from(bases.len() as u32)).ln().ceil() as usize
344    };
345
346    let field_byte_size = C::Scalar::NUM_BITS.div_ceil(8u32) as usize;
347    // OR all coefficients in order to make a mask to figure out the maximum number of bytes used
348    // among all coefficients.
349    let mut acc_or = vec![0; field_byte_size];
350    for coeff in &coeffs {
351        for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
352            *acc_limb |= *limb;
353        }
354    }
355    let max_byte_size = field_byte_size
356        - acc_or
357            .iter()
358            .rev()
359            .position(|v| *v != 0)
360            .unwrap_or(field_byte_size);
361    if max_byte_size == 0 {
362        return;
363    }
364    let number_of_windows = max_byte_size * 8_usize / c + 1;
365
366    for current_window in (0..number_of_windows).rev() {
367        for _ in 0..c {
368            *acc = acc.double();
369        }
370
371        #[derive(Clone, Copy)]
372        enum Bucket<C: CurveAffine> {
373            None,
374            Affine(C),
375            Projective(C::Curve),
376        }
377
378        impl<C: CurveAffine> Bucket<C> {
379            fn add_assign(&mut self, other: &C) {
380                *self = match *self {
381                    Bucket::None => Bucket::Affine(*other),
382                    Bucket::Affine(a) => Bucket::Projective(a + *other),
383                    Bucket::Projective(mut a) => {
384                        a += *other;
385                        Bucket::Projective(a)
386                    }
387                }
388            }
389
390            fn add(self, mut other: C::Curve) -> C::Curve {
391                match self {
392                    Bucket::None => other,
393                    Bucket::Affine(a) => {
394                        other += a;
395                        other
396                    }
397                    Bucket::Projective(a) => other + a,
398                }
399            }
400        }
401
402        let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];
403
404        for (coeff, base) in coeffs.iter().zip(bases.iter()) {
405            let coeff = get_booth_index(current_window, c, coeff.as_ref());
406            if coeff.is_positive() {
407                buckets[coeff as usize - 1].add_assign(base);
408            }
409            if coeff.is_negative() {
410                buckets[coeff.unsigned_abs() as usize - 1].add_assign(&base.neg());
411            }
412        }
413
414        // Summation by parts
415        // e.g. 3a + 2b + 1c = a +
416        //                    (a) + b +
417        //                    ((a) + b) + c
418        let mut running_sum = C::Curve::identity();
419        for exp in buckets.into_iter().rev() {
420            running_sum = exp.add(running_sum);
421            *acc += &running_sum;
422        }
423    }
424}
425
426/// Performs a multi-scalar multiplication operation.
427///
428/// This function will panic if coeffs and bases have a different length.
429///
430/// This will use multithreading if beneficial.
431pub fn msm_parallel<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
432    assert_eq!(coeffs.len(), bases.len());
433
434    let num_threads = rayon::current_num_threads();
435    if coeffs.len() > num_threads {
436        let chunk = coeffs.len() / num_threads;
437        let num_chunks = coeffs.chunks(chunk).len();
438        let mut results = vec![C::Curve::identity(); num_chunks];
439        rayon::scope(|scope| {
440            let chunk = coeffs.len() / num_threads;
441
442            for ((coeffs, bases), acc) in coeffs
443                .chunks(chunk)
444                .zip(bases.chunks(chunk))
445                .zip(results.iter_mut())
446            {
447                scope.spawn(move |_| {
448                    msm_serial(coeffs, bases, acc);
449                });
450            }
451        });
452        results.iter().fold(C::Curve::identity(), |a, b| a + b)
453    } else {
454        let mut acc = C::Curve::identity();
455        msm_serial(coeffs, bases, &mut acc);
456        acc
457    }
458}
459
460/// This function will panic if coeffs and bases have a different length.
461///
462/// This will use multithreading if beneficial.
463pub fn msm_best<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
464    assert_eq!(coeffs.len(), bases.len());
465
466    // TODO: consider adjusting it with emprical data?
467    let c = if bases.len() < 4 {
468        1
469    } else if bases.len() < 32 {
470        3
471    } else {
472        (f64::from(bases.len() as u32)).ln().ceil() as usize
473    };
474
475    if c < 10 {
476        return msm_parallel(coeffs, bases);
477    }
478
479    // coeffs to byte representation
480    let coeffs: Vec<_> = coeffs.par_iter().map(|a| a.to_repr()).collect();
481    // copy bases into `Affine` to skip in on curve check for every access
482    let bases_local: Vec<_> = bases.par_iter().map(Affine::from).collect();
483
484    // number of windows
485    let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;
486    // accumumator for each window
487    let mut acc = vec![C::Curve::identity(); number_of_windows];
488    acc.par_iter_mut().enumerate().rev().for_each(|(w, acc)| {
489        // jacobian buckets for already scheduled points
490        let mut j_bucks = vec![Bucket::<C>::None; 1 << (c - 1)];
491
492        // schedular for affine addition
493        let mut sched = Schedule::new(c);
494
495        for (base_idx, coeff) in coeffs.iter().enumerate() {
496            let buck_idx = get_booth_index(w, c, coeff.as_ref());
497
498            if buck_idx != 0 {
499                // parse bucket index
500                let sign = buck_idx.is_positive();
501                let buck_idx = buck_idx.unsigned_abs() as usize - 1;
502
503                if sched.contains(buck_idx) {
504                    // greedy accumulation
505                    // we use original bases here
506                    j_bucks[buck_idx].add_assign(&bases[base_idx], sign);
507                } else {
508                    // also flushes the schedule if full
509                    sched.add(&bases_local, base_idx, buck_idx, sign);
510                }
511            }
512        }
513
514        // flush the schedule
515        sched.execute(&bases_local);
516
517        // summation by parts
518        // e.g. 3a + 2b + 1c = a +
519        //                    (a) + b +
520        //                    ((a) + b) + c
521        let mut running_sum = C::Curve::identity();
522        for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() {
523            running_sum += j_buck.add(a_buck);
524            *acc += running_sum;
525        }
526
527        // shift accumulator to the window position
528        for _ in 0..c * w {
529            *acc = acc.double();
530        }
531    });
532    acc.into_iter().sum::<_>()
533}
534
535#[cfg(test)]
536mod test {
537    use std::ops::Neg;
538
539    use crate::bn256::{Fr, G1Affine, G1};
540    use ark_std::{end_timer, start_timer};
541    use ff::{Field, PrimeField};
542    use group::{Curve, Group};
543    use pasta_curves::arithmetic::CurveAffine;
544    use rand_core::OsRng;
545
546    #[test]
547    fn test_booth_encoding() {
548        fn mul(scalar: &Fr, point: &G1Affine, window: usize) -> G1Affine {
549            let u = scalar.to_repr();
550            let n = Fr::NUM_BITS as usize / window + 1;
551
552            let table = (0..=1 << (window - 1))
553                .map(|i| point * Fr::from(i as u64))
554                .collect::<Vec<_>>();
555
556            let mut acc = G1::identity();
557            for i in (0..n).rev() {
558                for _ in 0..window {
559                    acc = acc.double();
560                }
561
562                let idx = super::get_booth_index(i, window, u.as_ref());
563
564                if idx.is_negative() {
565                    acc += table[idx.unsigned_abs() as usize].neg();
566                }
567                if idx.is_positive() {
568                    acc += table[idx.unsigned_abs() as usize];
569                }
570            }
571
572            acc.to_affine()
573        }
574
575        let (scalars, points): (Vec<_>, Vec<_>) = (0..10)
576            .map(|_| {
577                let scalar = Fr::random(OsRng);
578                let point = G1Affine::random(OsRng);
579                (scalar, point)
580            })
581            .unzip();
582
583        for window in 1..10 {
584            for (scalar, point) in scalars.iter().zip(points.iter()) {
585                let c0 = mul(scalar, point, window);
586                let c1 = point * scalar;
587                assert_eq!(c0, c1.to_affine());
588            }
589        }
590    }
591
592    fn run_msm_cross<C: CurveAffine>(min_k: usize, max_k: usize) {
593        use rayon::iter::{IntoParallelIterator, ParallelIterator};
594
595        let points = (0..1 << max_k)
596            .into_par_iter()
597            .map(|_| C::Curve::random(OsRng))
598            .collect::<Vec<_>>();
599        let mut affine_points = vec![C::identity(); 1 << max_k];
600        C::Curve::batch_normalize(&points[..], &mut affine_points[..]);
601        let points = affine_points;
602
603        let scalars = (0..1 << max_k)
604            .into_par_iter()
605            .map(|_| C::Scalar::random(OsRng))
606            .collect::<Vec<_>>();
607
608        for k in min_k..=max_k {
609            let points = &points[..1 << k];
610            let scalars = &scalars[..1 << k];
611
612            let t0 = start_timer!(|| format!("cyclone indep k={}", k));
613            let e0 = super::msm_best(scalars, points);
614            end_timer!(t0);
615
616            let t1 = start_timer!(|| format!("older k={}", k));
617            let e1 = super::msm_parallel(scalars, points);
618            end_timer!(t1);
619            assert_eq!(e0, e1);
620        }
621    }
622
623    #[test]
624    fn test_msm_cross() {
625        run_msm_cross::<G1Affine>(14, 22);
626    }
627}