snark_verifier/pcs/kzg/multiopen/
bdfg21.rs

1use crate::{
2    cost::{Cost, CostEstimation},
3    loader::{LoadedScalar, Loader, ScalarLoader},
4    pcs::{
5        kzg::{KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey},
6        PolynomialCommitmentScheme, Query,
7    },
8    util::{
9        arithmetic::{CurveAffine, Fraction, MultiMillerLoop, PrimeField, Rotation},
10        msm::Msm,
11        transcript::TranscriptRead,
12        Itertools,
13    },
14    Error,
15};
16use std::{
17    collections::{BTreeMap, BTreeSet},
18    marker::PhantomData,
19};
20
21/// Verifier of multi-open KZG. It is for the SHPLONK implementation
22/// in [`halo2_proofs`](crate::halo2_proofs).
23/// Notations are following <https://eprint.iacr.org/2020/081>.
24#[derive(Clone, Debug)]
25pub struct Bdfg21;
26
27impl<M, L> PolynomialCommitmentScheme<M::G1Affine, L> for KzgAs<M, Bdfg21>
28where
29    M: MultiMillerLoop,
30    M::G1Affine: CurveAffine<ScalarExt = M::Fr, CurveExt = M::G1>,
31    M::Fr: Ord,
32    L: Loader<M::G1Affine>,
33{
34    type VerifyingKey = KzgSuccinctVerifyingKey<M::G1Affine>;
35    type Proof = Bdfg21Proof<M::G1Affine, L>;
36    type Output = KzgAccumulator<M::G1Affine, L>;
37
38    fn read_proof<T>(
39        _: &KzgSuccinctVerifyingKey<M::G1Affine>,
40        _: &[Query<Rotation>],
41        transcript: &mut T,
42    ) -> Result<Bdfg21Proof<M::G1Affine, L>, Error>
43    where
44        T: TranscriptRead<M::G1Affine, L>,
45    {
46        Bdfg21Proof::read(transcript)
47    }
48
49    fn verify(
50        svk: &KzgSuccinctVerifyingKey<M::G1Affine>,
51        commitments: &[Msm<M::G1Affine, L>],
52        z: &L::LoadedScalar,
53        queries: &[Query<Rotation, L::LoadedScalar>],
54        proof: &Bdfg21Proof<M::G1Affine, L>,
55    ) -> Result<Self::Output, Error> {
56        let sets = query_sets(queries);
57        let f = {
58            let coeffs = query_set_coeffs(&sets, z, &proof.z_prime);
59
60            let powers_of_mu =
61                proof.mu.powers(sets.iter().map(|set| set.polys.len()).max().unwrap());
62            let msms = sets
63                .iter()
64                .zip(coeffs.iter())
65                .map(|(set, coeff)| set.msm(coeff, commitments, &powers_of_mu));
66
67            msms.zip(proof.gamma.powers(sets.len()))
68                .map(|(msm, power_of_gamma)| msm * &power_of_gamma)
69                .sum::<Msm<_, _>>()
70                - Msm::base(&proof.w) * &coeffs[0].z_s
71        };
72
73        let rhs = Msm::base(&proof.w_prime);
74        let lhs = f + rhs.clone() * &proof.z_prime;
75
76        Ok(KzgAccumulator::new(lhs.evaluate(Some(svk.g)), rhs.evaluate(Some(svk.g))))
77    }
78}
79
80/// Structured proof of [`Bdfg21`].
81#[derive(Clone, Debug)]
82pub struct Bdfg21Proof<C, L>
83where
84    C: CurveAffine,
85    L: Loader<C>,
86{
87    mu: L::LoadedScalar,
88    gamma: L::LoadedScalar,
89    w: L::LoadedEcPoint,
90    z_prime: L::LoadedScalar,
91    w_prime: L::LoadedEcPoint,
92}
93
94impl<C, L> Bdfg21Proof<C, L>
95where
96    C: CurveAffine,
97    L: Loader<C>,
98{
99    fn read<T: TranscriptRead<C, L>>(transcript: &mut T) -> Result<Self, Error> {
100        let mu = transcript.squeeze_challenge();
101        let gamma = transcript.squeeze_challenge();
102        let w = transcript.read_ec_point()?;
103        let z_prime = transcript.squeeze_challenge();
104        let w_prime = transcript.read_ec_point()?;
105        Ok(Bdfg21Proof { mu, gamma, w, z_prime, w_prime })
106    }
107}
108
109fn query_sets<S: PartialEq + Ord + Copy, T: Clone>(queries: &[Query<S, T>]) -> Vec<QuerySet<S, T>> {
110    let poly_shifts =
111        queries.iter().fold(Vec::<(usize, Vec<_>, Vec<&T>)>::new(), |mut poly_shifts, query| {
112            if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) {
113                let (_, shifts, evals) = &mut poly_shifts[pos];
114                if !shifts.iter().map(|(shift, _)| shift).contains(&query.shift) {
115                    shifts.push((query.shift, query.loaded_shift.clone()));
116                    evals.push(&query.eval);
117                }
118            } else {
119                poly_shifts.push((
120                    query.poly,
121                    vec![(query.shift, query.loaded_shift.clone())],
122                    vec![&query.eval],
123                ));
124            }
125            poly_shifts
126        });
127
128    poly_shifts.into_iter().fold(Vec::<QuerySet<_, T>>::new(), |mut sets, (poly, shifts, evals)| {
129        if let Some(pos) = sets.iter().position(|set| {
130            BTreeSet::from_iter(set.shifts.iter().map(|(shift, _)| shift))
131                == BTreeSet::from_iter(shifts.iter().map(|(shift, _)| shift))
132        }) {
133            let set = &mut sets[pos];
134            if !set.polys.contains(&poly) {
135                set.polys.push(poly);
136                set.evals.push(
137                    set.shifts
138                        .iter()
139                        .map(|lhs| {
140                            let idx = shifts.iter().position(|rhs| lhs.0 == rhs.0).unwrap();
141                            evals[idx]
142                        })
143                        .collect(),
144                );
145            }
146        } else {
147            let set = QuerySet { shifts, polys: vec![poly], evals: vec![evals] };
148            sets.push(set);
149        }
150        sets
151    })
152}
153
154fn query_set_coeffs<F: PrimeField + Ord, T: LoadedScalar<F>>(
155    sets: &[QuerySet<Rotation, T>],
156    z: &T,
157    z_prime: &T,
158) -> Vec<QuerySetCoeff<F, T>> {
159    // map of shift => loaded_shift, removing duplicate `shift` values
160    // shift is the rotation, not omega^rotation, to ensure BTreeMap does not depend on omega (otherwise ordering can change)
161    let superset = BTreeMap::from_iter(sets.iter().flat_map(|set| set.shifts.clone()));
162
163    let size = sets.iter().map(|set| set.shifts.len()).chain(Some(2)).max().unwrap();
164    let powers_of_z = z.powers(size);
165    let z_prime_minus_z_shift_i = BTreeMap::from_iter(
166        superset
167            .into_iter()
168            .map(|(shift, loaded_shift)| (shift, z_prime.clone() - z.clone() * loaded_shift)),
169    );
170
171    let mut z_s_1 = None;
172    let mut coeffs = sets
173        .iter()
174        .map(|set| {
175            let coeff = QuerySetCoeff::new(
176                &set.shifts,
177                &powers_of_z,
178                z_prime,
179                &z_prime_minus_z_shift_i,
180                &z_s_1,
181            );
182            if z_s_1.is_none() {
183                z_s_1 = Some(coeff.z_s.clone());
184            };
185            coeff
186        })
187        .collect_vec();
188
189    T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms));
190    T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms));
191    coeffs.iter_mut().for_each(QuerySetCoeff::evaluate);
192
193    coeffs
194}
195
196#[derive(Clone, Debug)]
197struct QuerySet<'a, S, T> {
198    shifts: Vec<(S, T)>, // vec of (shift, loaded_shift)
199    polys: Vec<usize>,
200    evals: Vec<Vec<&'a T>>,
201}
202
203impl<'a, S, T> QuerySet<'a, S, T> {
204    fn msm<C: CurveAffine, L: Loader<C, LoadedScalar = T>>(
205        &self,
206        coeff: &QuerySetCoeff<C::Scalar, T>,
207        commitments: &[Msm<'a, C, L>],
208        powers_of_mu: &[T],
209    ) -> Msm<C, L>
210    where
211        T: LoadedScalar<C::Scalar>,
212    {
213        self.polys
214            .iter()
215            .zip(self.evals.iter())
216            .zip(powers_of_mu.iter())
217            .map(|((poly, evals), power_of_mu)| {
218                let loader = power_of_mu.loader();
219                let commitment = coeff
220                    .commitment_coeff
221                    .as_ref()
222                    .map(|commitment_coeff| {
223                        commitments[*poly].clone() * commitment_coeff.evaluated()
224                    })
225                    .unwrap_or_else(|| commitments[*poly].clone());
226                let r_eval = loader.sum_products(
227                    &coeff
228                        .eval_coeffs
229                        .iter()
230                        .zip(evals.iter().cloned())
231                        .map(|(coeff, eval)| (coeff.evaluated(), eval))
232                        .collect_vec(),
233                ) * coeff.r_eval_coeff.as_ref().unwrap().evaluated();
234                (commitment - Msm::constant(r_eval)) * power_of_mu
235            })
236            .sum()
237    }
238}
239
240#[derive(Clone, Debug)]
241struct QuerySetCoeff<F, T> {
242    z_s: T,
243    eval_coeffs: Vec<Fraction<T>>,
244    commitment_coeff: Option<Fraction<T>>,
245    r_eval_coeff: Option<Fraction<T>>,
246    _marker: PhantomData<F>,
247}
248
249impl<F, T> QuerySetCoeff<F, T>
250where
251    F: PrimeField + Ord,
252    T: LoadedScalar<F>,
253{
254    fn new(
255        shifts: &[(Rotation, T)],
256        powers_of_z: &[T],
257        z_prime: &T,
258        z_prime_minus_z_shift_i: &BTreeMap<Rotation, T>,
259        z_s_1: &Option<T>,
260    ) -> Self {
261        let loader = z_prime.loader();
262
263        let normalized_ell_primes = shifts
264            .iter()
265            .enumerate()
266            .map(|(j, shift_j)| {
267                shifts
268                    .iter()
269                    .enumerate()
270                    .filter(|&(i, _)| i != j)
271                    .map(|(_, shift_i)| (shift_j.1.clone() - &shift_i.1))
272                    .reduce(|acc, value| acc * value)
273                    .unwrap_or_else(|| loader.load_const(&F::ONE))
274            })
275            .collect_vec();
276
277        let z = &powers_of_z[1];
278        let z_pow_k_minus_one = &powers_of_z[shifts.len() - 1];
279
280        let barycentric_weights = shifts
281            .iter()
282            .zip(normalized_ell_primes.iter())
283            .map(|((_, loaded_shift), normalized_ell_prime)| {
284                let tmp = normalized_ell_prime.clone() * z_pow_k_minus_one;
285                loader.sum_products(&[(&tmp, z_prime), (&-(tmp.clone() * loaded_shift), z)])
286            })
287            .map(Fraction::one_over)
288            .collect_vec();
289
290        let z_s = loader.product(
291            &shifts
292                .iter()
293                .map(|(shift, _)| z_prime_minus_z_shift_i.get(shift).unwrap())
294                .collect_vec(),
295        );
296        let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone()));
297
298        Self {
299            z_s,
300            eval_coeffs: barycentric_weights,
301            commitment_coeff: z_s_1_over_z_s,
302            r_eval_coeff: None,
303            _marker: PhantomData,
304        }
305    }
306
307    fn denoms(&mut self) -> impl IntoIterator<Item = &'_ mut T> {
308        if self.eval_coeffs.first().unwrap().denom().is_some() {
309            return self
310                .eval_coeffs
311                .iter_mut()
312                .chain(self.commitment_coeff.as_mut())
313                .filter_map(Fraction::denom_mut)
314                .collect_vec();
315        }
316
317        if self.r_eval_coeff.is_none() {
318            let loader = self.z_s.loader();
319            self.eval_coeffs
320                .iter_mut()
321                .chain(self.commitment_coeff.as_mut())
322                .for_each(Fraction::evaluate);
323            let barycentric_weights_sum =
324                loader.sum(&self.eval_coeffs.iter().map(Fraction::evaluated).collect_vec());
325            self.r_eval_coeff = Some(match self.commitment_coeff.as_ref() {
326                Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum),
327                None => Fraction::one_over(barycentric_weights_sum),
328            });
329            return vec![self.r_eval_coeff.as_mut().unwrap().denom_mut().unwrap()];
330        }
331
332        unreachable!()
333    }
334
335    fn evaluate(&mut self) {
336        self.r_eval_coeff.as_mut().unwrap().evaluate();
337    }
338}
339
340impl<M> CostEstimation<M::G1Affine> for KzgAs<M, Bdfg21>
341where
342    M: MultiMillerLoop,
343{
344    type Input = Vec<Query<Rotation>>;
345
346    fn estimate_cost(_: &Vec<Query<Rotation>>) -> Cost {
347        Cost { num_commitment: 2, num_msm: 2, ..Default::default() }
348    }
349}