halo2_axiom/poly/ipa/
msm.rs
1use crate::arithmetic::{best_multiexp, CurveAffine};
2use crate::poly::{commitment::MSM, ipa::commitment::ParamsVerifierIPA};
3use ff::Field;
4use group::Group;
5use std::collections::BTreeMap;
6
7#[derive(Debug, Clone)]
9pub struct MSMIPA<'params, C: CurveAffine> {
10 pub(crate) params: &'params ParamsVerifierIPA<C>,
11 g_scalars: Option<Vec<C::Scalar>>,
12 w_scalar: Option<C::Scalar>,
13 u_scalar: Option<C::Scalar>,
14 other: BTreeMap<C::Base, (C::Scalar, C::Base)>,
16}
17
18impl<'a, C: CurveAffine> MSMIPA<'a, C> {
19 pub fn new(params: &'a ParamsVerifierIPA<C>) -> Self {
21 let g_scalars = None;
22 let w_scalar = None;
23 let u_scalar = None;
24 let other = BTreeMap::new();
25
26 Self {
27 g_scalars,
28 w_scalar,
29 u_scalar,
30 other,
31
32 params,
33 }
34 }
35
36 pub fn add_msm(&mut self, other: &Self) {
38 for (x, (scalar, y)) in other.other.iter() {
39 self.other
40 .entry(*x)
41 .and_modify(|(our_scalar, our_y)| {
42 if our_y == y {
43 *our_scalar += *scalar;
44 } else {
45 assert!(*our_y == -*y);
46 *our_scalar -= *scalar;
47 }
48 })
49 .or_insert((*scalar, *y));
50 }
51
52 if let Some(g_scalars) = &other.g_scalars {
53 self.add_to_g_scalars(g_scalars);
54 }
55
56 if let Some(w_scalar) = &other.w_scalar {
57 self.add_to_w_scalar(*w_scalar);
58 }
59
60 if let Some(u_scalar) = &other.u_scalar {
61 self.add_to_u_scalar(*u_scalar);
62 }
63 }
64}
65
66impl<'a, C: CurveAffine> MSM<C> for MSMIPA<'a, C> {
67 fn append_term(&mut self, scalar: C::Scalar, point: C::Curve) {
68 if !bool::from(point.is_identity()) {
69 use group::Curve;
70 let point = point.to_affine();
71 let xy = point.coordinates().unwrap();
72 let x = *xy.x();
73 let y = *xy.y();
74
75 self.other
76 .entry(x)
77 .and_modify(|(our_scalar, our_y)| {
78 if *our_y == y {
79 *our_scalar += scalar;
80 } else {
81 assert!(*our_y == -y);
82 *our_scalar -= scalar;
83 }
84 })
85 .or_insert((scalar, y));
86 }
87 }
88
89 fn add_msm(&mut self, other: &Self) {
91 for (x, (scalar, y)) in other.other.iter() {
92 self.other
93 .entry(*x)
94 .and_modify(|(our_scalar, our_y)| {
95 if our_y == y {
96 *our_scalar += *scalar;
97 } else {
98 assert!(*our_y == -*y);
99 *our_scalar -= *scalar;
100 }
101 })
102 .or_insert((*scalar, *y));
103 }
104
105 if let Some(g_scalars) = &other.g_scalars {
106 self.add_to_g_scalars(g_scalars);
107 }
108
109 if let Some(w_scalar) = &other.w_scalar {
110 self.add_to_w_scalar(*w_scalar);
111 }
112
113 if let Some(u_scalar) = &other.u_scalar {
114 self.add_to_u_scalar(*u_scalar);
115 }
116 }
117
118 fn scale(&mut self, factor: C::Scalar) {
119 if let Some(g_scalars) = &mut self.g_scalars {
120 for g_scalar in g_scalars {
121 *g_scalar *= &factor;
122 }
123 }
124
125 for other in self.other.values_mut() {
126 other.0 *= factor;
127 }
128
129 self.w_scalar = self.w_scalar.map(|a| a * &factor);
130 self.u_scalar = self.u_scalar.map(|a| a * &factor);
131 }
132
133 fn check(&self) -> bool {
134 bool::from(self.eval().is_identity())
135 }
136
137 fn eval(&self) -> C::Curve {
138 let len = self.g_scalars.as_ref().map(|v| v.len()).unwrap_or(0)
139 + self.w_scalar.map(|_| 1).unwrap_or(0)
140 + self.u_scalar.map(|_| 1).unwrap_or(0)
141 + self.other.len();
142 let mut scalars: Vec<C::Scalar> = Vec::with_capacity(len);
143 let mut bases: Vec<C> = Vec::with_capacity(len);
144
145 scalars.extend(self.other.values().map(|(scalar, _)| scalar));
146 bases.extend(
147 self.other
148 .iter()
149 .map(|(x, (_, y))| C::from_xy(*x, *y).unwrap()),
150 );
151
152 if let Some(w_scalar) = self.w_scalar {
153 scalars.push(w_scalar);
154 bases.push(self.params.w);
155 }
156
157 if let Some(u_scalar) = self.u_scalar {
158 scalars.push(u_scalar);
159 bases.push(self.params.u);
160 }
161
162 if let Some(g_scalars) = &self.g_scalars {
163 scalars.extend(g_scalars);
164 bases.extend(self.params.g.iter());
165 }
166
167 assert_eq!(scalars.len(), len);
168
169 best_multiexp(&scalars, &bases)
170 }
171
172 fn bases(&self) -> Vec<C::CurveExt> {
173 self.other
174 .iter()
175 .map(|(x, (_, y))| C::from_xy(*x, *y).unwrap().into())
176 .collect()
177 }
178
179 fn scalars(&self) -> Vec<C::Scalar> {
180 self.other.values().map(|(scalar, _)| *scalar).collect()
181 }
182}
183
184impl<'a, C: CurveAffine> MSMIPA<'a, C> {
185 pub fn add_constant_term(&mut self, constant: C::Scalar) {
187 if let Some(g_scalars) = self.g_scalars.as_mut() {
188 g_scalars[0] += &constant;
189 } else {
190 let mut g_scalars = vec![C::Scalar::ZERO; self.params.n as usize];
191 g_scalars[0] += &constant;
192 self.g_scalars = Some(g_scalars);
193 }
194 }
195
196 pub fn add_to_g_scalars(&mut self, scalars: &[C::Scalar]) {
199 assert_eq!(scalars.len(), self.params.n as usize);
200 if let Some(g_scalars) = &mut self.g_scalars {
201 for (g_scalar, scalar) in g_scalars.iter_mut().zip(scalars.iter()) {
202 *g_scalar += scalar;
203 }
204 } else {
205 self.g_scalars = Some(scalars.to_vec());
206 }
207 }
208 pub fn add_to_w_scalar(&mut self, scalar: C::Scalar) {
210 self.w_scalar = self.w_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
211 }
212
213 pub fn add_to_u_scalar(&mut self, scalar: C::Scalar) {
215 self.u_scalar = self.u_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use crate::poly::{
222 commitment::{ParamsProver, MSM},
223 ipa::{commitment::ParamsIPA, msm::MSMIPA},
224 };
225 use halo2curves::{
226 pasta::{Ep, EpAffine, Fp, Fq},
227 CurveAffine,
228 };
229
230 #[test]
231 fn msm_arithmetic() {
232 let base: Ep = EpAffine::from_xy(-Fp::one(), Fp::from(2)).unwrap().into();
233 let base_viol = base + base;
234
235 let params = ParamsIPA::new(4);
236 let mut a: MSMIPA<EpAffine> = MSMIPA::new(¶ms);
237 a.append_term(Fq::one(), base);
238 assert!(!a.clone().check());
240 a.append_term(Fq::one(), base);
241 assert!(!a.clone().check());
243 a.append_term(-Fq::one(), base_viol);
244 assert!(a.clone().check());
246 let b = a.clone();
247
248 a.append_term(Fq::from(4), -base);
250 assert!(!a.clone().check());
252 a.append_term(Fq::from(2), base_viol);
253 assert!(a.clone().check());
255
256 a.scale(Fq::from(3));
258 a.add_msm(&b);
259 assert!(a.clone().check());
261
262 let mut c: MSMIPA<EpAffine> = MSMIPA::new(¶ms);
263 c.append_term(Fq::from(2), base);
264 c.append_term(Fq::one(), -base_viol);
265 assert!(c.clone().check());
267 a.add_msm(&c);
269 assert!(a.check());
270 }
271}