p3_field/
batch_inverse.rs1use alloc::vec::Vec;
2
3use p3_maybe_rayon::prelude::*;
4use tracing::instrument;
5
6use crate::field::Field;
7use crate::{FieldAlgebra, FieldArray, PackedValue};
8
9#[instrument(level = "debug", skip_all)]
20pub fn batch_multiplicative_inverse<F: Field>(x: &[F]) -> Vec<F> {
21 const CHUNK_SIZE: usize = 1024;
23
24 let n = x.len();
25 let mut result = F::zero_vec(n);
26
27 x.par_chunks(CHUNK_SIZE)
28 .zip(result.par_chunks_mut(CHUNK_SIZE))
29 .for_each(|(x, result)| {
30 batch_multiplicative_inverse_helper(x, result);
31 });
32
33 result
34}
35
36fn batch_multiplicative_inverse_helper<F: Field>(x: &[F], result: &mut [F]) {
38 const WIDTH: usize = 4;
41
42 let n = x.len();
43 assert_eq!(result.len(), n);
44 if n % WIDTH != 0 {
45 return batch_multiplicative_inverse_general(x, result, |x| x.inverse());
49 }
50
51 let x_packed = FieldArray::<F, 4>::pack_slice(x);
52 let result_packed = FieldArray::<F, 4>::pack_slice_mut(result);
53
54 batch_multiplicative_inverse_general(x_packed, result_packed, |x_packed| x_packed.inverse());
55}
56
57pub(crate) fn batch_multiplicative_inverse_general<F, Inv>(x: &[F], result: &mut [F], inv: Inv)
60where
61 F: FieldAlgebra + Copy,
62 Inv: Fn(F) -> F,
63{
64 let n = x.len();
65 assert_eq!(result.len(), n);
66 if n == 0 {
67 return;
68 }
69
70 result[0] = F::ONE;
71 for i in 1..n {
72 result[i] = result[i - 1] * x[i - 1];
73 }
74
75 let product = result[n - 1] * x[n - 1];
76 let mut inv = inv(product);
77
78 for i in (0..n).rev() {
79 result[i] *= inv;
80 inv *= x[i];
81 }
82}