halo2_axiom/
arithmetic.rs

1//! This module provides common utilities, traits and structures for group,
2//! field and polynomial arithmetic.
3
4use super::multicore;
5pub use ff::Field;
6use group::{
7    ff::{BatchInvert, PrimeField},
8    prime::PrimeCurveAffine,
9    Curve, GroupOpsOwned, ScalarMulOwned,
10};
11use rayon::prelude::*;
12
13use halo2curves::msm::msm_best;
14pub use halo2curves::{CurveAffine, CurveExt};
15
16/// This represents an element of a group with basic operations that can be
17/// performed. This allows an FFT implementation (for example) to operate
18/// generically over either a field or elliptic curve group.
19pub trait FftGroup<Scalar: Field>:
20    Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
21{
22}
23
24impl<T, Scalar> FftGroup<Scalar> for T
25where
26    Scalar: Field,
27    T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
28{
29}
30
31// [JPW] Keep this adapter to halo2curves to minimize code changes.
32/// Performs a multi-exponentiation operation.
33///
34/// This function will panic if coeffs and bases have a different length.
35///
36/// This will use multithreading if beneficial.
37pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
38    msm_best(coeffs, bases)
39}
40
41/// Dispatcher
42pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(
43    a: &mut [G],
44    omega: Scalar,
45    log_n: u32,
46    data: &FFTData<Scalar>,
47    inverse: bool,
48) {
49    fft::fft(a, omega, log_n, data, inverse);
50}
51
52/// Convert coefficient bases group elements to lagrange basis by inverse FFT.
53pub fn g_to_lagrange<C: PrimeCurveAffine>(g_projective: Vec<C::Curve>, k: u32) -> Vec<C> {
54    let n_inv = C::Scalar::TWO_INV.pow_vartime([k as u64, 0, 0, 0]);
55    let omega = C::Scalar::ROOT_OF_UNITY;
56    let mut omega_inv = C::Scalar::ROOT_OF_UNITY_INV;
57    for _ in k..C::Scalar::S {
58        omega_inv = omega_inv.square();
59    }
60
61    let mut g_lagrange_projective = g_projective;
62    let n = g_lagrange_projective.len();
63    let fft_data = FFTData::new(n, omega, omega_inv);
64
65    best_fft(&mut g_lagrange_projective, omega_inv, k, &fft_data, true);
66    parallelize(&mut g_lagrange_projective, |g, _| {
67        for g in g.iter_mut() {
68            *g *= n_inv;
69        }
70    });
71
72    let mut g_lagrange = vec![C::identity(); 1 << k];
73    parallelize(&mut g_lagrange, |g_lagrange, starts| {
74        C::Curve::batch_normalize(
75            &g_lagrange_projective[starts..(starts + g_lagrange.len())],
76            g_lagrange,
77        );
78    });
79
80    g_lagrange
81}
82
83/// This evaluates a provided polynomial (in coefficient form) at `point`.
84pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
85    fn evaluate<F: Field>(poly: &[F], point: F) -> F {
86        poly.iter()
87            .rev()
88            .fold(F::ZERO, |acc, coeff| acc * point + coeff)
89    }
90    let n = poly.len();
91    let num_threads = multicore::current_num_threads();
92    if n * 2 < num_threads {
93        evaluate(poly, point)
94    } else {
95        let chunk_size = (n + num_threads - 1) / num_threads;
96        let mut parts = vec![F::ZERO; num_threads];
97        multicore::scope(|scope| {
98            for (chunk_idx, (out, poly)) in
99                parts.chunks_mut(1).zip(poly.chunks(chunk_size)).enumerate()
100            {
101                scope.spawn(move |_| {
102                    let start = chunk_idx * chunk_size;
103                    out[0] = evaluate(poly, point) * point.pow_vartime([start as u64, 0, 0, 0]);
104                });
105            }
106        });
107        parts.iter().fold(F::ZERO, |acc, coeff| acc + coeff)
108    }
109}
110
111/// This computes the inner product of two vectors `a` and `b`.
112///
113/// This function will panic if the two vectors are not the same size.
114/// For vectors smaller than 32 elements, it uses sequential computation for better performance.
115/// For larger vectors, it switches to parallel computation.
116pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
117    assert_eq!(a.len(), b.len());
118
119    if a.len() < 32 {
120        // Use sequential computation for small vectors
121        let mut acc = F::ZERO;
122        for (a, b) in a.iter().zip(b.iter()) {
123            acc += (*a) * (*b);
124        }
125        return acc;
126    }
127
128    // Use parallel computation
129    a.par_iter().zip(b.par_iter()).map(|(a, b)| (*a) * b).sum()
130}
131
132/// Divides polynomial `a` in `X` by `X - b` with
133/// no remainder.
134pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
135where
136    I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
137{
138    b = -b;
139    let a = a.into_iter();
140
141    let mut q = vec![F::ZERO; a.len() - 1];
142
143    let mut tmp = F::ZERO;
144    for (q, r) in q.iter_mut().rev().zip(a.rev()) {
145        let mut lead_coeff = *r;
146        lead_coeff.sub_assign(&tmp);
147        *q = lead_coeff;
148        tmp = lead_coeff;
149        tmp.mul_assign(&b);
150    }
151
152    q
153}
154
155/// This utility function will parallelize an operation that is to be
156/// performed over a mutable slice.
157pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
158    // Algorithm rationale:
159    //
160    // Using the stdlib `chunks_mut` will lead to severe load imbalance.
161    // From https://github.com/rust-lang/rust/blob/e94bda3/library/core/src/slice/iter.rs#L1607-L1637
162    // if the division is not exact, the last chunk will be the remainder.
163    //
164    // Dividing 40 items on 12 threads will lead to a chunk size of 40/12 = 3,
165    // There will be a 13 chunks of size 3 and 1 of size 1 distributed on 12 threads.
166    // This leads to 1 thread working on 6 iterations, 1 on 4 iterations and 10 on 3 iterations,
167    // a load imbalance of 2x.
168    //
169    // Instead we can divide work into chunks of size
170    // 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3 = 4*4 + 3*8 = 40
171    //
172    // This would lead to a 6/4 = 1.5x speedup compared to naive chunks_mut
173    //
174    // See also OpenMP spec (page 60)
175    // http://www.openmp.org/mp-documents/openmp-4.5.pdf
176    // "When no chunk_size is specified, the iteration space is divided into chunks
177    // that are approximately equal in size, and at most one chunk is distributed to
178    // each thread. The size of the chunks is unspecified in this case."
179    // This implies chunks are the same size ±1
180
181    let f = &f;
182    let total_iters = v.len();
183    let num_threads = multicore::current_num_threads();
184    let base_chunk_size = total_iters / num_threads;
185    let cutoff_chunk_id = total_iters % num_threads;
186    let split_pos = cutoff_chunk_id * (base_chunk_size + 1);
187    let (v_hi, v_lo) = v.split_at_mut(split_pos);
188
189    multicore::scope(|scope| {
190        // Skip special-case: number of iterations is cleanly divided by number of threads.
191        if cutoff_chunk_id != 0 {
192            for (chunk_id, chunk) in v_hi.chunks_exact_mut(base_chunk_size + 1).enumerate() {
193                let offset = chunk_id * (base_chunk_size + 1);
194                scope.spawn(move |_| f(chunk, offset));
195            }
196        }
197        // Skip special-case: less iterations than number of threads.
198        if base_chunk_size != 0 {
199            for (chunk_id, chunk) in v_lo.chunks_exact_mut(base_chunk_size).enumerate() {
200                let offset = split_pos + (chunk_id * base_chunk_size);
201                scope.spawn(move |_| f(chunk, offset));
202            }
203        }
204    });
205}
206
207pub fn log2_floor(num: usize) -> u32 {
208    assert!(num > 0);
209
210    let mut pow = 0;
211
212    while (1 << (pow + 1)) <= num {
213        pow += 1;
214    }
215
216    pow
217}
218
219/// Returns coefficients of an n - 1 degree polynomial given a set of n points
220/// and their evaluations. This function will panic if two values in `points`
221/// are the same.
222pub fn lagrange_interpolate<F: Field>(points: &[F], evals: &[F]) -> Vec<F> {
223    assert_eq!(points.len(), evals.len());
224    if points.len() == 1 {
225        // Constant polynomial
226        vec![evals[0]]
227    } else {
228        let mut denoms = Vec::with_capacity(points.len());
229        for (j, x_j) in points.iter().enumerate() {
230            let mut denom = Vec::with_capacity(points.len() - 1);
231            for x_k in points
232                .iter()
233                .enumerate()
234                .filter(|&(k, _)| k != j)
235                .map(|a| a.1)
236            {
237                denom.push(*x_j - x_k);
238            }
239            denoms.push(denom);
240        }
241        // Compute (x_j - x_k)^(-1) for each j != i
242        denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
243
244        let mut final_poly = vec![F::ZERO; points.len()];
245        for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
246            let mut tmp: Vec<F> = Vec::with_capacity(points.len());
247            let mut product = Vec::with_capacity(points.len() - 1);
248            tmp.push(F::ONE);
249            for (x_k, denom) in points
250                .iter()
251                .enumerate()
252                .filter(|&(k, _)| k != j)
253                .map(|a| a.1)
254                .zip(denoms)
255            {
256                product.resize(tmp.len() + 1, F::ZERO);
257                for ((a, b), product) in tmp
258                    .iter()
259                    .chain(std::iter::once(&F::ZERO))
260                    .zip(std::iter::once(&F::ZERO).chain(tmp.iter()))
261                    .zip(product.iter_mut())
262                {
263                    *product = *a * (-denom * x_k) + *b * denom;
264                }
265                std::mem::swap(&mut tmp, &mut product);
266            }
267            assert_eq!(tmp.len(), points.len());
268            assert_eq!(product.len(), points.len() - 1);
269            for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp) {
270                *final_coeff += interpolation_coeff * eval;
271            }
272        }
273        final_poly
274    }
275}
276
277pub(crate) fn evaluate_vanishing_polynomial<F: Field>(roots: &[F], z: F) -> F {
278    fn evaluate<F: Field>(roots: &[F], z: F) -> F {
279        roots.iter().fold(F::ONE, |acc, point| (z - point) * acc)
280    }
281    let n = roots.len();
282    let num_threads = multicore::current_num_threads();
283    if n * 2 < num_threads {
284        evaluate(roots, z)
285    } else {
286        let chunk_size = (n + num_threads - 1) / num_threads;
287        let mut parts = vec![F::ONE; num_threads];
288        multicore::scope(|scope| {
289            for (out, roots) in parts.chunks_mut(1).zip(roots.chunks(chunk_size)) {
290                scope.spawn(move |_| out[0] = evaluate(roots, z));
291            }
292        });
293        parts.iter().fold(F::ONE, |acc, part| acc * part)
294    }
295}
296
297pub(crate) fn powers<F: Field>(base: F) -> impl Iterator<Item = F> {
298    std::iter::successors(Some(F::ONE), move |power| Some(base * power))
299}
300
301/// Reverse `l` LSBs of bitvector `n`
302pub fn bitreverse(mut n: usize, l: usize) -> usize {
303    let mut r = 0;
304    for _ in 0..l {
305        r = (r << 1) | (n & 1);
306        n >>= 1;
307    }
308    r
309}
310
311#[cfg(test)]
312use rand_core::OsRng;
313
314use crate::fft::{self, recursive::FFTData};
315#[cfg(test)]
316use crate::halo2curves::pasta::Fp;
317// use crate::plonk::{get_duration, get_time, start_measure, stop_measure};
318
319#[test]
320fn test_lagrange_interpolate() {
321    let rng = OsRng;
322
323    let points = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
324    let evals = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
325
326    for coeffs in 0..5 {
327        let points = &points[0..coeffs];
328        let evals = &evals[0..coeffs];
329
330        let poly = lagrange_interpolate(points, evals);
331        assert_eq!(poly.len(), points.len());
332
333        for (point, eval) in points.iter().zip(evals) {
334            assert_eq!(eval_polynomial(&poly, *point), *eval);
335        }
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use rand_core::OsRng;
343
344    #[test]
345    fn test_compute_inner_product() {
346        let rng = OsRng;
347
348        // Test small vectors (sequential)
349        let a_small: Vec<Fp> = (0..16).map(|_| Fp::random(rng)).collect();
350        let b_small: Vec<Fp> = (0..16).map(|_| Fp::random(rng)).collect();
351        let result_small = compute_inner_product(&a_small, &b_small);
352        let expected_small = a_small
353            .iter()
354            .zip(b_small.iter())
355            .fold(Fp::ZERO, |acc, (a, b)| acc + (*a) * (*b));
356        assert_eq!(result_small, expected_small);
357
358        // Test large vectors (parallel)
359        let a_large: Vec<Fp> = (0..64).map(|_| Fp::random(rng)).collect();
360        let b_large: Vec<Fp> = (0..64).map(|_| Fp::random(rng)).collect();
361        let result_large = compute_inner_product(&a_large, &b_large);
362        let expected_large = a_large
363            .iter()
364            .zip(b_large.iter())
365            .fold(Fp::ZERO, |acc, (a, b)| acc + (*a) * (*b));
366        assert_eq!(result_large, expected_large);
367    }
368}