1use crate::{
4 loader::{LoadedEcPoint, Loader},
5 util::{
6 arithmetic::{CurveAffine, Group, PrimeField},
7 Itertools,
8 },
9};
10use num_integer::Integer;
11use std::{
12 default::Default,
13 iter::{self, Sum},
14 mem::size_of,
15 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
16};
17
18#[derive(Clone, Debug)]
19pub struct Msm<'a, C: CurveAffine, L: Loader<C>> {
21 constant: Option<L::LoadedScalar>,
22 scalars: Vec<L::LoadedScalar>,
23 bases: Vec<&'a L::LoadedEcPoint>,
24}
25
26impl<C, L> Default for Msm<'_, C, L>
27where
28 C: CurveAffine,
29 L: Loader<C>,
30{
31 fn default() -> Self {
32 Self { constant: None, scalars: Vec::new(), bases: Vec::new() }
33 }
34}
35
36impl<'a, C, L> Msm<'a, C, L>
37where
38 C: CurveAffine,
39 L: Loader<C>,
40{
41 pub fn constant(constant: L::LoadedScalar) -> Self {
43 Msm { constant: Some(constant), ..Default::default() }
44 }
45
46 pub fn base<'b: 'a>(base: &'b L::LoadedEcPoint) -> Self {
48 let one = base.loader().load_one();
49 Msm { scalars: vec![one], bases: vec![base], ..Default::default() }
50 }
51
52 pub(crate) fn size(&self) -> usize {
53 self.bases.len()
54 }
55
56 pub(crate) fn split(mut self) -> (Self, Option<L::LoadedScalar>) {
57 let constant = self.constant.take();
58 (self, constant)
59 }
60
61 pub(crate) fn try_into_constant(self) -> Option<L::LoadedScalar> {
62 self.bases.is_empty().then(|| self.constant.unwrap())
63 }
64
65 pub fn evaluate(self, gen: Option<C>) -> L::LoadedEcPoint {
71 let gen = gen.map(|gen| self.bases.first().unwrap().loader().ec_point_load_const(&gen));
72 let pairs = iter::empty()
73 .chain(self.constant.as_ref().map(|constant| (constant, gen.as_ref().unwrap())))
74 .chain(self.scalars.iter().zip(self.bases))
75 .collect_vec();
76 L::multi_scalar_multiplication(&pairs)
77 }
78
79 fn scale(&mut self, factor: &L::LoadedScalar) {
80 if let Some(constant) = self.constant.as_mut() {
81 *constant *= factor;
82 }
83 for scalar in self.scalars.iter_mut() {
84 *scalar *= factor
85 }
86 }
87
88 fn push<'b: 'a>(&mut self, scalar: L::LoadedScalar, base: &'b L::LoadedEcPoint) {
89 if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) {
90 self.scalars[pos] += &scalar;
91 } else {
92 self.scalars.push(scalar);
93 self.bases.push(base);
94 }
95 }
96
97 fn extend<'b: 'a>(&mut self, mut other: Msm<'b, C, L>) {
98 match (self.constant.as_mut(), other.constant.as_ref()) {
99 (Some(lhs), Some(rhs)) => *lhs += rhs,
100 (None, Some(_)) => self.constant = other.constant.take(),
101 _ => {}
102 };
103 for (scalar, base) in other.scalars.into_iter().zip(other.bases) {
104 self.push(scalar, base);
105 }
106 }
107}
108
109impl<'a, 'b, C, L> Add<Msm<'b, C, L>> for Msm<'a, C, L>
110where
111 'b: 'a,
112 C: CurveAffine,
113 L: Loader<C>,
114{
115 type Output = Msm<'a, C, L>;
116
117 fn add(mut self, rhs: Msm<'b, C, L>) -> Self::Output {
118 self.extend(rhs);
119 self
120 }
121}
122
123impl<'a, 'b, C, L> AddAssign<Msm<'b, C, L>> for Msm<'a, C, L>
124where
125 'b: 'a,
126 C: CurveAffine,
127 L: Loader<C>,
128{
129 fn add_assign(&mut self, rhs: Msm<'b, C, L>) {
130 self.extend(rhs);
131 }
132}
133
134impl<'a, 'b, C, L> Sub<Msm<'b, C, L>> for Msm<'a, C, L>
135where
136 'b: 'a,
137 C: CurveAffine,
138 L: Loader<C>,
139{
140 type Output = Msm<'a, C, L>;
141
142 fn sub(mut self, rhs: Msm<'b, C, L>) -> Self::Output {
143 self.extend(-rhs);
144 self
145 }
146}
147
148impl<'a, 'b, C, L> SubAssign<Msm<'b, C, L>> for Msm<'a, C, L>
149where
150 'b: 'a,
151 C: CurveAffine,
152 L: Loader<C>,
153{
154 fn sub_assign(&mut self, rhs: Msm<'b, C, L>) {
155 self.extend(-rhs);
156 }
157}
158
159impl<'a, C, L> Mul<&L::LoadedScalar> for Msm<'a, C, L>
160where
161 C: CurveAffine,
162 L: Loader<C>,
163{
164 type Output = Msm<'a, C, L>;
165
166 fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output {
167 self.scale(rhs);
168 self
169 }
170}
171
172impl<C, L> MulAssign<&L::LoadedScalar> for Msm<'_, C, L>
173where
174 C: CurveAffine,
175 L: Loader<C>,
176{
177 fn mul_assign(&mut self, rhs: &L::LoadedScalar) {
178 self.scale(rhs);
179 }
180}
181
182impl<'a, C, L> Neg for Msm<'a, C, L>
183where
184 C: CurveAffine,
185 L: Loader<C>,
186{
187 type Output = Msm<'a, C, L>;
188 fn neg(mut self) -> Msm<'a, C, L> {
189 self.constant = self.constant.map(|constant| -constant);
190 for scalar in self.scalars.iter_mut() {
191 *scalar = -scalar.clone();
192 }
193 self
194 }
195}
196
197impl<C, L> Sum for Msm<'_, C, L>
198where
199 C: CurveAffine,
200 L: Loader<C>,
201{
202 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
203 iter.reduce(|acc, item| acc + item).unwrap_or_default()
204 }
205}
206
207#[derive(Clone, Copy)]
208enum Bucket<C: CurveAffine> {
209 None,
210 Affine(C),
211 Projective(C::Curve),
212}
213
214impl<C: CurveAffine> Bucket<C> {
215 fn add_assign(&mut self, rhs: &C) {
216 *self = match *self {
217 Bucket::None => Bucket::Affine(*rhs),
218 Bucket::Affine(lhs) => Bucket::Projective(lhs + *rhs),
219 Bucket::Projective(mut lhs) => {
220 lhs += *rhs;
221 Bucket::Projective(lhs)
222 }
223 }
224 }
225
226 fn add(self, mut rhs: C::Curve) -> C::Curve {
227 match self {
228 Bucket::None => rhs,
229 Bucket::Affine(lhs) => {
230 rhs += lhs;
231 rhs
232 }
233 Bucket::Projective(lhs) => lhs + rhs,
234 }
235 }
236}
237
238fn multi_scalar_multiplication_serial<C: CurveAffine>(
239 scalars: &[C::Scalar],
240 bases: &[C],
241 result: &mut C::Curve,
242) {
243 let scalars = scalars.iter().map(|scalar| scalar.to_repr()).collect_vec();
244 let num_bytes = scalars[0].as_ref().len();
245 let num_bits = 8 * num_bytes;
246
247 let window_size = (scalars.len() as f64).ln().ceil() as usize + 2;
248 let num_buckets = (1 << window_size) - 1;
249
250 let windowed_scalar = |idx: usize, bytes: &<C::Scalar as PrimeField>::Repr| {
251 let skip_bits = idx * window_size;
252 let skip_bytes = skip_bits / 8;
253
254 let mut value = [0; size_of::<usize>()];
255 for (dst, src) in value.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
256 *dst = *src;
257 }
258
259 (usize::from_le_bytes(value) >> (skip_bits - (skip_bytes * 8))) & num_buckets
260 };
261
262 let num_window = Integer::div_ceil(&num_bits, &window_size);
263 for idx in (0..num_window).rev() {
264 for _ in 0..window_size {
265 *result = result.double();
266 }
267
268 let mut buckets = vec![Bucket::None; num_buckets];
269
270 for (scalar, base) in scalars.iter().zip(bases.iter()) {
271 let scalar = windowed_scalar(idx, scalar);
272 if scalar != 0 {
273 buckets[scalar - 1].add_assign(base);
274 }
275 }
276
277 let mut running_sum = C::Curve::identity();
278 for bucket in buckets.into_iter().rev() {
279 running_sum = bucket.add(running_sum);
280 *result += &running_sum;
281 }
282 }
283}
284
285pub fn multi_scalar_multiplication<C: CurveAffine>(scalars: &[C::Scalar], bases: &[C]) -> C::Curve {
288 assert_eq!(scalars.len(), bases.len());
289
290 #[cfg(feature = "parallel")]
291 {
292 use crate::util::{current_num_threads, parallelize_iter};
293
294 let num_threads = current_num_threads();
295 if scalars.len() < num_threads {
296 let mut result = C::Curve::identity();
297 multi_scalar_multiplication_serial(scalars, bases, &mut result);
298 return result;
299 }
300
301 let chunk_size = Integer::div_ceil(&scalars.len(), &num_threads);
302 let mut results = vec![C::Curve::identity(); num_threads];
303 parallelize_iter(
304 scalars.chunks(chunk_size).zip(bases.chunks(chunk_size)).zip(results.iter_mut()),
305 |((scalars, bases), result)| {
306 multi_scalar_multiplication_serial(scalars, bases, result);
307 },
308 );
309 results.iter().fold(C::Curve::identity(), |acc, result| acc + result)
310 }
311 #[cfg(not(feature = "parallel"))]
312 {
313 let mut result = C::Curve::identity();
314 multi_scalar_multiplication_serial(scalars, bases, &mut result);
315 result
316 }
317}