openvm_ecc_guest/
msm.rs
1use alloc::{vec, vec::Vec};
2use core::ops::{Add, Neg};
3
4use openvm_algebra_guest::IntMod;
5
6use super::Group;
7
8pub fn msm<EcPoint: Group, Scalar: IntMod>(coeffs: &[Scalar], bases: &[EcPoint]) -> EcPoint
11where
12 for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
13{
14 let coeffs: Vec<_> = coeffs.iter().map(|c| c.as_le_bytes()).collect();
15 let mut acc = EcPoint::IDENTITY;
16
17 let c = if bases.len() < 4 {
19 1
20 } else if bases.len() < 32 {
21 3
22 } else {
23 bases.len().ilog2() as usize
25 };
26
27 let field_byte_size = Scalar::NUM_LIMBS;
28
29 let mut acc_or = vec![0; field_byte_size];
32 for coeff in &coeffs {
33 for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
34 *acc_limb |= *limb;
35 }
36 }
37 let max_byte_size = field_byte_size
38 - acc_or
39 .iter()
40 .rev()
41 .position(|v| *v != 0)
42 .unwrap_or(field_byte_size);
43 if max_byte_size == 0 {
44 return EcPoint::IDENTITY;
45 }
46 let number_of_windows = max_byte_size * 8_usize / c + 1;
47
48 for current_window in (0..number_of_windows).rev() {
49 for _ in 0..c {
50 acc.double_assign();
51 }
52 #[derive(Clone)]
53 enum Bucket<EcPoint: Group> {
54 None,
55 Affine(EcPoint),
56 }
57
58 impl<EcPoint: Group> Bucket<EcPoint>
59 where
60 for<'a> &'a EcPoint: Add<&'a EcPoint, Output = EcPoint>,
61 {
62 fn add_assign(&mut self, other: &EcPoint) {
63 match self {
64 Bucket::None => {
65 *self = Bucket::Affine(other.clone());
66 }
67 Bucket::Affine(a) => {
68 a.add_assign(other);
69 }
70 }
71 }
72
73 fn sub_assign(&mut self, other: &EcPoint) {
74 match self {
75 Bucket::None => {
76 *self = Bucket::Affine(other.clone().neg());
77 }
78 Bucket::Affine(a) => {
79 a.sub_assign(other);
80 }
81 }
82 }
83
84 fn add(self, mut other: EcPoint) -> EcPoint {
85 match self {
86 Bucket::None => other.clone(),
87 Bucket::Affine(a) => {
88 other += a;
89 other
90 }
91 }
92 }
93 }
94
95 let mut buckets: Vec<Bucket<EcPoint>> = vec![Bucket::None; 1 << (c - 1)];
96
97 for (coeff, base) in coeffs.iter().zip(bases.iter()) {
98 let coeff = get_booth_index(current_window, c, coeff);
99 if coeff.is_positive() {
100 buckets[coeff as usize - 1].add_assign(base);
101 }
102 if coeff.is_negative() {
103 buckets[coeff.unsigned_abs() as usize - 1].sub_assign(base);
104 }
105 }
106
107 let mut running_sum = EcPoint::IDENTITY;
112 for exp in buckets.into_iter().rev() {
113 running_sum = exp.add(running_sum);
114 acc = acc.add(&running_sum);
115 }
116 }
117 acc
118}
119
120fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
121 let skip_bits = (window_index * window_size).saturating_sub(1);
131 let skip_bytes = skip_bits / 8;
132
133 let mut v: [u8; 4] = [0; 4];
135 for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
136 *dst = *src
137 }
138 let mut tmp = u32::from_le_bytes(v);
139
140 if window_index == 0 {
142 tmp <<= 1;
143 }
144
145 tmp >>= skip_bits - (skip_bytes * 8);
147 tmp &= (1 << (window_size + 1)) - 1;
149
150 let sign = tmp & (1 << window_size) == 0;
151
152 tmp = (tmp + 1) >> 1;
154
155 if sign {
157 tmp as i32
158 } else {
159 ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
160 }
161}