snark_verifier/util/
msm.rs

1//! Multi-scalar multiplication algorithm.
2
3use crate::{
4    loader::{LoadedEcPoint, Loader},
5    util::{
6        arithmetic::{CurveAffine, Group, PrimeField},
7        Itertools,
8    },
9};
10use num_integer::Integer;
11use std::{
12    default::Default,
13    iter::{self, Sum},
14    mem::size_of,
15    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
16};
17
18#[derive(Clone, Debug)]
19/// Contains unevaluated multi-scalar multiplication.
20pub struct Msm<'a, C: CurveAffine, L: Loader<C>> {
21    constant: Option<L::LoadedScalar>,
22    scalars: Vec<L::LoadedScalar>,
23    bases: Vec<&'a L::LoadedEcPoint>,
24}
25
26impl<C, L> Default for Msm<'_, C, L>
27where
28    C: CurveAffine,
29    L: Loader<C>,
30{
31    fn default() -> Self {
32        Self { constant: None, scalars: Vec::new(), bases: Vec::new() }
33    }
34}
35
36impl<'a, C, L> Msm<'a, C, L>
37where
38    C: CurveAffine,
39    L: Loader<C>,
40{
41    /// Initialize with a constant.
42    pub fn constant(constant: L::LoadedScalar) -> Self {
43        Msm { constant: Some(constant), ..Default::default() }
44    }
45
46    /// Initialize with a base.
47    pub fn base<'b: 'a>(base: &'b L::LoadedEcPoint) -> Self {
48        let one = base.loader().load_one();
49        Msm { scalars: vec![one], bases: vec![base], ..Default::default() }
50    }
51
52    pub(crate) fn size(&self) -> usize {
53        self.bases.len()
54    }
55
56    pub(crate) fn split(mut self) -> (Self, Option<L::LoadedScalar>) {
57        let constant = self.constant.take();
58        (self, constant)
59    }
60
61    pub(crate) fn try_into_constant(self) -> Option<L::LoadedScalar> {
62        self.bases.is_empty().then(|| self.constant.unwrap())
63    }
64
65    /// Evaluate multi-scalar multiplication.
66    ///
67    /// # Panic
68    ///
69    /// If given `gen` is `None` but there `constant` has some value.
70    pub fn evaluate(self, gen: Option<C>) -> L::LoadedEcPoint {
71        let gen = gen.map(|gen| self.bases.first().unwrap().loader().ec_point_load_const(&gen));
72        let pairs = iter::empty()
73            .chain(self.constant.as_ref().map(|constant| (constant, gen.as_ref().unwrap())))
74            .chain(self.scalars.iter().zip(self.bases))
75            .collect_vec();
76        L::multi_scalar_multiplication(&pairs)
77    }
78
79    fn scale(&mut self, factor: &L::LoadedScalar) {
80        if let Some(constant) = self.constant.as_mut() {
81            *constant *= factor;
82        }
83        for scalar in self.scalars.iter_mut() {
84            *scalar *= factor
85        }
86    }
87
88    fn push<'b: 'a>(&mut self, scalar: L::LoadedScalar, base: &'b L::LoadedEcPoint) {
89        if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) {
90            self.scalars[pos] += &scalar;
91        } else {
92            self.scalars.push(scalar);
93            self.bases.push(base);
94        }
95    }
96
97    fn extend<'b: 'a>(&mut self, mut other: Msm<'b, C, L>) {
98        match (self.constant.as_mut(), other.constant.as_ref()) {
99            (Some(lhs), Some(rhs)) => *lhs += rhs,
100            (None, Some(_)) => self.constant = other.constant.take(),
101            _ => {}
102        };
103        for (scalar, base) in other.scalars.into_iter().zip(other.bases) {
104            self.push(scalar, base);
105        }
106    }
107}
108
109impl<'a, 'b, C, L> Add<Msm<'b, C, L>> for Msm<'a, C, L>
110where
111    'b: 'a,
112    C: CurveAffine,
113    L: Loader<C>,
114{
115    type Output = Msm<'a, C, L>;
116
117    fn add(mut self, rhs: Msm<'b, C, L>) -> Self::Output {
118        self.extend(rhs);
119        self
120    }
121}
122
123impl<'a, 'b, C, L> AddAssign<Msm<'b, C, L>> for Msm<'a, C, L>
124where
125    'b: 'a,
126    C: CurveAffine,
127    L: Loader<C>,
128{
129    fn add_assign(&mut self, rhs: Msm<'b, C, L>) {
130        self.extend(rhs);
131    }
132}
133
134impl<'a, 'b, C, L> Sub<Msm<'b, C, L>> for Msm<'a, C, L>
135where
136    'b: 'a,
137    C: CurveAffine,
138    L: Loader<C>,
139{
140    type Output = Msm<'a, C, L>;
141
142    fn sub(mut self, rhs: Msm<'b, C, L>) -> Self::Output {
143        self.extend(-rhs);
144        self
145    }
146}
147
148impl<'a, 'b, C, L> SubAssign<Msm<'b, C, L>> for Msm<'a, C, L>
149where
150    'b: 'a,
151    C: CurveAffine,
152    L: Loader<C>,
153{
154    fn sub_assign(&mut self, rhs: Msm<'b, C, L>) {
155        self.extend(-rhs);
156    }
157}
158
159impl<'a, C, L> Mul<&L::LoadedScalar> for Msm<'a, C, L>
160where
161    C: CurveAffine,
162    L: Loader<C>,
163{
164    type Output = Msm<'a, C, L>;
165
166    fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output {
167        self.scale(rhs);
168        self
169    }
170}
171
172impl<C, L> MulAssign<&L::LoadedScalar> for Msm<'_, C, L>
173where
174    C: CurveAffine,
175    L: Loader<C>,
176{
177    fn mul_assign(&mut self, rhs: &L::LoadedScalar) {
178        self.scale(rhs);
179    }
180}
181
182impl<'a, C, L> Neg for Msm<'a, C, L>
183where
184    C: CurveAffine,
185    L: Loader<C>,
186{
187    type Output = Msm<'a, C, L>;
188    fn neg(mut self) -> Msm<'a, C, L> {
189        self.constant = self.constant.map(|constant| -constant);
190        for scalar in self.scalars.iter_mut() {
191            *scalar = -scalar.clone();
192        }
193        self
194    }
195}
196
197impl<C, L> Sum for Msm<'_, C, L>
198where
199    C: CurveAffine,
200    L: Loader<C>,
201{
202    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
203        iter.reduce(|acc, item| acc + item).unwrap_or_default()
204    }
205}
206
207#[derive(Clone, Copy)]
208enum Bucket<C: CurveAffine> {
209    None,
210    Affine(C),
211    Projective(C::Curve),
212}
213
214impl<C: CurveAffine> Bucket<C> {
215    fn add_assign(&mut self, rhs: &C) {
216        *self = match *self {
217            Bucket::None => Bucket::Affine(*rhs),
218            Bucket::Affine(lhs) => Bucket::Projective(lhs + *rhs),
219            Bucket::Projective(mut lhs) => {
220                lhs += *rhs;
221                Bucket::Projective(lhs)
222            }
223        }
224    }
225
226    fn add(self, mut rhs: C::Curve) -> C::Curve {
227        match self {
228            Bucket::None => rhs,
229            Bucket::Affine(lhs) => {
230                rhs += lhs;
231                rhs
232            }
233            Bucket::Projective(lhs) => lhs + rhs,
234        }
235    }
236}
237
238fn multi_scalar_multiplication_serial<C: CurveAffine>(
239    scalars: &[C::Scalar],
240    bases: &[C],
241    result: &mut C::Curve,
242) {
243    let scalars = scalars.iter().map(|scalar| scalar.to_repr()).collect_vec();
244    let num_bytes = scalars[0].as_ref().len();
245    let num_bits = 8 * num_bytes;
246
247    let window_size = (scalars.len() as f64).ln().ceil() as usize + 2;
248    let num_buckets = (1 << window_size) - 1;
249
250    let windowed_scalar = |idx: usize, bytes: &<C::Scalar as PrimeField>::Repr| {
251        let skip_bits = idx * window_size;
252        let skip_bytes = skip_bits / 8;
253
254        let mut value = [0; size_of::<usize>()];
255        for (dst, src) in value.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
256            *dst = *src;
257        }
258
259        (usize::from_le_bytes(value) >> (skip_bits - (skip_bytes * 8))) & num_buckets
260    };
261
262    let num_window = Integer::div_ceil(&num_bits, &window_size);
263    for idx in (0..num_window).rev() {
264        for _ in 0..window_size {
265            *result = result.double();
266        }
267
268        let mut buckets = vec![Bucket::None; num_buckets];
269
270        for (scalar, base) in scalars.iter().zip(bases.iter()) {
271            let scalar = windowed_scalar(idx, scalar);
272            if scalar != 0 {
273                buckets[scalar - 1].add_assign(base);
274            }
275        }
276
277        let mut running_sum = C::Curve::identity();
278        for bucket in buckets.into_iter().rev() {
279            running_sum = bucket.add(running_sum);
280            *result += &running_sum;
281        }
282    }
283}
284
285/// Multi-scalar multiplication algorithm copied from
286/// <https://github.com/zcash/halo2/blob/main/halo2_proofs/src/arithmetic.rs>.
287pub fn multi_scalar_multiplication<C: CurveAffine>(scalars: &[C::Scalar], bases: &[C]) -> C::Curve {
288    assert_eq!(scalars.len(), bases.len());
289
290    #[cfg(feature = "parallel")]
291    {
292        use crate::util::{current_num_threads, parallelize_iter};
293
294        let num_threads = current_num_threads();
295        if scalars.len() < num_threads {
296            let mut result = C::Curve::identity();
297            multi_scalar_multiplication_serial(scalars, bases, &mut result);
298            return result;
299        }
300
301        let chunk_size = Integer::div_ceil(&scalars.len(), &num_threads);
302        let mut results = vec![C::Curve::identity(); num_threads];
303        parallelize_iter(
304            scalars.chunks(chunk_size).zip(bases.chunks(chunk_size)).zip(results.iter_mut()),
305            |((scalars, bases), result)| {
306                multi_scalar_multiplication_serial(scalars, bases, result);
307            },
308        );
309        results.iter().fold(C::Curve::identity(), |acc, result| acc + result)
310    }
311    #[cfg(not(feature = "parallel"))]
312    {
313        let mut result = C::Curve::identity();
314        multi_scalar_multiplication_serial(scalars, bases, &mut result);
315        result
316    }
317}