halo2_axiom/poly/kzg/
commitment.rs

1use crate::arithmetic::{best_multiexp, g_to_lagrange, parallelize};
2use crate::helpers::SerdeCurveAffine;
3use crate::poly::commitment::{Blind, CommitmentScheme, Params, ParamsProver, ParamsVerifier};
4use crate::poly::{Coeff, LagrangeCoeff, Polynomial};
5use crate::SerdeFormat;
6
7use ff::{Field, PrimeField};
8use group::{prime::PrimeCurveAffine, Curve, Group};
9use pairing::Engine;
10use rand_core::{OsRng, RngCore};
11use std::fmt::Debug;
12use std::marker::PhantomData;
13
14use std::io;
15
16use super::msm::MSMKZG;
17
18/// These are the public parameters for the polynomial commitment scheme.
19#[derive(Debug, Clone)]
20pub struct ParamsKZG<E: Engine> {
21    pub(crate) k: u32,
22    pub(crate) n: u64,
23    pub(crate) g: Vec<E::G1Affine>,
24    pub(crate) g_lagrange: Vec<E::G1Affine>,
25    pub(crate) g2: E::G2Affine,
26    pub(crate) s_g2: E::G2Affine,
27}
28
29/// Umbrella commitment scheme construction for all KZG variants
30#[derive(Debug)]
31pub struct KZGCommitmentScheme<E: Engine> {
32    _marker: PhantomData<E>,
33}
34
35impl<E: Engine + Debug> CommitmentScheme for KZGCommitmentScheme<E>
36where
37    E::G1Affine: SerdeCurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
38    E::G2Affine: SerdeCurveAffine,
39{
40    type Scalar = E::Fr;
41    type Curve = E::G1Affine;
42
43    type ParamsProver = ParamsKZG<E>;
44    type ParamsVerifier = ParamsVerifierKZG<E>;
45
46    fn new_params(k: u32) -> Self::ParamsProver {
47        ParamsKZG::new(k)
48    }
49
50    fn read_params<R: io::Read>(reader: &mut R) -> io::Result<Self::ParamsProver> {
51        ParamsKZG::read(reader)
52    }
53}
54
55impl<E: Engine + Debug> ParamsKZG<E> {
56    /// Initializes parameters for the curve, draws toxic secret from given rng.
57    /// MUST NOT be used in production.
58    pub fn setup<R: RngCore>(k: u32, rng: R) -> Self {
59        // Largest root of unity exponent of the Engine is `2^E::Scalar::S`, so we can
60        // only support FFTs of polynomials below degree `2^E::Scalar::S`.
61        assert!(k <= E::Fr::S);
62        let n: u64 = 1 << k;
63
64        // Calculate g = [G1, [s] G1, [s^2] G1, ..., [s^(n-1)] G1] in parallel.
65        let g1 = E::G1Affine::generator();
66        let s = <E::Fr>::random(rng);
67
68        let mut g_projective = vec![E::G1::identity(); n as usize];
69        parallelize(&mut g_projective, |g, start| {
70            let mut current_g: E::G1 = g1.into();
71            current_g *= s.pow_vartime([start as u64]);
72            for g in g.iter_mut() {
73                *g = current_g;
74                current_g *= s;
75            }
76        });
77
78        let g = {
79            let mut g = vec![E::G1Affine::identity(); n as usize];
80            parallelize(&mut g, |g, starts| {
81                E::G1::batch_normalize(&g_projective[starts..(starts + g.len())], g);
82            });
83            g
84        };
85
86        let mut g_lagrange_projective = vec![E::G1::identity(); n as usize];
87        let mut root = E::Fr::ROOT_OF_UNITY_INV.invert().unwrap();
88        for _ in k..E::Fr::S {
89            root = root.square();
90        }
91        let n_inv = Option::<E::Fr>::from(E::Fr::from(n).invert())
92            .expect("inversion should be ok for n = 1<<k");
93        let multiplier = (s.pow_vartime([n]) - E::Fr::ONE) * n_inv;
94        parallelize(&mut g_lagrange_projective, |g, start| {
95            for (idx, g) in g.iter_mut().enumerate() {
96                let offset = start + idx;
97                let root_pow = root.pow_vartime([offset as u64]);
98                let scalar = multiplier * root_pow * (s - root_pow).invert().unwrap();
99                *g = g1 * scalar;
100            }
101        });
102
103        let g_lagrange = {
104            let mut g_lagrange = vec![E::G1Affine::identity(); n as usize];
105            parallelize(&mut g_lagrange, |g_lagrange, starts| {
106                E::G1::batch_normalize(
107                    &g_lagrange_projective[starts..(starts + g_lagrange.len())],
108                    g_lagrange,
109                );
110            });
111            drop(g_lagrange_projective);
112            g_lagrange
113        };
114
115        let g2 = <E::G2Affine as PrimeCurveAffine>::generator();
116        let s_g2 = (g2 * s).into();
117
118        Self {
119            k,
120            n,
121            g,
122            g_lagrange,
123            g2,
124            s_g2,
125        }
126    }
127
128    /// Initializes parameters for the curve through existing parameters
129    /// k, g, g_lagrange (optional), g2, s_g2
130    pub fn from_parts(
131        &self,
132        k: u32,
133        g: Vec<E::G1Affine>,
134        g_lagrange: Option<Vec<E::G1Affine>>,
135        g2: E::G2Affine,
136        s_g2: E::G2Affine,
137    ) -> Self {
138        Self {
139            k,
140            n: 1 << k,
141            g_lagrange: if let Some(g_l) = g_lagrange {
142                g_l
143            } else {
144                g_to_lagrange(g.iter().map(PrimeCurveAffine::to_curve).collect(), k)
145            },
146            g,
147            g2,
148            s_g2,
149        }
150    }
151
152    /// Returns gernerator on G2
153    pub fn g2(&self) -> E::G2Affine {
154        self.g2
155    }
156
157    /// Returns first power of secret on G2
158    pub fn s_g2(&self) -> E::G2Affine {
159        self.s_g2
160    }
161
162    /// Writes parameters to buffer
163    pub fn write_custom<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()>
164    where
165        E::G1Affine: SerdeCurveAffine,
166        E::G2Affine: SerdeCurveAffine,
167    {
168        writer.write_all(&self.k.to_le_bytes())?;
169        for el in self.g.iter() {
170            el.write(writer, format)?;
171        }
172        for el in self.g_lagrange.iter() {
173            el.write(writer, format)?;
174        }
175        self.g2.write(writer, format)?;
176        self.s_g2.write(writer, format)?;
177        Ok(())
178    }
179
180    /// Reads params from a buffer.
181    pub fn read_custom<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self>
182    where
183        E::G1Affine: SerdeCurveAffine,
184        E::G2Affine: SerdeCurveAffine,
185    {
186        let mut k = [0u8; 4];
187        reader.read_exact(&mut k[..])?;
188        let k = u32::from_le_bytes(k);
189        let n = 1 << k;
190
191        let (g, g_lagrange) = match format {
192            SerdeFormat::Processed => {
193                use group::GroupEncoding;
194                let load_points_from_file_parallelly =
195                    |reader: &mut R| -> io::Result<Vec<Option<E::G1Affine>>> {
196                        let mut points_compressed =
197                            vec![<<E as Engine>::G1Affine as GroupEncoding>::Repr::default(); n];
198                        for points_compressed in points_compressed.iter_mut() {
199                            reader.read_exact((*points_compressed).as_mut())?;
200                        }
201
202                        let mut points = vec![Option::<E::G1Affine>::None; n];
203                        parallelize(&mut points, |points, chunks| {
204                            for (i, point) in points.iter_mut().enumerate() {
205                                *point = Option::from(E::G1Affine::from_bytes(
206                                    &points_compressed[chunks + i],
207                                ));
208                            }
209                        });
210                        Ok(points)
211                    };
212
213                let g = load_points_from_file_parallelly(reader)?;
214                let g: Vec<<E as Engine>::G1Affine> = g
215                    .iter()
216                    .map(|point| {
217                        point.ok_or_else(|| {
218                            io::Error::new(io::ErrorKind::Other, "invalid point encoding")
219                        })
220                    })
221                    .collect::<io::Result<_>>()?;
222                let g_lagrange = load_points_from_file_parallelly(reader)?;
223                let g_lagrange: Vec<<E as Engine>::G1Affine> = g_lagrange
224                    .iter()
225                    .map(|point| {
226                        point.ok_or_else(|| {
227                            io::Error::new(io::ErrorKind::Other, "invalid point encoding")
228                        })
229                    })
230                    .collect::<io::Result<_>>()?;
231                (g, g_lagrange)
232            }
233            SerdeFormat::RawBytes => {
234                let g = (0..n)
235                    .map(|_| <E::G1Affine as SerdeCurveAffine>::read(reader, format))
236                    .collect::<io::Result<_>>()?;
237                let g_lagrange = (0..n)
238                    .map(|_| <E::G1Affine as SerdeCurveAffine>::read(reader, format))
239                    .collect::<io::Result<_>>()?;
240                (g, g_lagrange)
241            }
242            SerdeFormat::RawBytesUnchecked => {
243                // avoid try branching for performance
244                let g = (0..n)
245                    .map(|_| <E::G1Affine as SerdeCurveAffine>::read(reader, format).unwrap())
246                    .collect::<Vec<_>>();
247                let g_lagrange = (0..n)
248                    .map(|_| <E::G1Affine as SerdeCurveAffine>::read(reader, format).unwrap())
249                    .collect::<Vec<_>>();
250                (g, g_lagrange)
251            }
252        };
253
254        let g2 = E::G2Affine::read(reader, format)?;
255        let s_g2 = E::G2Affine::read(reader, format)?;
256
257        Ok(Self {
258            k,
259            n: n as u64,
260            g,
261            g_lagrange,
262            g2,
263            s_g2,
264        })
265    }
266}
267
268// TODO: see the issue at https://github.com/appliedzkp/halo2/issues/45
269// So we probably need much smaller verifier key. However for new bases in g1 should be in verifier keys.
270/// KZG multi-open verification parameters
271pub type ParamsVerifierKZG<C> = ParamsKZG<C>;
272
273impl<'params, E: Engine + Debug> Params<'params, E::G1Affine> for ParamsKZG<E>
274where
275    E::G1Affine: SerdeCurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
276    E::G2Affine: SerdeCurveAffine,
277{
278    type MSM = MSMKZG<E>;
279
280    fn k(&self) -> u32 {
281        self.k
282    }
283
284    fn n(&self) -> u64 {
285        self.n
286    }
287
288    fn downsize(&mut self, k: u32) {
289        assert!(k <= self.k);
290
291        self.k = k;
292        self.n = 1 << k;
293
294        self.g.truncate(self.n as usize);
295        self.g_lagrange = g_to_lagrange(self.g.iter().map(|g| g.to_curve()).collect(), k);
296    }
297
298    fn empty_msm(&'params self) -> MSMKZG<E> {
299        MSMKZG::new()
300    }
301
302    fn commit_lagrange(&self, poly: &Polynomial<E::Fr, LagrangeCoeff>, _: Blind<E::Fr>) -> E::G1 {
303        let size = poly.len();
304        assert!(self.n() >= size as u64);
305        best_multiexp(poly, &self.g_lagrange[0..size])
306    }
307
308    /// Writes params to a buffer.
309    fn write<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
310        self.write_custom(writer, SerdeFormat::RawBytesUnchecked)
311    }
312
313    /// Reads params from a buffer.
314    fn read<R: io::Read>(reader: &mut R) -> io::Result<Self> {
315        Self::read_custom(reader, SerdeFormat::RawBytesUnchecked)
316    }
317}
318
319impl<'params, E: Engine + Debug> ParamsVerifier<'params, E::G1Affine> for ParamsKZG<E>
320where
321    E::G1Affine: SerdeCurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
322    E::G2Affine: SerdeCurveAffine,
323{
324}
325
326impl<'params, E: Engine + Debug> ParamsProver<'params, E::G1Affine> for ParamsKZG<E>
327where
328    E::G1Affine: SerdeCurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
329    E::G2Affine: SerdeCurveAffine,
330{
331    type ParamsVerifier = ParamsVerifierKZG<E>;
332
333    fn verifier_params(&'params self) -> &'params Self::ParamsVerifier {
334        self
335    }
336
337    fn new(k: u32) -> Self {
338        Self::setup(k, OsRng)
339    }
340
341    fn commit(&self, poly: &Polynomial<E::Fr, Coeff>, _: Blind<E::Fr>) -> E::G1 {
342        let size = poly.len();
343        assert!(self.n() >= size as u64);
344        best_multiexp(poly, &self.g[0..size])
345    }
346
347    fn get_g(&self) -> &[E::G1Affine] {
348        &self.g
349    }
350}
351
352#[cfg(test)]
353mod test {
354    use crate::poly::commitment::ParamsProver;
355    use crate::poly::commitment::{Blind, Params};
356    use crate::poly::kzg::commitment::ParamsKZG;
357    use ff::Field;
358
359    #[test]
360    fn test_commit_lagrange() {
361        const K: u32 = 6;
362
363        use rand_core::OsRng;
364
365        use crate::poly::EvaluationDomain;
366        use halo2curves::bn256::{Bn256, Fr};
367
368        let params = ParamsKZG::<Bn256>::new(K);
369        let domain = EvaluationDomain::new(1, K);
370
371        let mut a = domain.empty_lagrange();
372
373        for (i, a) in a.iter_mut().enumerate() {
374            *a = Fr::from(i as u64);
375        }
376
377        let b = domain.lagrange_to_coeff(a.clone());
378
379        let alpha = Blind(Fr::random(OsRng));
380
381        assert_eq!(params.commit(&b, alpha), params.commit_lagrange(&a, alpha));
382    }
383
384    #[test]
385    fn test_parameter_serialisation_roundtrip() {
386        const K: u32 = 4;
387
388        use super::super::commitment::Params;
389        use crate::halo2curves::bn256::Bn256;
390
391        let params0 = ParamsKZG::<Bn256>::new(K);
392        let mut data = vec![];
393        <ParamsKZG<_> as Params<_>>::write(&params0, &mut data).unwrap();
394        let params1: ParamsKZG<Bn256> = Params::read::<_>(&mut &data[..]).unwrap();
395
396        assert_eq!(params0.k, params1.k);
397        assert_eq!(params0.n, params1.n);
398        assert_eq!(params0.g.len(), params1.g.len());
399        assert_eq!(params0.g_lagrange.len(), params1.g_lagrange.len());
400
401        assert_eq!(params0.g, params1.g);
402        assert_eq!(params0.g_lagrange, params1.g_lagrange);
403        assert_eq!(params0.g2, params1.g2);
404        assert_eq!(params0.s_g2, params1.s_g2);
405    }
406}