openvm_stark_backend/utils.rs
1use p3_field::Field;
2use tracing::instrument;
3
4use crate::air_builders::debug::USE_DEBUG_BUILDER;
5
6// Copied from valida-util
7/// Calculates and returns the multiplicative inverses of each field element, with zero
8/// values remaining unchanged.
9#[instrument(name = "batch_multiplicative_inverse", level = "info", skip_all)]
10pub fn batch_multiplicative_inverse_allowing_zero<F: Field>(values: Vec<F>) -> Vec<F> {
11 // Check if values are zero, and construct a new vector with only nonzero values
12 let mut nonzero_values = Vec::with_capacity(values.len());
13 let mut indices = Vec::with_capacity(values.len());
14 for (i, value) in values.iter().cloned().enumerate() {
15 if value.is_zero() {
16 continue;
17 }
18 nonzero_values.push(value);
19 indices.push(i);
20 }
21
22 // Compute the multiplicative inverse of nonzero values
23 let inverse_nonzero_values = p3_field::batch_multiplicative_inverse(&nonzero_values);
24
25 // Reconstruct the original vector
26 let mut result = values.clone();
27 for (i, index) in indices.into_iter().enumerate() {
28 result[index] = inverse_nonzero_values[i];
29 }
30
31 result
32}
33
34/// This utility function will parallelize an operation that is to be
35/// performed over a mutable slice.
36///
37/// Assumes that slice length is a multiple of `chunk_size` and parallelization preserves the chunks
38/// so each slice in a thread is still multiple of `chunk_size`.
39///
40/// The closure `f` takes `(thread_slice, idx)` where `thread_slice` is a sub-slice starting at
41/// `v[idx]`.
42// Copied and modified from https://github.com/axiom-crypto/halo2/blob/4e584896b62c981ec7c7dced4a9ca95b82306550/halo2_proofs/src/arithmetic.rs#L157
43pub fn parallelize_chunks<T, F>(v: &mut [T], chunk_size: usize, f: F)
44where
45 T: Send,
46 F: Fn(&mut [T], usize) + Send + Sync + Clone,
47{
48 debug_assert_eq!(v.len() % chunk_size, 0);
49 #[cfg(not(feature = "parallel"))]
50 {
51 f(v, 0)
52 }
53 // Algorithm rationale:
54 //
55 // Using the stdlib `chunks_mut` will lead to severe load imbalance.
56 // From https://github.com/rust-lang/rust/blob/e94bda3/library/core/src/slice/iter.rs#L1607-L1637
57 // if the division is not exact, the last chunk will be the remainder.
58 //
59 // Dividing 40 items on 12 threads will lead to a chunk size of 40/12 = 3,
60 // There will be a 13 chunks of size 3 and 1 of size 1 distributed on 12 threads.
61 // This leads to 1 thread working on 6 iterations, 1 on 4 iterations and 10 on 3 iterations,
62 // a load imbalance of 2x.
63 //
64 // Instead we can divide work into chunks of size
65 // 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3 = 4*4 + 3*8 = 40
66 //
67 // This would lead to a 6/4 = 1.5x speedup compared to naive chunks_mut
68 //
69 // See also OpenMP spec (page 60)
70 // http://www.openmp.org/mp-documents/openmp-4.5.pdf
71 // "When no chunk_size is specified, the iteration space is divided into chunks
72 // that are approximately equal in size, and at most one chunk is distributed to
73 // each thread. The size of the chunks is unspecified in this case."
74 // This implies chunks are the same size ±1
75 #[cfg(feature = "parallel")]
76 {
77 let f = &f;
78 let total_iters = v.len() / chunk_size;
79 let num_threads = rayon::current_num_threads();
80
81 let lo_slice_size = (total_iters / num_threads) * chunk_size;
82 let hi_slice_size = lo_slice_size + chunk_size;
83 let cutoff_thread_idx = total_iters % num_threads;
84 let split_pos = cutoff_thread_idx * hi_slice_size;
85 let (v_hi, v_lo) = v.split_at_mut(split_pos);
86
87 rayon::scope(|scope| {
88 // Skip special-case: number of iterations is cleanly divided by number of threads.
89 if cutoff_thread_idx != 0 {
90 for (chunk_id, chunk) in v_hi.chunks_exact_mut(hi_slice_size).enumerate() {
91 let offset = chunk_id * hi_slice_size;
92 scope.spawn(move |_| f(chunk, offset));
93 }
94 }
95 // Skip special-case: less iterations than number of threads.
96 if lo_slice_size != 0 {
97 for (chunk_id, chunk) in v_lo.chunks_exact_mut(lo_slice_size).enumerate() {
98 let offset = split_pos + (chunk_id * lo_slice_size);
99 scope.spawn(move |_| f(chunk, offset));
100 }
101 }
102 });
103 }
104}
105
106/// Disables the debug builder so there are not debug assert panics.
107/// Commonly used in negative tests to prevent panics.
108pub fn disable_debug_builder() {
109 USE_DEBUG_BUILDER.with(|debug| {
110 *debug.lock().unwrap() = false;
111 });
112}
113
114#[macro_export]
115#[cfg(feature = "parallel")]
116macro_rules! parizip {
117 ( $first:expr $( , $rest:expr )* $(,)* ) => {
118 {
119 use rayon::iter::*;
120 (( $first $( , $rest)* )).into_par_iter()
121 }
122 };
123}
124#[macro_export]
125#[cfg(not(feature = "parallel"))]
126macro_rules! parizip {
127 ( $first:expr $( , $rest:expr )* $(,)* ) => {
128 itertools::izip!( $first $( , $rest)* )
129 };
130}