halo2_proofs/poly/commitment/
msm.rs
1use super::Params;
2use crate::arithmetic::{best_multiexp, parallelize, CurveAffine};
3use ff::Field;
4use group::Group;
5
6#[derive(Debug, Clone)]
8pub struct MSM<'a, C: CurveAffine> {
9 pub(crate) params: &'a Params<C>,
10 g_scalars: Option<Vec<C::Scalar>>,
11 w_scalar: Option<C::Scalar>,
12 u_scalar: Option<C::Scalar>,
13 other_scalars: Vec<C::Scalar>,
14 other_bases: Vec<C>,
15}
16
17impl<'a, C: CurveAffine> MSM<'a, C> {
18 pub fn new(params: &'a Params<C>) -> Self {
20 let g_scalars = None;
21 let w_scalar = None;
22 let u_scalar = None;
23 let other_scalars = vec![];
24 let other_bases = vec![];
25
26 MSM {
27 params,
28 g_scalars,
29 w_scalar,
30 u_scalar,
31 other_scalars,
32 other_bases,
33 }
34 }
35
36 pub fn add_msm(&mut self, other: &Self) {
38 self.other_scalars.extend(other.other_scalars.iter());
39 self.other_bases.extend(other.other_bases.iter());
40
41 if let Some(g_scalars) = &other.g_scalars {
42 self.add_to_g_scalars(g_scalars);
43 }
44
45 if let Some(w_scalar) = &other.w_scalar {
46 self.add_to_w_scalar(*w_scalar);
47 }
48
49 if let Some(u_scalar) = &other.u_scalar {
50 self.add_to_u_scalar(*u_scalar);
51 }
52 }
53
54 pub fn append_term(&mut self, scalar: C::Scalar, point: C) {
56 self.other_scalars.push(scalar);
57 self.other_bases.push(point);
58 }
59
60 pub fn add_constant_term(&mut self, constant: C::Scalar) {
62 if let Some(g_scalars) = self.g_scalars.as_mut() {
63 g_scalars[0] += &constant;
64 } else {
65 let mut g_scalars = vec![C::Scalar::zero(); self.params.n as usize];
66 g_scalars[0] += &constant;
67 self.g_scalars = Some(g_scalars);
68 }
69 }
70
71 pub fn add_to_g_scalars(&mut self, scalars: &[C::Scalar]) {
74 assert_eq!(scalars.len(), self.params.n as usize);
75 if let Some(g_scalars) = &mut self.g_scalars {
76 parallelize(g_scalars, |g_scalars, start| {
77 for (g_scalar, scalar) in g_scalars.iter_mut().zip(scalars[start..].iter()) {
78 *g_scalar += scalar;
79 }
80 })
81 } else {
82 self.g_scalars = Some(scalars.to_vec());
83 }
84 }
85
86 pub fn add_to_w_scalar(&mut self, scalar: C::Scalar) {
88 self.w_scalar = self.w_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
89 }
90
91 pub fn add_to_u_scalar(&mut self, scalar: C::Scalar) {
93 self.u_scalar = self.u_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
94 }
95
96 pub fn scale(&mut self, factor: C::Scalar) {
98 if let Some(g_scalars) = &mut self.g_scalars {
99 parallelize(g_scalars, |g_scalars, _| {
100 for g_scalar in g_scalars {
101 *g_scalar *= &factor;
102 }
103 })
104 }
105
106 if !self.other_scalars.is_empty() {
107 parallelize(&mut self.other_scalars, |other_scalars, _| {
108 for other_scalar in other_scalars {
109 *other_scalar *= &factor;
110 }
111 })
112 }
113
114 self.w_scalar = self.w_scalar.map(|a| a * &factor);
115 self.u_scalar = self.u_scalar.map(|a| a * &factor);
116 }
117
118 pub fn eval(self) -> bool {
120 let len = self.g_scalars.as_ref().map(|v| v.len()).unwrap_or(0)
121 + self.w_scalar.map(|_| 1).unwrap_or(0)
122 + self.u_scalar.map(|_| 1).unwrap_or(0)
123 + self.other_scalars.len();
124 let mut scalars: Vec<C::Scalar> = Vec::with_capacity(len);
125 let mut bases: Vec<C> = Vec::with_capacity(len);
126
127 scalars.extend(&self.other_scalars);
128 bases.extend(&self.other_bases);
129
130 if let Some(w_scalar) = self.w_scalar {
131 scalars.push(w_scalar);
132 bases.push(self.params.w);
133 }
134
135 if let Some(u_scalar) = self.u_scalar {
136 scalars.push(u_scalar);
137 bases.push(self.params.u);
138 }
139
140 if let Some(g_scalars) = &self.g_scalars {
141 scalars.extend(g_scalars);
142 bases.extend(self.params.g.iter());
143 }
144
145 assert_eq!(scalars.len(), len);
146
147 bool::from(best_multiexp(&scalars, &bases).is_identity())
148 }
149}