halo2_proofs/
arithmetic.rs

1//! This module provides common utilities, traits and structures for group,
2//! field and polynomial arithmetic.
3
4use super::multicore;
5pub use ff::Field;
6use group::{
7    ff::{BatchInvert, PrimeField},
8    Group as _,
9};
10
11pub use pasta_curves::arithmetic::*;
12
13fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
14    let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
15
16    let c = if bases.len() < 4 {
17        1
18    } else if bases.len() < 32 {
19        3
20    } else {
21        (f64::from(bases.len() as u32)).ln().ceil() as usize
22    };
23
24    fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
25        let skip_bits = segment * c;
26        let skip_bytes = skip_bits / 8;
27
28        if skip_bytes >= 32 {
29            return 0;
30        }
31
32        let mut v = [0; 8];
33        for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
34            *v = *o;
35        }
36
37        let mut tmp = u64::from_le_bytes(v);
38        tmp >>= skip_bits - (skip_bytes * 8);
39        tmp = tmp % (1 << c);
40
41        tmp as usize
42    }
43
44    let segments = (256 / c) + 1;
45
46    for current_segment in (0..segments).rev() {
47        for _ in 0..c {
48            *acc = acc.double();
49        }
50
51        #[derive(Clone, Copy)]
52        enum Bucket<C: CurveAffine> {
53            None,
54            Affine(C),
55            Projective(C::Curve),
56        }
57
58        impl<C: CurveAffine> Bucket<C> {
59            fn add_assign(&mut self, other: &C) {
60                *self = match *self {
61                    Bucket::None => Bucket::Affine(*other),
62                    Bucket::Affine(a) => Bucket::Projective(a + *other),
63                    Bucket::Projective(mut a) => {
64                        a += *other;
65                        Bucket::Projective(a)
66                    }
67                }
68            }
69
70            fn add(self, mut other: C::Curve) -> C::Curve {
71                match self {
72                    Bucket::None => other,
73                    Bucket::Affine(a) => {
74                        other += a;
75                        other
76                    }
77                    Bucket::Projective(a) => other + &a,
78                }
79            }
80        }
81
82        let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
83
84        for (coeff, base) in coeffs.iter().zip(bases.iter()) {
85            let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
86            if coeff != 0 {
87                buckets[coeff - 1].add_assign(base);
88            }
89        }
90
91        // Summation by parts
92        // e.g. 3a + 2b + 1c = a +
93        //                    (a) + b +
94        //                    ((a) + b) + c
95        let mut running_sum = C::Curve::identity();
96        for exp in buckets.into_iter().rev() {
97            running_sum = exp.add(running_sum);
98            *acc = *acc + &running_sum;
99        }
100    }
101}
102
103/// Performs a small multi-exponentiation operation.
104/// Uses the double-and-add algorithm with doublings shared across points.
105pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
106    let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
107    let mut acc = C::Curve::identity();
108
109    // for byte idx
110    for byte_idx in (0..32).rev() {
111        // for bit idx
112        for bit_idx in (0..8).rev() {
113            acc = acc.double();
114            // for each coeff
115            for coeff_idx in 0..coeffs.len() {
116                let byte = coeffs[coeff_idx].as_ref()[byte_idx];
117                if ((byte >> bit_idx) & 1) != 0 {
118                    acc += bases[coeff_idx];
119                }
120            }
121        }
122    }
123
124    acc
125}
126
127/// Performs a multi-exponentiation operation.
128///
129/// This function will panic if coeffs and bases have a different length.
130///
131/// This will use multithreading if beneficial.
132pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
133    assert_eq!(coeffs.len(), bases.len());
134
135    let num_threads = multicore::current_num_threads();
136    if coeffs.len() > num_threads {
137        let chunk = coeffs.len() / num_threads;
138        let num_chunks = coeffs.chunks(chunk).len();
139        let mut results = vec![C::Curve::identity(); num_chunks];
140        multicore::scope(|scope| {
141            let chunk = coeffs.len() / num_threads;
142
143            for ((coeffs, bases), acc) in coeffs
144                .chunks(chunk)
145                .zip(bases.chunks(chunk))
146                .zip(results.iter_mut())
147            {
148                scope.spawn(move |_| {
149                    multiexp_serial(coeffs, bases, acc);
150                });
151            }
152        });
153        results.iter().fold(C::Curve::identity(), |a, b| a + b)
154    } else {
155        let mut acc = C::Curve::identity();
156        multiexp_serial(coeffs, bases, &mut acc);
157        acc
158    }
159}
160
161/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size
162/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative
163/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when
164/// interpreted as the coefficients of a polynomial of degree $n - 1$, is
165/// transformed into the evaluations of this polynomial at each of the $n$
166/// distinct powers of $\omega$. This transformation is invertible by providing
167/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element
168/// by $n$.
169///
170/// This will use multithreading if beneficial.
171pub fn best_fft<G: Group>(a: &mut [G], omega: G::Scalar, log_n: u32) {
172    fn bitreverse(mut n: usize, l: usize) -> usize {
173        let mut r = 0;
174        for _ in 0..l {
175            r = (r << 1) | (n & 1);
176            n >>= 1;
177        }
178        r
179    }
180
181    let threads = multicore::current_num_threads();
182    let log_threads = log2_floor(threads);
183    let n = a.len() as usize;
184    assert_eq!(n, 1 << log_n);
185
186    for k in 0..n {
187        let rk = bitreverse(k, log_n as usize);
188        if k < rk {
189            a.swap(rk, k);
190        }
191    }
192
193    // precompute twiddle factors
194    let twiddles: Vec<_> = (0..(n / 2) as usize)
195        .scan(G::Scalar::one(), |w, _| {
196            let tw = *w;
197            w.group_scale(&omega);
198            Some(tw)
199        })
200        .collect();
201
202    if log_n <= log_threads {
203        let mut chunk = 2_usize;
204        let mut twiddle_chunk = (n / 2) as usize;
205        for _ in 0..log_n {
206            a.chunks_mut(chunk).for_each(|coeffs| {
207                let (left, right) = coeffs.split_at_mut(chunk / 2);
208
209                // case when twiddle factor is one
210                let (a, left) = left.split_at_mut(1);
211                let (b, right) = right.split_at_mut(1);
212                let t = b[0];
213                b[0] = a[0];
214                a[0].group_add(&t);
215                b[0].group_sub(&t);
216
217                left.iter_mut()
218                    .zip(right.iter_mut())
219                    .enumerate()
220                    .for_each(|(i, (a, b))| {
221                        let mut t = *b;
222                        t.group_scale(&twiddles[(i + 1) * twiddle_chunk]);
223                        *b = *a;
224                        a.group_add(&t);
225                        b.group_sub(&t);
226                    });
227            });
228            chunk *= 2;
229            twiddle_chunk /= 2;
230        }
231    } else {
232        recursive_butterfly_arithmetic(a, n, 1, &twiddles)
233    }
234}
235
236/// This perform recursive butterfly arithmetic
237pub fn recursive_butterfly_arithmetic<G: Group>(
238    a: &mut [G],
239    n: usize,
240    twiddle_chunk: usize,
241    twiddles: &[G::Scalar],
242) {
243    if n == 2 {
244        let t = a[1];
245        a[1] = a[0];
246        a[0].group_add(&t);
247        a[1].group_sub(&t);
248    } else {
249        let (left, right) = a.split_at_mut(n / 2);
250        rayon::join(
251            || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
252            || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
253        );
254
255        // case when twiddle factor is one
256        let (a, left) = left.split_at_mut(1);
257        let (b, right) = right.split_at_mut(1);
258        let t = b[0];
259        b[0] = a[0];
260        a[0].group_add(&t);
261        b[0].group_sub(&t);
262
263        left.iter_mut()
264            .zip(right.iter_mut())
265            .enumerate()
266            .for_each(|(i, (a, b))| {
267                let mut t = *b;
268                t.group_scale(&twiddles[(i + 1) * twiddle_chunk]);
269                *b = *a;
270                a.group_add(&t);
271                b.group_sub(&t);
272            });
273    }
274}
275
276/// This evaluates a provided polynomial (in coefficient form) at `point`.
277pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
278    // TODO: parallelize?
279    poly.iter()
280        .rev()
281        .fold(F::zero(), |acc, coeff| acc * point + coeff)
282}
283
284/// This computes the inner product of two vectors `a` and `b`.
285///
286/// This function will panic if the two vectors are not the same size.
287pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
288    // TODO: parallelize?
289    assert_eq!(a.len(), b.len());
290
291    let mut acc = F::zero();
292    for (a, b) in a.iter().zip(b.iter()) {
293        acc += (*a) * (*b);
294    }
295
296    acc
297}
298
299/// Divides polynomial `a` in `X` by `X - b` with
300/// no remainder.
301pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
302where
303    I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
304{
305    b = -b;
306    let a = a.into_iter();
307
308    let mut q = vec![F::zero(); a.len() - 1];
309
310    let mut tmp = F::zero();
311    for (q, r) in q.iter_mut().rev().zip(a.rev()) {
312        let mut lead_coeff = *r;
313        lead_coeff.sub_assign(&tmp);
314        *q = lead_coeff;
315        tmp = lead_coeff;
316        tmp.mul_assign(&b);
317    }
318
319    q
320}
321
322/// This simple utility function will parallelize an operation that is to be
323/// performed over a mutable slice.
324pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
325    let n = v.len();
326    let num_threads = multicore::current_num_threads();
327    let mut chunk = (n as usize) / num_threads;
328    if chunk < num_threads {
329        chunk = n as usize;
330    }
331
332    multicore::scope(|scope| {
333        for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
334            let f = f.clone();
335            scope.spawn(move |_| {
336                let start = chunk_num * chunk;
337                f(v, start);
338            });
339        }
340    });
341}
342
343fn log2_floor(num: usize) -> u32 {
344    assert!(num > 0);
345
346    let mut pow = 0;
347
348    while (1 << (pow + 1)) <= num {
349        pow += 1;
350    }
351
352    pow
353}
354
355/// Returns coefficients of an n - 1 degree polynomial given a set of n points
356/// and their evaluations. This function will panic if two values in `points`
357/// are the same.
358pub fn lagrange_interpolate<F: FieldExt>(points: &[F], evals: &[F]) -> Vec<F> {
359    assert_eq!(points.len(), evals.len());
360    if points.len() == 1 {
361        // Constant polynomial
362        return vec![evals[0]];
363    } else {
364        let mut denoms = Vec::with_capacity(points.len());
365        for (j, x_j) in points.iter().enumerate() {
366            let mut denom = Vec::with_capacity(points.len() - 1);
367            for x_k in points
368                .iter()
369                .enumerate()
370                .filter(|&(k, _)| k != j)
371                .map(|a| a.1)
372            {
373                denom.push(*x_j - x_k);
374            }
375            denoms.push(denom);
376        }
377        // Compute (x_j - x_k)^(-1) for each j != i
378        denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
379
380        let mut final_poly = vec![F::zero(); points.len()];
381        for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
382            let mut tmp: Vec<F> = Vec::with_capacity(points.len());
383            let mut product = Vec::with_capacity(points.len() - 1);
384            tmp.push(F::one());
385            for (x_k, denom) in points
386                .iter()
387                .enumerate()
388                .filter(|&(k, _)| k != j)
389                .map(|a| a.1)
390                .zip(denoms.into_iter())
391            {
392                product.resize(tmp.len() + 1, F::zero());
393                for ((a, b), product) in tmp
394                    .iter()
395                    .chain(std::iter::once(&F::zero()))
396                    .zip(std::iter::once(&F::zero()).chain(tmp.iter()))
397                    .zip(product.iter_mut())
398                {
399                    *product = *a * (-denom * x_k) + *b * denom;
400                }
401                std::mem::swap(&mut tmp, &mut product);
402            }
403            assert_eq!(tmp.len(), points.len());
404            assert_eq!(product.len(), points.len() - 1);
405            for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp.into_iter()) {
406                *final_coeff += interpolation_coeff * eval;
407            }
408        }
409        final_poly
410    }
411}
412
413#[cfg(test)]
414use rand_core::OsRng;
415
416#[cfg(test)]
417use crate::pasta::Fp;
418
419#[test]
420fn test_lagrange_interpolate() {
421    let rng = OsRng;
422
423    let points = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
424    let evals = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
425
426    for coeffs in 0..5 {
427        let points = &points[0..coeffs];
428        let evals = &evals[0..coeffs];
429
430        let poly = lagrange_interpolate(points, evals);
431        assert_eq!(poly.len(), points.len());
432
433        for (point, eval) in points.iter().zip(evals) {
434            assert_eq!(eval_polynomial(&poly, *point), *eval);
435        }
436    }
437}