halo2_proofs/poly/commitment/
msm.rs

1use super::Params;
2use crate::arithmetic::{best_multiexp, parallelize, CurveAffine};
3use ff::Field;
4use group::Group;
5
6/// A multiscalar multiplication in the polynomial commitment scheme
7#[derive(Debug, Clone)]
8pub struct MSM<'a, C: CurveAffine> {
9    pub(crate) params: &'a Params<C>,
10    g_scalars: Option<Vec<C::Scalar>>,
11    w_scalar: Option<C::Scalar>,
12    u_scalar: Option<C::Scalar>,
13    other_scalars: Vec<C::Scalar>,
14    other_bases: Vec<C>,
15}
16
17impl<'a, C: CurveAffine> MSM<'a, C> {
18    /// Create a new, empty MSM using the provided parameters.
19    pub fn new(params: &'a Params<C>) -> Self {
20        let g_scalars = None;
21        let w_scalar = None;
22        let u_scalar = None;
23        let other_scalars = vec![];
24        let other_bases = vec![];
25
26        MSM {
27            params,
28            g_scalars,
29            w_scalar,
30            u_scalar,
31            other_scalars,
32            other_bases,
33        }
34    }
35
36    /// Add another multiexp into this one
37    pub fn add_msm(&mut self, other: &Self) {
38        self.other_scalars.extend(other.other_scalars.iter());
39        self.other_bases.extend(other.other_bases.iter());
40
41        if let Some(g_scalars) = &other.g_scalars {
42            self.add_to_g_scalars(g_scalars);
43        }
44
45        if let Some(w_scalar) = &other.w_scalar {
46            self.add_to_w_scalar(*w_scalar);
47        }
48
49        if let Some(u_scalar) = &other.u_scalar {
50            self.add_to_u_scalar(*u_scalar);
51        }
52    }
53
54    /// Add arbitrary term (the scalar and the point)
55    pub fn append_term(&mut self, scalar: C::Scalar, point: C) {
56        self.other_scalars.push(scalar);
57        self.other_bases.push(point);
58    }
59
60    /// Add a value to the first entry of `g_scalars`.
61    pub fn add_constant_term(&mut self, constant: C::Scalar) {
62        if let Some(g_scalars) = self.g_scalars.as_mut() {
63            g_scalars[0] += &constant;
64        } else {
65            let mut g_scalars = vec![C::Scalar::zero(); self.params.n as usize];
66            g_scalars[0] += &constant;
67            self.g_scalars = Some(g_scalars);
68        }
69    }
70
71    /// Add a vector of scalars to `g_scalars`. This function will panic if the
72    /// caller provides a slice of scalars that is not of length `params.n`.
73    pub fn add_to_g_scalars(&mut self, scalars: &[C::Scalar]) {
74        assert_eq!(scalars.len(), self.params.n as usize);
75        if let Some(g_scalars) = &mut self.g_scalars {
76            parallelize(g_scalars, |g_scalars, start| {
77                for (g_scalar, scalar) in g_scalars.iter_mut().zip(scalars[start..].iter()) {
78                    *g_scalar += scalar;
79                }
80            })
81        } else {
82            self.g_scalars = Some(scalars.to_vec());
83        }
84    }
85
86    /// Add to `w_scalar`
87    pub fn add_to_w_scalar(&mut self, scalar: C::Scalar) {
88        self.w_scalar = self.w_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
89    }
90
91    /// Add to `u_scalar`
92    pub fn add_to_u_scalar(&mut self, scalar: C::Scalar) {
93        self.u_scalar = self.u_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
94    }
95
96    /// Scale all scalars in the MSM by some scaling factor
97    pub fn scale(&mut self, factor: C::Scalar) {
98        if let Some(g_scalars) = &mut self.g_scalars {
99            parallelize(g_scalars, |g_scalars, _| {
100                for g_scalar in g_scalars {
101                    *g_scalar *= &factor;
102                }
103            })
104        }
105
106        if !self.other_scalars.is_empty() {
107            parallelize(&mut self.other_scalars, |other_scalars, _| {
108                for other_scalar in other_scalars {
109                    *other_scalar *= &factor;
110                }
111            })
112        }
113
114        self.w_scalar = self.w_scalar.map(|a| a * &factor);
115        self.u_scalar = self.u_scalar.map(|a| a * &factor);
116    }
117
118    /// Perform multiexp and check that it results in zero
119    pub fn eval(self) -> bool {
120        let len = self.g_scalars.as_ref().map(|v| v.len()).unwrap_or(0)
121            + self.w_scalar.map(|_| 1).unwrap_or(0)
122            + self.u_scalar.map(|_| 1).unwrap_or(0)
123            + self.other_scalars.len();
124        let mut scalars: Vec<C::Scalar> = Vec::with_capacity(len);
125        let mut bases: Vec<C> = Vec::with_capacity(len);
126
127        scalars.extend(&self.other_scalars);
128        bases.extend(&self.other_bases);
129
130        if let Some(w_scalar) = self.w_scalar {
131            scalars.push(w_scalar);
132            bases.push(self.params.w);
133        }
134
135        if let Some(u_scalar) = self.u_scalar {
136            scalars.push(u_scalar);
137            bases.push(self.params.u);
138        }
139
140        if let Some(g_scalars) = &self.g_scalars {
141            scalars.extend(g_scalars);
142            bases.extend(self.params.g.iter());
143        }
144
145        assert_eq!(scalars.len(), len);
146
147        bool::from(best_multiexp(&scalars, &bases).is_identity())
148    }
149}