1use alloc::{vec, vec::Vec};
2use core::ops::{Add, Neg};
3
4use openvm_algebra_guest::IntMod;
5
6use super::Group;
7
8pub fn msm<EcPoint: Group, Scalar: IntMod>(coeffs: &[Scalar], bases: &[EcPoint]) -> EcPoint
11where
12 for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
13{
14 debug_assert_eq!(
15 coeffs.len(),
16 bases.len(),
17 "msm: coefficients and bases must have the same length"
18 );
19 let coeffs: Vec<_> = coeffs.iter().map(|c| c.as_le_bytes()).collect();
20 let mut acc = EcPoint::IDENTITY;
21
22 let c = if bases.len() < 4 {
24 1
25 } else if bases.len() < 32 {
26 3
27 } else {
28 bases.len().ilog2() as usize
30 };
31
32 let field_byte_size = Scalar::NUM_LIMBS;
33
34 let mut acc_or = vec![0; field_byte_size];
37 for coeff in &coeffs {
38 for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
39 *acc_limb |= *limb;
40 }
41 }
42 let max_byte_size = field_byte_size
43 - acc_or
44 .iter()
45 .rev()
46 .position(|v| *v != 0)
47 .unwrap_or(field_byte_size);
48 if max_byte_size == 0 {
49 return EcPoint::IDENTITY;
50 }
51 let number_of_windows = max_byte_size * 8_usize / c + 1;
52
53 for current_window in (0..number_of_windows).rev() {
54 for _ in 0..c {
55 acc.double_assign();
56 }
57 #[derive(Clone)]
58 enum Bucket<EcPoint: Group> {
59 None,
60 Affine(EcPoint),
61 }
62
63 impl<EcPoint: Group> Bucket<EcPoint>
64 where
65 for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
66 {
67 fn add_assign(&mut self, other: &EcPoint) {
68 match self {
69 Bucket::None => {
70 *self = Bucket::Affine(other.clone());
71 }
72 Bucket::Affine(a) => {
73 a.add_assign(other);
74 }
75 }
76 }
77
78 fn sub_assign(&mut self, other: &EcPoint) {
79 match self {
80 Bucket::None => {
81 *self = Bucket::Affine(other.clone().neg());
82 }
83 Bucket::Affine(a) => {
84 a.sub_assign(other);
85 }
86 }
87 }
88
89 fn add(self, mut other: EcPoint) -> EcPoint {
90 match self {
91 Bucket::None => other.clone(),
92 Bucket::Affine(a) => {
93 other += a;
94 other
95 }
96 }
97 }
98 }
99
100 let mut buckets: Vec<Bucket<EcPoint>> = vec![Bucket::None; 1 << (c - 1)];
101
102 for (coeff, base) in coeffs.iter().zip(bases.iter()) {
103 let coeff = get_booth_index(current_window, c, coeff);
104 if coeff.is_positive() {
105 buckets[coeff as usize - 1].add_assign(base);
106 }
107 if coeff.is_negative() {
108 buckets[coeff.unsigned_abs() as usize - 1].sub_assign(base);
109 }
110 }
111
112 let mut running_sum = EcPoint::IDENTITY;
117 for exp in buckets.into_iter().rev() {
118 running_sum = exp.add(running_sum);
119 acc = acc.add(&running_sum);
120 }
121 }
122 acc
123}
124
125fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
126 let skip_bits = (window_index * window_size).saturating_sub(1);
136 let skip_bytes = skip_bits / 8;
137
138 let mut v: [u8; 4] = [0; 4];
140 for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
141 *dst = *src
142 }
143 let mut tmp = u32::from_le_bytes(v);
144
145 if window_index == 0 {
147 tmp <<= 1;
148 }
149
150 tmp >>= skip_bits - (skip_bytes * 8);
152 tmp &= (1 << (window_size + 1)) - 1;
154
155 let sign = tmp & (1 << window_size) == 0;
156
157 tmp = (tmp + 1) >> 1;
159
160 if sign {
162 tmp as i32
163 } else {
164 ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
165 }
166}