openvm_ecc_guest/
msm.rs

1use alloc::{vec, vec::Vec};
2use core::ops::{Add, Neg};
3
4use openvm_algebra_guest::IntMod;
5
6use super::Group;
7
8/// Multi-scalar multiplication via Pippenger's algorithm
9// Reference: https://github.com/privacy-scaling-explorations/halo2curves/blob/8771fe5a5d54fc03e74dbc8915db5dad3ab46a83/src/msm.rs#L335
10pub fn msm<EcPoint: Group, Scalar: IntMod>(coeffs: &[Scalar], bases: &[EcPoint]) -> EcPoint
11where
12    for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
13{
14    let coeffs: Vec<_> = coeffs.iter().map(|c| c.as_le_bytes()).collect();
15    let mut acc = EcPoint::IDENTITY;
16
17    // c: window size. Will group scalars into c-bit windows
18    let c = if bases.len() < 4 {
19        1
20    } else if bases.len() < 32 {
21        3
22    } else {
23        // finetune this if needed
24        bases.len().ilog2() as usize
25    };
26
27    let field_byte_size = Scalar::NUM_LIMBS;
28
29    // OR all coefficients in order to make a mask to figure out the maximum number of bytes used
30    // among all coefficients.
31    let mut acc_or = vec![0; field_byte_size];
32    for coeff in &coeffs {
33        for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
34            *acc_limb |= *limb;
35        }
36    }
37    let max_byte_size = field_byte_size
38        - acc_or
39            .iter()
40            .rev()
41            .position(|v| *v != 0)
42            .unwrap_or(field_byte_size);
43    if max_byte_size == 0 {
44        return EcPoint::IDENTITY;
45    }
46    let number_of_windows = max_byte_size * 8_usize / c + 1;
47
48    for current_window in (0..number_of_windows).rev() {
49        for _ in 0..c {
50            acc.double_assign();
51        }
52        #[derive(Clone)]
53        enum Bucket<EcPoint: Group> {
54            None,
55            Affine(EcPoint),
56        }
57
58        impl<EcPoint: Group> Bucket<EcPoint>
59        where
60            for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
61        {
62            fn add_assign(&mut self, other: &EcPoint) {
63                match self {
64                    Bucket::None => {
65                        *self = Bucket::Affine(other.clone());
66                    }
67                    Bucket::Affine(a) => {
68                        a.add_assign(other);
69                    }
70                }
71            }
72
73            fn sub_assign(&mut self, other: &EcPoint) {
74                match self {
75                    Bucket::None => {
76                        *self = Bucket::Affine(other.clone().neg());
77                    }
78                    Bucket::Affine(a) => {
79                        a.sub_assign(other);
80                    }
81                }
82            }
83
84            fn add(self, mut other: EcPoint) -> EcPoint {
85                match self {
86                    Bucket::None => other.clone(),
87                    Bucket::Affine(a) => {
88                        other += a;
89                        other
90                    }
91                }
92            }
93        }
94
95        let mut buckets: Vec<Bucket<EcPoint>> = vec![Bucket::None; 1 << (c - 1)];
96
97        for (coeff, base) in coeffs.iter().zip(bases.iter()) {
98            let coeff = get_booth_index(current_window, c, coeff);
99            if coeff.is_positive() {
100                buckets[coeff as usize - 1].add_assign(base);
101            }
102            if coeff.is_negative() {
103                buckets[coeff.unsigned_abs() as usize - 1].sub_assign(base);
104            }
105        }
106
107        // Summation by parts
108        // e.g. 3a + 2b + 1c = a +
109        //                    (a) + b +
110        //                    ((a) + b) + c
111        let mut running_sum = EcPoint::IDENTITY;
112        for exp in buckets.into_iter().rev() {
113            running_sum = exp.add(running_sum);
114            acc = acc.add(&running_sum);
115        }
116    }
117    acc
118}
119
120fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
121    // Booth encoding:
122    // * step by `window` size
123    // * slice by size of `window + 1``
124    // * each window overlap by 1 bit * append a zero bit to the least significant end
125    // Indexing rule for example window size 3 where we slice by 4 bits:
126    // `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
127    // So we can reduce the bucket size without preprocessing scalars
128    // and remembering them as in classic signed digit encoding
129
130    let skip_bits = (window_index * window_size).saturating_sub(1);
131    let skip_bytes = skip_bits / 8;
132
133    // fill into a u32
134    let mut v: [u8; 4] = [0; 4];
135    for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
136        *dst = *src
137    }
138    let mut tmp = u32::from_le_bytes(v);
139
140    // pad with one 0 if slicing the least significant window
141    if window_index == 0 {
142        tmp <<= 1;
143    }
144
145    // remove further bits
146    tmp >>= skip_bits - (skip_bytes * 8);
147    // apply the booth window
148    tmp &= (1 << (window_size + 1)) - 1;
149
150    let sign = tmp & (1 << window_size) == 0;
151
152    // div ceil by 2
153    tmp = (tmp + 1) >> 1;
154
155    // find the booth action index
156    if sign {
157        tmp as i32
158    } else {
159        ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
160    }
161}