halo2_axiom/poly/ipa/
msm.rs

1use crate::arithmetic::{best_multiexp, CurveAffine};
2use crate::poly::{commitment::MSM, ipa::commitment::ParamsVerifierIPA};
3use ff::Field;
4use group::Group;
5use std::collections::BTreeMap;
6
7/// A multiscalar multiplication in the polynomial commitment scheme
8#[derive(Debug, Clone)]
9pub struct MSMIPA<'params, C: CurveAffine> {
10    pub(crate) params: &'params ParamsVerifierIPA<C>,
11    g_scalars: Option<Vec<C::Scalar>>,
12    w_scalar: Option<C::Scalar>,
13    u_scalar: Option<C::Scalar>,
14    // x-coordinate -> (scalar, y-coordinate)
15    other: BTreeMap<C::Base, (C::Scalar, C::Base)>,
16}
17
18impl<'a, C: CurveAffine> MSMIPA<'a, C> {
19    /// Given verifier parameters Creates an empty multi scalar engine
20    pub fn new(params: &'a ParamsVerifierIPA<C>) -> Self {
21        let g_scalars = None;
22        let w_scalar = None;
23        let u_scalar = None;
24        let other = BTreeMap::new();
25
26        Self {
27            g_scalars,
28            w_scalar,
29            u_scalar,
30            other,
31
32            params,
33        }
34    }
35
36    /// Add another multiexp into this one
37    pub fn add_msm(&mut self, other: &Self) {
38        for (x, (scalar, y)) in other.other.iter() {
39            self.other
40                .entry(*x)
41                .and_modify(|(our_scalar, our_y)| {
42                    if our_y == y {
43                        *our_scalar += *scalar;
44                    } else {
45                        assert!(*our_y == -*y);
46                        *our_scalar -= *scalar;
47                    }
48                })
49                .or_insert((*scalar, *y));
50        }
51
52        if let Some(g_scalars) = &other.g_scalars {
53            self.add_to_g_scalars(g_scalars);
54        }
55
56        if let Some(w_scalar) = &other.w_scalar {
57            self.add_to_w_scalar(*w_scalar);
58        }
59
60        if let Some(u_scalar) = &other.u_scalar {
61            self.add_to_u_scalar(*u_scalar);
62        }
63    }
64}
65
66impl<'a, C: CurveAffine> MSM<C> for MSMIPA<'a, C> {
67    fn append_term(&mut self, scalar: C::Scalar, point: C::Curve) {
68        if !bool::from(point.is_identity()) {
69            use group::Curve;
70            let point = point.to_affine();
71            let xy = point.coordinates().unwrap();
72            let x = *xy.x();
73            let y = *xy.y();
74
75            self.other
76                .entry(x)
77                .and_modify(|(our_scalar, our_y)| {
78                    if *our_y == y {
79                        *our_scalar += scalar;
80                    } else {
81                        assert!(*our_y == -y);
82                        *our_scalar -= scalar;
83                    }
84                })
85                .or_insert((scalar, y));
86        }
87    }
88
89    /// Add another multiexp into this one
90    fn add_msm(&mut self, other: &Self) {
91        for (x, (scalar, y)) in other.other.iter() {
92            self.other
93                .entry(*x)
94                .and_modify(|(our_scalar, our_y)| {
95                    if our_y == y {
96                        *our_scalar += *scalar;
97                    } else {
98                        assert!(*our_y == -*y);
99                        *our_scalar -= *scalar;
100                    }
101                })
102                .or_insert((*scalar, *y));
103        }
104
105        if let Some(g_scalars) = &other.g_scalars {
106            self.add_to_g_scalars(g_scalars);
107        }
108
109        if let Some(w_scalar) = &other.w_scalar {
110            self.add_to_w_scalar(*w_scalar);
111        }
112
113        if let Some(u_scalar) = &other.u_scalar {
114            self.add_to_u_scalar(*u_scalar);
115        }
116    }
117
118    fn scale(&mut self, factor: C::Scalar) {
119        if let Some(g_scalars) = &mut self.g_scalars {
120            for g_scalar in g_scalars {
121                *g_scalar *= &factor;
122            }
123        }
124
125        for other in self.other.values_mut() {
126            other.0 *= factor;
127        }
128
129        self.w_scalar = self.w_scalar.map(|a| a * &factor);
130        self.u_scalar = self.u_scalar.map(|a| a * &factor);
131    }
132
133    fn check(&self) -> bool {
134        bool::from(self.eval().is_identity())
135    }
136
137    fn eval(&self) -> C::Curve {
138        let len = self.g_scalars.as_ref().map(|v| v.len()).unwrap_or(0)
139            + self.w_scalar.map(|_| 1).unwrap_or(0)
140            + self.u_scalar.map(|_| 1).unwrap_or(0)
141            + self.other.len();
142        let mut scalars: Vec<C::Scalar> = Vec::with_capacity(len);
143        let mut bases: Vec<C> = Vec::with_capacity(len);
144
145        scalars.extend(self.other.values().map(|(scalar, _)| scalar));
146        bases.extend(
147            self.other
148                .iter()
149                .map(|(x, (_, y))| C::from_xy(*x, *y).unwrap()),
150        );
151
152        if let Some(w_scalar) = self.w_scalar {
153            scalars.push(w_scalar);
154            bases.push(self.params.w);
155        }
156
157        if let Some(u_scalar) = self.u_scalar {
158            scalars.push(u_scalar);
159            bases.push(self.params.u);
160        }
161
162        if let Some(g_scalars) = &self.g_scalars {
163            scalars.extend(g_scalars);
164            bases.extend(self.params.g.iter());
165        }
166
167        assert_eq!(scalars.len(), len);
168
169        best_multiexp(&scalars, &bases)
170    }
171
172    fn bases(&self) -> Vec<C::CurveExt> {
173        self.other
174            .iter()
175            .map(|(x, (_, y))| C::from_xy(*x, *y).unwrap().into())
176            .collect()
177    }
178
179    fn scalars(&self) -> Vec<C::Scalar> {
180        self.other.values().map(|(scalar, _)| *scalar).collect()
181    }
182}
183
184impl<'a, C: CurveAffine> MSMIPA<'a, C> {
185    /// Add a value to the first entry of `g_scalars`.
186    pub fn add_constant_term(&mut self, constant: C::Scalar) {
187        if let Some(g_scalars) = self.g_scalars.as_mut() {
188            g_scalars[0] += &constant;
189        } else {
190            let mut g_scalars = vec![C::Scalar::ZERO; self.params.n as usize];
191            g_scalars[0] += &constant;
192            self.g_scalars = Some(g_scalars);
193        }
194    }
195
196    /// Add a vector of scalars to `g_scalars`. This function will panic if the
197    /// caller provides a slice of scalars that is not of length `params.n`.
198    pub fn add_to_g_scalars(&mut self, scalars: &[C::Scalar]) {
199        assert_eq!(scalars.len(), self.params.n as usize);
200        if let Some(g_scalars) = &mut self.g_scalars {
201            for (g_scalar, scalar) in g_scalars.iter_mut().zip(scalars.iter()) {
202                *g_scalar += scalar;
203            }
204        } else {
205            self.g_scalars = Some(scalars.to_vec());
206        }
207    }
208    /// Add to `w_scalar`
209    pub fn add_to_w_scalar(&mut self, scalar: C::Scalar) {
210        self.w_scalar = self.w_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
211    }
212
213    /// Add to `u_scalar`
214    pub fn add_to_u_scalar(&mut self, scalar: C::Scalar) {
215        self.u_scalar = self.u_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use crate::poly::{
222        commitment::{ParamsProver, MSM},
223        ipa::{commitment::ParamsIPA, msm::MSMIPA},
224    };
225    use halo2curves::{
226        pasta::{Ep, EpAffine, Fp, Fq},
227        CurveAffine,
228    };
229
230    #[test]
231    fn msm_arithmetic() {
232        let base: Ep = EpAffine::from_xy(-Fp::one(), Fp::from(2)).unwrap().into();
233        let base_viol = base + base;
234
235        let params = ParamsIPA::new(4);
236        let mut a: MSMIPA<EpAffine> = MSMIPA::new(&params);
237        a.append_term(Fq::one(), base);
238        // a = [1] P
239        assert!(!a.clone().check());
240        a.append_term(Fq::one(), base);
241        // a = [1+1] P
242        assert!(!a.clone().check());
243        a.append_term(-Fq::one(), base_viol);
244        // a = [1+1] P + [-1] 2P
245        assert!(a.clone().check());
246        let b = a.clone();
247
248        // Append a point that is the negation of an existing one.
249        a.append_term(Fq::from(4), -base);
250        // a = [1+1-4] P + [-1] 2P
251        assert!(!a.clone().check());
252        a.append_term(Fq::from(2), base_viol);
253        // a = [1+1-4] P + [-1+2] 2P
254        assert!(a.clone().check());
255
256        // Add two MSMs with common bases.
257        a.scale(Fq::from(3));
258        a.add_msm(&b);
259        // a = [3*(1+1)+(1+1-4)] P + [3*(-1)+(-1+2)] 2P
260        assert!(a.clone().check());
261
262        let mut c: MSMIPA<EpAffine> = MSMIPA::new(&params);
263        c.append_term(Fq::from(2), base);
264        c.append_term(Fq::one(), -base_viol);
265        // c = [2] P + [1] (-2P)
266        assert!(c.clone().check());
267        // Add two MSMs with bases that differ only in sign.
268        a.add_msm(&c);
269        assert!(a.check());
270    }
271}