openvm_ecc_guest/
msm.rs

1use alloc::{vec, vec::Vec};
2use core::ops::{Add, Neg};
3
4use openvm_algebra_guest::IntMod;
5
6use super::Group;
7
8/// Multi-scalar multiplication via Pippenger's algorithm
9// Reference: https://github.com/privacy-scaling-explorations/halo2curves/blob/8771fe5a5d54fc03e74dbc8915db5dad3ab46a83/src/msm.rs#L335
10pub fn msm<EcPoint: Group, Scalar: IntMod>(coeffs: &[Scalar], bases: &[EcPoint]) -> EcPoint
11where
12    for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
13{
14    debug_assert_eq!(
15        coeffs.len(),
16        bases.len(),
17        "msm: coefficients and bases must have the same length"
18    );
19    let coeffs: Vec<_> = coeffs.iter().map(|c| c.as_le_bytes()).collect();
20    let mut acc = EcPoint::IDENTITY;
21
22    // c: window size. Will group scalars into c-bit windows
23    let c = if bases.len() < 4 {
24        1
25    } else if bases.len() < 32 {
26        3
27    } else {
28        // finetune this if needed
29        bases.len().ilog2() as usize
30    };
31
32    let field_byte_size = Scalar::NUM_LIMBS;
33
34    // OR all coefficients in order to make a mask to figure out the maximum number of bytes used
35    // among all coefficients.
36    let mut acc_or = vec![0; field_byte_size];
37    for coeff in &coeffs {
38        for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
39            *acc_limb |= *limb;
40        }
41    }
42    let max_byte_size = field_byte_size
43        - acc_or
44            .iter()
45            .rev()
46            .position(|v| *v != 0)
47            .unwrap_or(field_byte_size);
48    if max_byte_size == 0 {
49        return EcPoint::IDENTITY;
50    }
51    let number_of_windows = max_byte_size * 8_usize / c + 1;
52
53    for current_window in (0..number_of_windows).rev() {
54        for _ in 0..c {
55            acc.double_assign();
56        }
57        #[derive(Clone)]
58        enum Bucket<EcPoint: Group> {
59            None,
60            Affine(EcPoint),
61        }
62
63        impl<EcPoint: Group> Bucket<EcPoint>
64        where
65            for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
66        {
67            fn add_assign(&mut self, other: &EcPoint) {
68                match self {
69                    Bucket::None => {
70                        *self = Bucket::Affine(other.clone());
71                    }
72                    Bucket::Affine(a) => {
73                        a.add_assign(other);
74                    }
75                }
76            }
77
78            fn sub_assign(&mut self, other: &EcPoint) {
79                match self {
80                    Bucket::None => {
81                        *self = Bucket::Affine(other.clone().neg());
82                    }
83                    Bucket::Affine(a) => {
84                        a.sub_assign(other);
85                    }
86                }
87            }
88
89            fn add(self, mut other: EcPoint) -> EcPoint {
90                match self {
91                    Bucket::None => other.clone(),
92                    Bucket::Affine(a) => {
93                        other += a;
94                        other
95                    }
96                }
97            }
98        }
99
100        let mut buckets: Vec<Bucket<EcPoint>> = vec![Bucket::None; 1 << (c - 1)];
101
102        for (coeff, base) in coeffs.iter().zip(bases.iter()) {
103            let coeff = get_booth_index(current_window, c, coeff);
104            if coeff.is_positive() {
105                buckets[coeff as usize - 1].add_assign(base);
106            }
107            if coeff.is_negative() {
108                buckets[coeff.unsigned_abs() as usize - 1].sub_assign(base);
109            }
110        }
111
112        // Summation by parts
113        // e.g. 3a + 2b + 1c = a +
114        //                    (a) + b +
115        //                    ((a) + b) + c
116        let mut running_sum = EcPoint::IDENTITY;
117        for exp in buckets.into_iter().rev() {
118            running_sum = exp.add(running_sum);
119            acc = acc.add(&running_sum);
120        }
121    }
122    acc
123}
124
125fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
126    // Booth encoding:
127    // * step by `window` size
128    // * slice by size of `window + 1``
129    // * each window overlap by 1 bit * append a zero bit to the least significant end
130    // Indexing rule for example window size 3 where we slice by 4 bits:
131    // `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
132    // So we can reduce the bucket size without preprocessing scalars
133    // and remembering them as in classic signed digit encoding
134
135    let skip_bits = (window_index * window_size).saturating_sub(1);
136    let skip_bytes = skip_bits / 8;
137
138    // fill into a u32
139    let mut v: [u8; 4] = [0; 4];
140    for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
141        *dst = *src
142    }
143    let mut tmp = u32::from_le_bytes(v);
144
145    // pad with one 0 if slicing the least significant window
146    if window_index == 0 {
147        tmp <<= 1;
148    }
149
150    // remove further bits
151    tmp >>= skip_bits - (skip_bytes * 8);
152    // apply the booth window
153    tmp &= (1 << (window_size + 1)) - 1;
154
155    let sign = tmp & (1 << window_size) == 0;
156
157    // div ceil by 2
158    tmp = (tmp + 1) >> 1;
159
160    // find the booth action index
161    if sign {
162        tmp as i32
163    } else {
164        ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
165    }
166}