halo2_axiom/plonk/permutation/
keygen.rs

1use ff::{Field, PrimeField};
2use group::Curve;
3
4use super::{Argument, ProvingKey, VerifyingKey};
5use crate::{
6    arithmetic::{parallelize, CurveAffine},
7    plonk::{Any, Column, Error},
8    poly::{
9        commitment::{Blind, Params},
10        EvaluationDomain,
11    },
12};
13
14#[cfg(feature = "multicore")]
15use crate::multicore::{IndexedParallelIterator, ParallelIterator};
16
17#[cfg(feature = "thread-safe-region")]
18use std::collections::{BTreeSet, HashMap};
19
20#[cfg(not(feature = "thread-safe-region"))]
21/// Struct that accumulates all the necessary data in order to construct the permutation argument.
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub struct Assembly {
24    /// Columns that participate on the copy permutation argument.
25    columns: Vec<Column<Any>>,
26    /// Mapping of the actual copies done.
27    mapping: Vec<Vec<(usize, usize)>>,
28    /// Some aux data used to swap positions directly when sorting.
29    aux: Vec<Vec<(usize, usize)>>,
30    /// More aux data
31    sizes: Vec<Vec<usize>>,
32}
33
34#[cfg(not(feature = "thread-safe-region"))]
35impl Assembly {
36    pub(crate) fn new(n: usize, p: &Argument) -> Self {
37        // Initialize the copy vector to keep track of copy constraints in all
38        // the permutation arguments.
39        let mut columns = vec![];
40        for i in 0..p.columns.len() {
41            // Computes [(i, 0), (i, 1), ..., (i, n - 1)]
42            columns.push((0..n).map(|j| (i, j)).collect());
43        }
44
45        // Before any equality constraints are applied, every cell in the permutation is
46        // in a 1-cycle; therefore mapping and aux are identical, because every cell is
47        // its own distinguished element.
48        Assembly {
49            columns: p.columns.clone(),
50            mapping: columns.clone(),
51            aux: columns,
52            sizes: vec![vec![1usize; n]; p.columns.len()],
53        }
54    }
55
56    pub(crate) fn copy(
57        &mut self,
58        left_column: Column<Any>,
59        left_row: usize,
60        right_column: Column<Any>,
61        right_row: usize,
62    ) -> Result<(), Error> {
63        let left_column = self
64            .columns
65            .iter()
66            .position(|c| c == &left_column)
67            .ok_or(Error::ColumnNotInPermutation(left_column))?;
68        let right_column = self
69            .columns
70            .iter()
71            .position(|c| c == &right_column)
72            .ok_or(Error::ColumnNotInPermutation(right_column))?;
73
74        // Check bounds
75        if left_row >= self.mapping[left_column].len()
76            || right_row >= self.mapping[right_column].len()
77        {
78            return Err(Error::BoundsFailure);
79        }
80
81        // See book/src/design/permutation.md for a description of this algorithm.
82
83        let mut left_cycle = self.aux[left_column][left_row];
84        let mut right_cycle = self.aux[right_column][right_row];
85
86        // If left and right are in the same cycle, do nothing.
87        if left_cycle == right_cycle {
88            return Ok(());
89        }
90
91        if self.sizes[left_cycle.0][left_cycle.1] < self.sizes[right_cycle.0][right_cycle.1] {
92            std::mem::swap(&mut left_cycle, &mut right_cycle);
93        }
94
95        // Merge the right cycle into the left one.
96        self.sizes[left_cycle.0][left_cycle.1] += self.sizes[right_cycle.0][right_cycle.1];
97        let mut i = right_cycle;
98        loop {
99            self.aux[i.0][i.1] = left_cycle;
100            i = self.mapping[i.0][i.1];
101            if i == right_cycle {
102                break;
103            }
104        }
105
106        let tmp = self.mapping[left_column][left_row];
107        self.mapping[left_column][left_row] = self.mapping[right_column][right_row];
108        self.mapping[right_column][right_row] = tmp;
109
110        Ok(())
111    }
112
113    pub(crate) fn build_vk<'params, C: CurveAffine, P: Params<'params, C>>(
114        self,
115        params: &P,
116        domain: &EvaluationDomain<C::Scalar>,
117        p: &Argument,
118    ) -> VerifyingKey<C> {
119        build_vk(params, domain, p, |i, j| self.mapping[i][j])
120    }
121
122    pub(crate) fn build_pk<'params, C: CurveAffine, P: Params<'params, C>>(
123        self,
124        params: &P,
125        domain: &EvaluationDomain<C::Scalar>,
126        p: &Argument,
127    ) -> ProvingKey<C> {
128        build_pk(params, domain, p, |i, j| self.mapping[i][j])
129    }
130
131    /// Returns columns that participate in the permutation argument.
132    pub fn columns(&self) -> &[Column<Any>] {
133        &self.columns
134    }
135
136    #[cfg(feature = "multicore")]
137    /// Returns mappings of the copies.
138    pub fn mapping(
139        &self,
140    ) -> impl Iterator<Item = impl IndexedParallelIterator<Item = (usize, usize)> + '_> {
141        use crate::multicore::IntoParallelRefIterator;
142
143        self.mapping.iter().map(|c| c.par_iter().copied())
144    }
145
146    #[cfg(not(feature = "multicore"))]
147    /// Returns mappings of the copies.
148    pub fn mapping(&self) -> impl Iterator<Item = impl Iterator<Item = (usize, usize)> + '_> {
149        self.mapping.iter().map(|c| c.iter().copied())
150    }
151}
152
153#[cfg(feature = "thread-safe-region")]
154/// Struct that accumulates all the necessary data in order to construct the permutation argument.
155#[derive(Clone, Debug, PartialEq, Eq)]
156pub struct Assembly {
157    /// Columns that participate on the copy permutation argument.
158    columns: Vec<Column<Any>>,
159    /// Mapping of the actual copies done.
160    cycles: Vec<Vec<(usize, usize)>>,
161    /// Mapping of the actual copies done.
162    ordered_cycles: Vec<BTreeSet<(usize, usize)>>,
163    /// Mapping of the actual copies done.
164    aux: HashMap<(usize, usize), usize>,
165    /// total length of a column
166    col_len: usize,
167    /// number of columns
168    num_cols: usize,
169}
170
171#[cfg(feature = "thread-safe-region")]
172impl Assembly {
173    pub(crate) fn new(n: usize, p: &Argument) -> Self {
174        Assembly {
175            columns: p.columns.clone(),
176            cycles: Vec::with_capacity(n),
177            ordered_cycles: Vec::with_capacity(n),
178            aux: HashMap::new(),
179            col_len: n,
180            num_cols: p.columns.len(),
181        }
182    }
183
184    pub(crate) fn copy(
185        &mut self,
186        left_column: Column<Any>,
187        left_row: usize,
188        right_column: Column<Any>,
189        right_row: usize,
190    ) -> Result<(), Error> {
191        let left_column = self
192            .columns
193            .iter()
194            .position(|c| c == &left_column)
195            .ok_or(Error::ColumnNotInPermutation(left_column))?;
196        let right_column = self
197            .columns
198            .iter()
199            .position(|c| c == &right_column)
200            .ok_or(Error::ColumnNotInPermutation(right_column))?;
201
202        // Check bounds
203        if left_row >= self.col_len || right_row >= self.col_len {
204            return Err(Error::BoundsFailure);
205        }
206
207        let left_cycle = self.aux.get(&(left_column, left_row));
208        let right_cycle = self.aux.get(&(right_column, right_row));
209
210        // extract cycle elements
211        let right_cycle_elems = match right_cycle {
212            Some(i) => {
213                let entry = self.cycles[*i].clone();
214                self.cycles[*i] = vec![];
215                entry
216            }
217            None => [(right_column, right_row)].into(),
218        };
219
220        assert!(right_cycle_elems.contains(&(right_column, right_row)));
221
222        // merge cycles
223        let cycle_idx = match left_cycle {
224            Some(i) => {
225                let entry = &mut self.cycles[*i];
226                entry.extend(right_cycle_elems.clone());
227                *i
228            }
229            // if they were singletons -- create a new cycle entry
230            None => {
231                let mut set: Vec<(usize, usize)> = right_cycle_elems.clone();
232                set.push((left_column, left_row));
233                self.cycles.push(set);
234                let cycle_idx = self.cycles.len() - 1;
235                self.aux.insert((left_column, left_row), cycle_idx);
236                cycle_idx
237            }
238        };
239
240        let index_updates = vec![cycle_idx; right_cycle_elems.len()].into_iter();
241        let updates = right_cycle_elems.into_iter().zip(index_updates);
242
243        self.aux.extend(updates);
244
245        Ok(())
246    }
247
248    /// Builds the ordered mapping of the cycles.
249    /// This will only get executed once.
250    pub fn build_ordered_mapping(&mut self) {
251        use crate::multicore::IntoParallelRefMutIterator;
252
253        // will only get called once
254        if self.ordered_cycles.is_empty() && !self.cycles.is_empty() {
255            self.ordered_cycles = self
256                .cycles
257                .par_iter_mut()
258                .map(|col| {
259                    let mut set = BTreeSet::new();
260                    set.extend(col.clone());
261                    // free up memory
262                    *col = vec![];
263                    set
264                })
265                .collect();
266        }
267    }
268
269    fn mapping_at_idx(&self, col: usize, row: usize) -> (usize, usize) {
270        assert!(
271            !self.ordered_cycles.is_empty() || self.cycles.is_empty(),
272            "cycles have not been ordered"
273        );
274
275        if let Some(cycle_idx) = self.aux.get(&(col, row)) {
276            let cycle = &self.ordered_cycles[*cycle_idx];
277            let mut cycle_iter = cycle.range((
278                std::ops::Bound::Excluded((col, row)),
279                std::ops::Bound::Unbounded,
280            ));
281            // point to the next node in the cycle
282            match cycle_iter.next() {
283                Some((i, j)) => (*i, *j),
284                // wrap back around to the first element which SHOULD exist
285                None => *(cycle.iter().next().unwrap()),
286            }
287        // is a singleton
288        } else {
289            (col, row)
290        }
291    }
292
293    pub(crate) fn build_vk<'params, C: CurveAffine, P: Params<'params, C>>(
294        &mut self,
295        params: &P,
296        domain: &EvaluationDomain<C::Scalar>,
297        p: &Argument,
298    ) -> VerifyingKey<C> {
299        self.build_ordered_mapping();
300        build_vk(params, domain, p, |i, j| self.mapping_at_idx(i, j))
301    }
302
303    pub(crate) fn build_pk<'params, C: CurveAffine, P: Params<'params, C>>(
304        &mut self,
305        params: &P,
306        domain: &EvaluationDomain<C::Scalar>,
307        p: &Argument,
308    ) -> ProvingKey<C> {
309        self.build_ordered_mapping();
310        build_pk(params, domain, p, |i, j| self.mapping_at_idx(i, j))
311    }
312
313    /// Returns columns that participate in the permutation argument.
314    pub fn columns(&self) -> &[Column<Any>] {
315        &self.columns
316    }
317
318    #[cfg(feature = "multicore")]
319    /// Returns mappings of the copies.
320    pub fn mapping(
321        &self,
322    ) -> impl Iterator<Item = impl IndexedParallelIterator<Item = (usize, usize)> + '_> {
323        use crate::multicore::IntoParallelIterator;
324
325        (0..self.num_cols).map(move |i| {
326            (0..self.col_len)
327                .into_par_iter()
328                .map(move |j| self.mapping_at_idx(i, j))
329        })
330    }
331
332    #[cfg(not(feature = "multicore"))]
333    /// Returns mappings of the copies.
334    pub fn mapping(&self) -> impl Iterator<Item = impl Iterator<Item = (usize, usize)> + '_> {
335        (0..self.num_cols).map(move |i| (0..self.col_len).map(move |j| self.mapping_at_idx(i, j)))
336    }
337}
338
339pub(crate) fn build_pk<'params, C: CurveAffine, P: Params<'params, C>>(
340    params: &P,
341    domain: &EvaluationDomain<C::Scalar>,
342    p: &Argument,
343    mapping: impl Fn(usize, usize) -> (usize, usize) + Sync,
344) -> ProvingKey<C> {
345    // Compute [omega^0, omega^1, ..., omega^{params.n - 1}]
346    let mut omega_powers = vec![C::Scalar::ZERO; params.n() as usize];
347    {
348        let omega = domain.get_omega();
349        parallelize(&mut omega_powers, |o, start| {
350            let mut cur = omega.pow_vartime([start as u64]);
351            for v in o.iter_mut() {
352                *v = cur;
353                cur *= &omega;
354            }
355        })
356    }
357
358    // Compute [omega_powers * \delta^0, omega_powers * \delta^1, ..., omega_powers * \delta^m]
359    let mut deltaomega = vec![omega_powers; p.columns.len()];
360    {
361        parallelize(&mut deltaomega, |o, start| {
362            let mut cur = C::Scalar::DELTA.pow_vartime([start as u64]);
363            for omega_powers in o.iter_mut() {
364                for v in omega_powers {
365                    *v *= &cur;
366                }
367                cur *= &C::Scalar::DELTA;
368            }
369        });
370    }
371
372    // Compute permutation polynomials, convert to coset form.
373    let mut permutations = vec![domain.empty_lagrange(); p.columns.len()];
374    {
375        parallelize(&mut permutations, |o, start| {
376            for (x, permutation_poly) in o.iter_mut().enumerate() {
377                let i = start + x;
378                for (j, p) in permutation_poly.iter_mut().enumerate() {
379                    let (permuted_i, permuted_j) = mapping(i, j);
380                    *p = deltaomega[permuted_i][permuted_j];
381                }
382            }
383        });
384    }
385
386    let mut polys = vec![domain.empty_coeff(); p.columns.len()];
387    {
388        parallelize(&mut polys, |o, start| {
389            for (x, poly) in o.iter_mut().enumerate() {
390                let i = start + x;
391                let permutation_poly = permutations[i].clone();
392                *poly = domain.lagrange_to_coeff(permutation_poly);
393            }
394        });
395    }
396
397    ProvingKey {
398        permutations,
399        polys,
400    }
401}
402
403pub(crate) fn build_vk<'params, C: CurveAffine, P: Params<'params, C>>(
404    params: &P,
405    domain: &EvaluationDomain<C::Scalar>,
406    p: &Argument,
407    mapping: impl Fn(usize, usize) -> (usize, usize) + Sync,
408) -> VerifyingKey<C> {
409    // Compute [omega^0, omega^1, ..., omega^{params.n - 1}]
410    let mut omega_powers = vec![C::Scalar::ZERO; params.n() as usize];
411    {
412        let omega = domain.get_omega();
413        parallelize(&mut omega_powers, |o, start| {
414            let mut cur = omega.pow_vartime([start as u64]);
415            for v in o.iter_mut() {
416                *v = cur;
417                cur *= &omega;
418            }
419        })
420    }
421
422    // Compute [omega_powers * \delta^0, omega_powers * \delta^1, ..., omega_powers * \delta^m]
423    let mut deltaomega = vec![omega_powers; p.columns.len()];
424    {
425        parallelize(&mut deltaomega, |o, start| {
426            let mut cur = C::Scalar::DELTA.pow_vartime([start as u64]);
427            for omega_powers in o.iter_mut() {
428                for v in omega_powers {
429                    *v *= &cur;
430                }
431                cur *= &<C::Scalar as PrimeField>::DELTA;
432            }
433        });
434    }
435
436    // Computes the permutation polynomial based on the permutation
437    // description in the assembly.
438    let mut permutations = vec![domain.empty_lagrange(); p.columns.len()];
439    {
440        parallelize(&mut permutations, |o, start| {
441            for (x, permutation_poly) in o.iter_mut().enumerate() {
442                let i = start + x;
443                for (j, p) in permutation_poly.iter_mut().enumerate() {
444                    let (permuted_i, permuted_j) = mapping(i, j);
445                    *p = deltaomega[permuted_i][permuted_j];
446                }
447            }
448        });
449    }
450
451    // Pre-compute commitments for the URS.
452    let mut commitments = Vec::with_capacity(p.columns.len());
453    for permutation in &permutations {
454        // Compute commitment to permutation polynomial
455        commitments.push(
456            params
457                .commit_lagrange(permutation, Blind::default())
458                .to_affine(),
459        );
460    }
461
462    VerifyingKey { commitments }
463}