snark_verifier/pcs/ipa/multiopen/
bgh19.rs

1use crate::{
2    loader::{LoadedScalar, Loader, ScalarLoader},
3    pcs::{
4        ipa::{Ipa, IpaAccumulator, IpaAs, IpaProof, IpaSuccinctVerifyingKey, Round},
5        PolynomialCommitmentScheme, Query,
6    },
7    util::{
8        arithmetic::{CurveAffine, Fraction, PrimeField, Rotation},
9        msm::Msm,
10        transcript::TranscriptRead,
11        Itertools,
12    },
13    Error,
14};
15use std::{
16    collections::{BTreeMap, BTreeSet},
17    iter,
18    marker::PhantomData,
19};
20
21/// Verifier of multi-open inner product argument. It is for the implementation
22/// in [`halo2_proofs`](crate::halo2_proofs), which is previously <https://eprint.iacr.org/2019/1021>
23/// .
24#[derive(Clone, Debug)]
25pub struct Bgh19;
26
27impl<C, L> PolynomialCommitmentScheme<C, L> for IpaAs<C, Bgh19>
28where
29    C: CurveAffine,
30    L: Loader<C>,
31{
32    type VerifyingKey = IpaSuccinctVerifyingKey<C>;
33    type Proof = Bgh19Proof<C, L>;
34    type Output = IpaAccumulator<C, L>;
35
36    fn read_proof<T>(
37        svk: &Self::VerifyingKey,
38        queries: &[Query<Rotation>],
39        transcript: &mut T,
40    ) -> Result<Self::Proof, Error>
41    where
42        T: TranscriptRead<C, L>,
43    {
44        Bgh19Proof::read(svk, queries, transcript)
45    }
46
47    fn verify(
48        svk: &Self::VerifyingKey,
49        commitments: &[Msm<C, L>],
50        x: &L::LoadedScalar,
51        queries: &[Query<Rotation, L::LoadedScalar>],
52        proof: &Self::Proof,
53    ) -> Result<Self::Output, Error> {
54        let loader = x.loader();
55        let g = loader.ec_point_load_const(&svk.g);
56
57        // Multiopen
58        let sets = query_sets(queries);
59        let p = {
60            let coeffs = query_set_coeffs(&sets, x, &proof.x_3);
61
62            let powers_of_x_1 =
63                proof.x_1.powers(sets.iter().map(|set| set.polys.len()).max().unwrap());
64            let f_eval = {
65                let powers_of_x_2 = proof.x_2.powers(sets.len());
66                let f_evals = sets
67                    .iter()
68                    .zip(coeffs.iter())
69                    .zip(proof.q_evals.iter())
70                    .map(|((set, coeff), q_eval)| set.f_eval(coeff, q_eval, &powers_of_x_1))
71                    .collect_vec();
72                x.loader()
73                    .sum_products(&powers_of_x_2.iter().zip(f_evals.iter().rev()).collect_vec())
74            };
75            let msms = sets
76                .iter()
77                .zip(proof.q_evals.iter())
78                .map(|(set, q_eval)| set.msm(commitments, q_eval, &powers_of_x_1));
79
80            let (mut msm, constant) = iter::once(Msm::base(&proof.f) - Msm::constant(f_eval))
81                .chain(msms)
82                .zip(proof.x_4.powers(sets.len() + 1).into_iter().rev())
83                .map(|(msm, power_of_x_4)| msm * &power_of_x_4)
84                .sum::<Msm<_, _>>()
85                .split();
86            if let Some(constant) = constant {
87                msm += Msm::base(&g) * &constant;
88            }
89            msm
90        };
91
92        // IPA
93        Ipa::succinct_verify(svk, &p, &proof.x_3, &loader.load_zero(), &proof.ipa)
94    }
95}
96
97/// Structured proof of [`Bgh19`].
98#[derive(Clone, Debug)]
99pub struct Bgh19Proof<C, L>
100where
101    C: CurveAffine,
102    L: Loader<C>,
103{
104    // Multiopen
105    x_1: L::LoadedScalar,
106    x_2: L::LoadedScalar,
107    f: L::LoadedEcPoint,
108    x_3: L::LoadedScalar,
109    q_evals: Vec<L::LoadedScalar>,
110    x_4: L::LoadedScalar,
111    // IPA
112    ipa: IpaProof<C, L>,
113}
114
115impl<C, L> Bgh19Proof<C, L>
116where
117    C: CurveAffine,
118    L: Loader<C>,
119{
120    fn read<T: TranscriptRead<C, L>>(
121        svk: &IpaSuccinctVerifyingKey<C>,
122        queries: &[Query<Rotation>],
123        transcript: &mut T,
124    ) -> Result<Self, Error> {
125        // Multiopen
126        let x_1 = transcript.squeeze_challenge();
127        let x_2 = transcript.squeeze_challenge();
128        let f = transcript.read_ec_point()?;
129        let x_3 = transcript.squeeze_challenge();
130        let q_evals = transcript.read_n_scalars(query_sets(queries).len())?;
131        let x_4 = transcript.squeeze_challenge();
132        // IPA
133        let s = transcript.read_ec_point()?;
134        let xi = transcript.squeeze_challenge();
135        let z = transcript.squeeze_challenge();
136        let rounds = iter::repeat_with(|| {
137            Ok(Round::new(
138                transcript.read_ec_point()?,
139                transcript.read_ec_point()?,
140                transcript.squeeze_challenge(),
141            ))
142        })
143        .take(svk.domain.k)
144        .collect::<Result<Vec<_>, _>>()?;
145        let c = transcript.read_scalar()?;
146        let blind = transcript.read_scalar()?;
147        let g = transcript.read_ec_point()?;
148        Ok(Bgh19Proof {
149            x_1,
150            x_2,
151            f,
152            x_3,
153            q_evals,
154            x_4,
155            ipa: IpaProof::new(Some((s, xi)), Some(blind), z, rounds, g, c),
156        })
157    }
158}
159
160fn query_sets<S, T>(queries: &[Query<S, T>]) -> Vec<QuerySet<S, T>>
161where
162    S: PartialEq + Ord + Copy,
163    T: Clone,
164{
165    let poly_shifts =
166        queries.iter().fold(Vec::<(usize, Vec<_>, Vec<&T>)>::new(), |mut poly_shifts, query| {
167            if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) {
168                let (_, shifts, evals) = &mut poly_shifts[pos];
169                if !shifts.iter().map(|(shift, _)| shift).contains(&query.shift) {
170                    shifts.push((query.shift, query.loaded_shift.clone()));
171                    evals.push(&query.eval);
172                }
173            } else {
174                poly_shifts.push((
175                    query.poly,
176                    vec![(query.shift, query.loaded_shift.clone())],
177                    vec![&query.eval],
178                ));
179            }
180            poly_shifts
181        });
182
183    poly_shifts.into_iter().fold(Vec::<QuerySet<_, T>>::new(), |mut sets, (poly, shifts, evals)| {
184        if let Some(pos) = sets.iter().position(|set| {
185            BTreeSet::from_iter(set.shifts.iter().map(|(shift, _)| shift))
186                == BTreeSet::from_iter(shifts.iter().map(|(shift, _)| shift))
187        }) {
188            let set = &mut sets[pos];
189            if !set.polys.contains(&poly) {
190                set.polys.push(poly);
191                set.evals.push(
192                    set.shifts
193                        .iter()
194                        .map(|lhs| {
195                            let idx = shifts.iter().position(|rhs| lhs.0 == rhs.0).unwrap();
196                            evals[idx]
197                        })
198                        .collect(),
199                );
200            }
201        } else {
202            let set = QuerySet { shifts, polys: vec![poly], evals: vec![evals] };
203            sets.push(set);
204        }
205        sets
206    })
207}
208
209fn query_set_coeffs<F, T>(
210    sets: &[QuerySet<Rotation, T>],
211    x: &T,
212    x_3: &T,
213) -> Vec<QuerySetCoeff<F, T>>
214where
215    F: PrimeField + Ord,
216    T: LoadedScalar<F>,
217{
218    let superset = BTreeMap::from_iter(sets.iter().flat_map(|set| set.shifts.clone()));
219
220    let size = sets.iter().map(|set| set.shifts.len()).chain(Some(2)).max().unwrap();
221    let powers_of_x = x.powers(size);
222    let x_3_minus_x_shift_i = BTreeMap::from_iter(
223        superset
224            .into_iter()
225            .map(|(shift, loaded_shift)| (shift, x_3.clone() - x.clone() * loaded_shift)),
226    );
227
228    let mut coeffs = sets
229        .iter()
230        .map(|set| QuerySetCoeff::new(&set.shifts, &powers_of_x, x_3, &x_3_minus_x_shift_i))
231        .collect_vec();
232
233    T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms));
234    T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms));
235    coeffs.iter_mut().for_each(QuerySetCoeff::evaluate);
236
237    coeffs
238}
239
240#[derive(Clone, Debug)]
241struct QuerySet<'a, S, T> {
242    shifts: Vec<(S, T)>,
243    polys: Vec<usize>,
244    evals: Vec<Vec<&'a T>>,
245}
246
247impl<'a, S, T> QuerySet<'a, S, T> {
248    fn msm<C: CurveAffine, L: Loader<C, LoadedScalar = T>>(
249        &self,
250        commitments: &[Msm<'a, C, L>],
251        q_eval: &T,
252        powers_of_x_1: &[T],
253    ) -> Msm<C, L>
254    where
255        T: LoadedScalar<C::Scalar>,
256    {
257        self.polys
258            .iter()
259            .rev()
260            .zip(powers_of_x_1)
261            .map(|(poly, power_of_x_1)| commitments[*poly].clone() * power_of_x_1)
262            .sum::<Msm<_, _>>()
263            - Msm::constant(q_eval.clone())
264    }
265
266    fn f_eval<F: PrimeField>(
267        &self,
268        coeff: &QuerySetCoeff<F, T>,
269        q_eval: &T,
270        powers_of_x_1: &[T],
271    ) -> T
272    where
273        T: LoadedScalar<F>,
274    {
275        let loader = q_eval.loader();
276        let r_eval = {
277            let r_evals = self
278                .evals
279                .iter()
280                .map(|evals| {
281                    loader.sum_products(
282                        &coeff
283                            .eval_coeffs
284                            .iter()
285                            .zip(evals.iter())
286                            .map(|(coeff, eval)| (coeff.evaluated(), *eval))
287                            .collect_vec(),
288                    ) * coeff.r_eval_coeff.as_ref().unwrap().evaluated()
289                })
290                .collect_vec();
291            loader.sum_products(&r_evals.iter().rev().zip(powers_of_x_1).collect_vec())
292        };
293
294        (q_eval.clone() - r_eval) * coeff.f_eval_coeff.evaluated()
295    }
296}
297
298#[derive(Clone, Debug)]
299struct QuerySetCoeff<F, T> {
300    eval_coeffs: Vec<Fraction<T>>,
301    r_eval_coeff: Option<Fraction<T>>,
302    f_eval_coeff: Fraction<T>,
303    _marker: PhantomData<F>,
304}
305
306impl<F, T> QuerySetCoeff<F, T>
307where
308    F: PrimeField + Ord,
309    T: LoadedScalar<F>,
310{
311    fn new(
312        shifts: &[(Rotation, T)],
313        powers_of_x: &[T],
314        x_3: &T,
315        x_3_minus_x_shift_i: &BTreeMap<Rotation, T>,
316    ) -> Self {
317        let loader = x_3.loader();
318        let normalized_ell_primes = shifts
319            .iter()
320            .enumerate()
321            .map(|(j, shift_j)| {
322                shifts
323                    .iter()
324                    .enumerate()
325                    .filter(|&(i, _)| i != j)
326                    .map(|(_, shift_i)| (shift_j.1.clone() - &shift_i.1))
327                    .reduce(|acc, value| acc * value)
328                    .unwrap_or_else(|| loader.load_const(&F::ONE))
329            })
330            .collect_vec();
331
332        let x = &powers_of_x[1].clone();
333        let x_pow_k_minus_one = &powers_of_x[shifts.len() - 1];
334
335        let barycentric_weights = shifts
336            .iter()
337            .zip(normalized_ell_primes.iter())
338            .map(|((_, loaded_shift), normalized_ell_prime)| {
339                let tmp = normalized_ell_prime.clone() * x_pow_k_minus_one;
340                loader.sum_products(&[(&tmp, x_3), (&-(tmp.clone() * loaded_shift), x)])
341            })
342            .map(Fraction::one_over)
343            .collect_vec();
344
345        let f_eval_coeff = Fraction::one_over(loader.product(
346            &shifts.iter().map(|(shift, _)| x_3_minus_x_shift_i.get(shift).unwrap()).collect_vec(),
347        ));
348
349        Self {
350            eval_coeffs: barycentric_weights,
351            r_eval_coeff: None,
352            f_eval_coeff,
353            _marker: PhantomData,
354        }
355    }
356
357    fn denoms(&mut self) -> impl IntoIterator<Item = &'_ mut T> {
358        if self.eval_coeffs.first().unwrap().denom().is_some() {
359            return self
360                .eval_coeffs
361                .iter_mut()
362                .chain(Some(&mut self.f_eval_coeff))
363                .filter_map(Fraction::denom_mut)
364                .collect_vec();
365        }
366
367        if self.r_eval_coeff.is_none() {
368            self.eval_coeffs
369                .iter_mut()
370                .chain(Some(&mut self.f_eval_coeff))
371                .for_each(Fraction::evaluate);
372
373            let loader = self.f_eval_coeff.evaluated().loader();
374            let barycentric_weights_sum =
375                loader.sum(&self.eval_coeffs.iter().map(Fraction::evaluated).collect_vec());
376            self.r_eval_coeff = Some(Fraction::one_over(barycentric_weights_sum));
377
378            return vec![self.r_eval_coeff.as_mut().unwrap().denom_mut().unwrap()];
379        }
380
381        unreachable!()
382    }
383
384    fn evaluate(&mut self) {
385        self.r_eval_coeff.as_mut().unwrap().evaluate();
386    }
387}