1use crate::arithmetic::{best_multiexp, g_to_lagrange, parallelize};
2use crate::helpers::SerdeCurveAffine;
3use crate::poly::commitment::{Blind, CommitmentScheme, Params, ParamsProver, ParamsVerifier};
4use crate::poly::{Coeff, LagrangeCoeff, Polynomial};
5use crate::SerdeFormat;
6
7use ff::{Field, PrimeField};
8use group::{prime::PrimeCurveAffine, Curve, Group};
9use pairing::Engine;
10use rand_core::{OsRng, RngCore};
11use std::fmt::Debug;
12use std::marker::PhantomData;
13
14use std::io;
15
16use super::msm::MSMKZG;
17
18#[derive(Debug, Clone)]
20pub struct ParamsKZG<E: Engine> {
21 pub(crate) k: u32,
22 pub(crate) n: u64,
23 pub(crate) g: Vec<E::G1Affine>,
24 pub(crate) g_lagrange: Vec<E::G1Affine>,
25 pub(crate) g2: E::G2Affine,
26 pub(crate) s_g2: E::G2Affine,
27}
28
29#[derive(Debug)]
31pub struct KZGCommitmentScheme<E: Engine> {
32 _marker: PhantomData<E>,
33}
34
35impl<E: Engine + Debug> CommitmentScheme for KZGCommitmentScheme<E>
36where
37 E::G1Affine: SerdeCurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
38 E::G2Affine: SerdeCurveAffine,
39{
40 type Scalar = E::Fr;
41 type Curve = E::G1Affine;
42
43 type ParamsProver = ParamsKZG<E>;
44 type ParamsVerifier = ParamsVerifierKZG<E>;
45
46 fn new_params(k: u32) -> Self::ParamsProver {
47 ParamsKZG::new(k)
48 }
49
50 fn read_params<R: io::Read>(reader: &mut R) -> io::Result<Self::ParamsProver> {
51 ParamsKZG::read(reader)
52 }
53}
54
55impl<E: Engine + Debug> ParamsKZG<E> {
56 pub fn setup<R: RngCore>(k: u32, rng: R) -> Self {
59 assert!(k <= E::Fr::S);
62 let n: u64 = 1 << k;
63
64 let g1 = E::G1Affine::generator();
66 let s = <E::Fr>::random(rng);
67
68 let mut g_projective = vec![E::G1::identity(); n as usize];
69 parallelize(&mut g_projective, |g, start| {
70 let mut current_g: E::G1 = g1.into();
71 current_g *= s.pow_vartime([start as u64]);
72 for g in g.iter_mut() {
73 *g = current_g;
74 current_g *= s;
75 }
76 });
77
78 let g = {
79 let mut g = vec![E::G1Affine::identity(); n as usize];
80 parallelize(&mut g, |g, starts| {
81 E::G1::batch_normalize(&g_projective[starts..(starts + g.len())], g);
82 });
83 g
84 };
85
86 let mut g_lagrange_projective = vec![E::G1::identity(); n as usize];
87 let mut root = E::Fr::ROOT_OF_UNITY_INV.invert().unwrap();
88 for _ in k..E::Fr::S {
89 root = root.square();
90 }
91 let n_inv = Option::<E::Fr>::from(E::Fr::from(n).invert())
92 .expect("inversion should be ok for n = 1<<k");
93 let multiplier = (s.pow_vartime([n]) - E::Fr::ONE) * n_inv;
94 parallelize(&mut g_lagrange_projective, |g, start| {
95 for (idx, g) in g.iter_mut().enumerate() {
96 let offset = start + idx;
97 let root_pow = root.pow_vartime([offset as u64]);
98 let scalar = multiplier * root_pow * (s - root_pow).invert().unwrap();
99 *g = g1 * scalar;
100 }
101 });
102
103 let g_lagrange = {
104 let mut g_lagrange = vec![E::G1Affine::identity(); n as usize];
105 parallelize(&mut g_lagrange, |g_lagrange, starts| {
106 E::G1::batch_normalize(
107 &g_lagrange_projective[starts..(starts + g_lagrange.len())],
108 g_lagrange,
109 );
110 });
111 drop(g_lagrange_projective);
112 g_lagrange
113 };
114
115 let g2 = <E::G2Affine as PrimeCurveAffine>::generator();
116 let s_g2 = (g2 * s).into();
117
118 Self {
119 k,
120 n,
121 g,
122 g_lagrange,
123 g2,
124 s_g2,
125 }
126 }
127
128 pub fn from_parts(
131 &self,
132 k: u32,
133 g: Vec<E::G1Affine>,
134 g_lagrange: Option<Vec<E::G1Affine>>,
135 g2: E::G2Affine,
136 s_g2: E::G2Affine,
137 ) -> Self {
138 Self {
139 k,
140 n: 1 << k,
141 g_lagrange: if let Some(g_l) = g_lagrange {
142 g_l
143 } else {
144 g_to_lagrange(g.iter().map(PrimeCurveAffine::to_curve).collect(), k)
145 },
146 g,
147 g2,
148 s_g2,
149 }
150 }
151
152 pub fn g2(&self) -> E::G2Affine {
154 self.g2
155 }
156
157 pub fn s_g2(&self) -> E::G2Affine {
159 self.s_g2
160 }
161
162 pub fn write_custom<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()>
164 where
165 E::G1Affine: SerdeCurveAffine,
166 E::G2Affine: SerdeCurveAffine,
167 {
168 writer.write_all(&self.k.to_le_bytes())?;
169 for el in self.g.iter() {
170 el.write(writer, format)?;
171 }
172 for el in self.g_lagrange.iter() {
173 el.write(writer, format)?;
174 }
175 self.g2.write(writer, format)?;
176 self.s_g2.write(writer, format)?;
177 Ok(())
178 }
179
180 pub fn read_custom<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self>
182 where
183 E::G1Affine: SerdeCurveAffine,
184 E::G2Affine: SerdeCurveAffine,
185 {
186 let mut k = [0u8; 4];
187 reader.read_exact(&mut k[..])?;
188 let k = u32::from_le_bytes(k);
189 let n = 1 << k;
190
191 let (g, g_lagrange) = match format {
192 SerdeFormat::Processed => {
193 use group::GroupEncoding;
194 let load_points_from_file_parallelly =
195 |reader: &mut R| -> io::Result<Vec<Option<E::G1Affine>>> {
196 let mut points_compressed =
197 vec![<<E as Engine>::G1Affine as GroupEncoding>::Repr::default(); n];
198 for points_compressed in points_compressed.iter_mut() {
199 reader.read_exact((*points_compressed).as_mut())?;
200 }
201
202 let mut points = vec![Option::<E::G1Affine>::None; n];
203 parallelize(&mut points, |points, chunks| {
204 for (i, point) in points.iter_mut().enumerate() {
205 *point = Option::from(E::G1Affine::from_bytes(
206 &points_compressed[chunks + i],
207 ));
208 }
209 });
210 Ok(points)
211 };
212
213 let g = load_points_from_file_parallelly(reader)?;
214 let g: Vec<<E as Engine>::G1Affine> = g
215 .iter()
216 .map(|point| {
217 point.ok_or_else(|| {
218 io::Error::new(io::ErrorKind::Other, "invalid point encoding")
219 })
220 })
221 .collect::<io::Result<_>>()?;
222 let g_lagrange = load_points_from_file_parallelly(reader)?;
223 let g_lagrange: Vec<<E as Engine>::G1Affine> = g_lagrange
224 .iter()
225 .map(|point| {
226 point.ok_or_else(|| {
227 io::Error::new(io::ErrorKind::Other, "invalid point encoding")
228 })
229 })
230 .collect::<io::Result<_>>()?;
231 (g, g_lagrange)
232 }
233 SerdeFormat::RawBytes => {
234 let g = (0..n)
235 .map(|_| <E::G1Affine as SerdeCurveAffine>::read(reader, format))
236 .collect::<io::Result<_>>()?;
237 let g_lagrange = (0..n)
238 .map(|_| <E::G1Affine as SerdeCurveAffine>::read(reader, format))
239 .collect::<io::Result<_>>()?;
240 (g, g_lagrange)
241 }
242 SerdeFormat::RawBytesUnchecked => {
243 let g = (0..n)
245 .map(|_| <E::G1Affine as SerdeCurveAffine>::read(reader, format).unwrap())
246 .collect::<Vec<_>>();
247 let g_lagrange = (0..n)
248 .map(|_| <E::G1Affine as SerdeCurveAffine>::read(reader, format).unwrap())
249 .collect::<Vec<_>>();
250 (g, g_lagrange)
251 }
252 };
253
254 let g2 = E::G2Affine::read(reader, format)?;
255 let s_g2 = E::G2Affine::read(reader, format)?;
256
257 Ok(Self {
258 k,
259 n: n as u64,
260 g,
261 g_lagrange,
262 g2,
263 s_g2,
264 })
265 }
266}
267
268pub type ParamsVerifierKZG<C> = ParamsKZG<C>;
272
273impl<'params, E: Engine + Debug> Params<'params, E::G1Affine> for ParamsKZG<E>
274where
275 E::G1Affine: SerdeCurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
276 E::G2Affine: SerdeCurveAffine,
277{
278 type MSM = MSMKZG<E>;
279
280 fn k(&self) -> u32 {
281 self.k
282 }
283
284 fn n(&self) -> u64 {
285 self.n
286 }
287
288 fn downsize(&mut self, k: u32) {
289 assert!(k <= self.k);
290
291 self.k = k;
292 self.n = 1 << k;
293
294 self.g.truncate(self.n as usize);
295 self.g_lagrange = g_to_lagrange(self.g.iter().map(|g| g.to_curve()).collect(), k);
296 }
297
298 fn empty_msm(&'params self) -> MSMKZG<E> {
299 MSMKZG::new()
300 }
301
302 fn commit_lagrange(&self, poly: &Polynomial<E::Fr, LagrangeCoeff>, _: Blind<E::Fr>) -> E::G1 {
303 let size = poly.len();
304 assert!(self.n() >= size as u64);
305 best_multiexp(poly, &self.g_lagrange[0..size])
306 }
307
308 fn write<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
310 self.write_custom(writer, SerdeFormat::RawBytesUnchecked)
311 }
312
313 fn read<R: io::Read>(reader: &mut R) -> io::Result<Self> {
315 Self::read_custom(reader, SerdeFormat::RawBytesUnchecked)
316 }
317}
318
319impl<'params, E: Engine + Debug> ParamsVerifier<'params, E::G1Affine> for ParamsKZG<E>
320where
321 E::G1Affine: SerdeCurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
322 E::G2Affine: SerdeCurveAffine,
323{
324}
325
326impl<'params, E: Engine + Debug> ParamsProver<'params, E::G1Affine> for ParamsKZG<E>
327where
328 E::G1Affine: SerdeCurveAffine<ScalarExt = E::Fr, CurveExt = E::G1>,
329 E::G2Affine: SerdeCurveAffine,
330{
331 type ParamsVerifier = ParamsVerifierKZG<E>;
332
333 fn verifier_params(&'params self) -> &'params Self::ParamsVerifier {
334 self
335 }
336
337 fn new(k: u32) -> Self {
338 Self::setup(k, OsRng)
339 }
340
341 fn commit(&self, poly: &Polynomial<E::Fr, Coeff>, _: Blind<E::Fr>) -> E::G1 {
342 let size = poly.len();
343 assert!(self.n() >= size as u64);
344 best_multiexp(poly, &self.g[0..size])
345 }
346
347 fn get_g(&self) -> &[E::G1Affine] {
348 &self.g
349 }
350}
351
352#[cfg(test)]
353mod test {
354 use crate::poly::commitment::ParamsProver;
355 use crate::poly::commitment::{Blind, Params};
356 use crate::poly::kzg::commitment::ParamsKZG;
357 use ff::Field;
358
359 #[test]
360 fn test_commit_lagrange() {
361 const K: u32 = 6;
362
363 use rand_core::OsRng;
364
365 use crate::poly::EvaluationDomain;
366 use halo2curves::bn256::{Bn256, Fr};
367
368 let params = ParamsKZG::<Bn256>::new(K);
369 let domain = EvaluationDomain::new(1, K);
370
371 let mut a = domain.empty_lagrange();
372
373 for (i, a) in a.iter_mut().enumerate() {
374 *a = Fr::from(i as u64);
375 }
376
377 let b = domain.lagrange_to_coeff(a.clone());
378
379 let alpha = Blind(Fr::random(OsRng));
380
381 assert_eq!(params.commit(&b, alpha), params.commit_lagrange(&a, alpha));
382 }
383
384 #[test]
385 fn test_parameter_serialisation_roundtrip() {
386 const K: u32 = 4;
387
388 use super::super::commitment::Params;
389 use crate::halo2curves::bn256::Bn256;
390
391 let params0 = ParamsKZG::<Bn256>::new(K);
392 let mut data = vec![];
393 <ParamsKZG<_> as Params<_>>::write(¶ms0, &mut data).unwrap();
394 let params1: ParamsKZG<Bn256> = Params::read::<_>(&mut &data[..]).unwrap();
395
396 assert_eq!(params0.k, params1.k);
397 assert_eq!(params0.n, params1.n);
398 assert_eq!(params0.g.len(), params1.g.len());
399 assert_eq!(params0.g_lagrange.len(), params1.g_lagrange.len());
400
401 assert_eq!(params0.g, params1.g);
402 assert_eq!(params0.g_lagrange, params1.g_lagrange);
403 assert_eq!(params0.g2, params1.g2);
404 assert_eq!(params0.s_g2, params1.s_g2);
405 }
406}