openvm_stark_backend/
utils.rs

1use p3_field::Field;
2use tracing::instrument;
3
4use crate::air_builders::debug::USE_DEBUG_BUILDER;
5
6// Copied from valida-util
7/// Calculates and returns the multiplicative inverses of each field element, with zero
8/// values remaining unchanged.
9#[instrument(name = "batch_multiplicative_inverse", level = "info", skip_all)]
10pub fn batch_multiplicative_inverse_allowing_zero<F: Field>(values: Vec<F>) -> Vec<F> {
11    // Check if values are zero, and construct a new vector with only nonzero values
12    let mut nonzero_values = Vec::with_capacity(values.len());
13    let mut indices = Vec::with_capacity(values.len());
14    for (i, value) in values.iter().cloned().enumerate() {
15        if value.is_zero() {
16            continue;
17        }
18        nonzero_values.push(value);
19        indices.push(i);
20    }
21
22    // Compute the multiplicative inverse of nonzero values
23    let inverse_nonzero_values = p3_field::batch_multiplicative_inverse(&nonzero_values);
24
25    // Reconstruct the original vector
26    let mut result = values.clone();
27    for (i, index) in indices.into_iter().enumerate() {
28        result[index] = inverse_nonzero_values[i];
29    }
30
31    result
32}
33
34/// This utility function will parallelize an operation that is to be
35/// performed over a mutable slice.
36///
37/// Assumes that slice length is a multiple of `chunk_size` and parallelization preserves the chunks
38/// so each slice in a thread is still multiple of `chunk_size`.
39///
40/// The closure `f` takes `(thread_slice, idx)` where `thread_slice` is a sub-slice starting at
41/// `v[idx]`.
42// Copied and modified from https://github.com/axiom-crypto/halo2/blob/4e584896b62c981ec7c7dced4a9ca95b82306550/halo2_proofs/src/arithmetic.rs#L157
43pub fn parallelize_chunks<T, F>(v: &mut [T], chunk_size: usize, f: F)
44where
45    T: Send,
46    F: Fn(&mut [T], usize) + Send + Sync + Clone,
47{
48    debug_assert_eq!(v.len() % chunk_size, 0);
49    #[cfg(not(feature = "parallel"))]
50    {
51        f(v, 0)
52    }
53    // Algorithm rationale:
54    //
55    // Using the stdlib `chunks_mut` will lead to severe load imbalance.
56    // From https://github.com/rust-lang/rust/blob/e94bda3/library/core/src/slice/iter.rs#L1607-L1637
57    // if the division is not exact, the last chunk will be the remainder.
58    //
59    // Dividing 40 items on 12 threads will lead to a chunk size of 40/12 = 3,
60    // There will be a 13 chunks of size 3 and 1 of size 1 distributed on 12 threads.
61    // This leads to 1 thread working on 6 iterations, 1 on 4 iterations and 10 on 3 iterations,
62    // a load imbalance of 2x.
63    //
64    // Instead we can divide work into chunks of size
65    // 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3 = 4*4 + 3*8 = 40
66    //
67    // This would lead to a 6/4 = 1.5x speedup compared to naive chunks_mut
68    //
69    // See also OpenMP spec (page 60)
70    // http://www.openmp.org/mp-documents/openmp-4.5.pdf
71    // "When no chunk_size is specified, the iteration space is divided into chunks
72    // that are approximately equal in size, and at most one chunk is distributed to
73    // each thread. The size of the chunks is unspecified in this case."
74    // This implies chunks are the same size ±1
75    #[cfg(feature = "parallel")]
76    {
77        let f = &f;
78        let total_iters = v.len() / chunk_size;
79        let num_threads = rayon::current_num_threads();
80
81        let lo_slice_size = (total_iters / num_threads) * chunk_size;
82        let hi_slice_size = lo_slice_size + chunk_size;
83        let cutoff_thread_idx = total_iters % num_threads;
84        let split_pos = cutoff_thread_idx * hi_slice_size;
85        let (v_hi, v_lo) = v.split_at_mut(split_pos);
86
87        rayon::scope(|scope| {
88            // Skip special-case: number of iterations is cleanly divided by number of threads.
89            if cutoff_thread_idx != 0 {
90                for (chunk_id, chunk) in v_hi.chunks_exact_mut(hi_slice_size).enumerate() {
91                    let offset = chunk_id * hi_slice_size;
92                    scope.spawn(move |_| f(chunk, offset));
93                }
94            }
95            // Skip special-case: less iterations than number of threads.
96            if lo_slice_size != 0 {
97                for (chunk_id, chunk) in v_lo.chunks_exact_mut(lo_slice_size).enumerate() {
98                    let offset = split_pos + (chunk_id * lo_slice_size);
99                    scope.spawn(move |_| f(chunk, offset));
100                }
101            }
102        });
103    }
104}
105
106/// Disables the debug builder so there are not debug assert panics.
107/// Commonly used in negative tests to prevent panics.
108pub fn disable_debug_builder() {
109    USE_DEBUG_BUILDER.with(|debug| {
110        *debug.lock().unwrap() = false;
111    });
112}
113
114#[macro_export]
115#[cfg(feature = "parallel")]
116macro_rules! parizip {
117    ( $first:expr $( , $rest:expr )* $(,)* ) => {
118        {
119            use rayon::iter::*;
120            (( $first $( , $rest)* )).into_par_iter()
121        }
122    };
123}
124#[macro_export]
125#[cfg(not(feature = "parallel"))]
126macro_rules! parizip {
127    ( $first:expr $( , $rest:expr )* $(,)* ) => {
128        itertools::izip!( $first $( , $rest)* )
129    };
130}