p3_monty_31/aarch64_neon/
packing.rs

1use alloc::vec::Vec;
2use core::arch::aarch64::{self, int32x4_t, uint32x4_t};
3use core::arch::asm;
4use core::hint::unreachable_unchecked;
5use core::iter::{Product, Sum};
6use core::mem::transmute;
7use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
8
9use p3_field::{Field, FieldAlgebra, PackedField, PackedFieldPow2, PackedValue};
10use p3_util::convert_vec;
11use rand::distributions::{Distribution, Standard};
12use rand::Rng;
13
14use crate::{FieldParameters, MontyField31, PackedMontyParameters};
15
16const WIDTH: usize = 4;
17
18pub trait MontyParametersNeon {
19    const PACKED_P: uint32x4_t;
20    const PACKED_MU: int32x4_t;
21}
22
23/// Vectorized NEON implementation of `MontyField31` arithmetic.
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25#[repr(transparent)] // This needed to make `transmute`s safe.
26pub struct PackedMontyField31Neon<PMP: PackedMontyParameters>(pub [MontyField31<PMP>; WIDTH]);
27
28impl<PMP: PackedMontyParameters> PackedMontyField31Neon<PMP> {
29    #[inline]
30    #[must_use]
31    /// Get an arch-specific vector representing the packed values.
32    fn to_vector(self) -> uint32x4_t {
33        unsafe {
34            // Safety: `MontyField31` is `repr(transparent)` so it can be transmuted to `u32`. It
35            // follows that `[MontyField31; WIDTH]` can be transmuted to `[u32; WIDTH]`, which can be
36            // transmuted to `uint32x4_t`, since arrays are guaranteed to be contiguous in memory.
37            // Finally `PackedMontyField31Neon` is `repr(transparent)` so it can be transmuted to
38            // `[MontyField31; WIDTH]`.
39            transmute(self)
40        }
41    }
42
43    #[inline]
44    #[must_use]
45    /// Make a packed field vector from an arch-specific vector.
46    ///
47    /// SAFETY: The caller must ensure that each element of `vector` represents a valid `MontyField31`.
48    /// In particular, each element of vector must be in `0..P` (canonical form).
49    unsafe fn from_vector(vector: uint32x4_t) -> Self {
50        // Safety: It is up to the user to ensure that elements of `vector` represent valid
51        // `MontyField31` values. We must only reason about memory representations. `uint32x4_t` can be
52        // transmuted to `[u32; WIDTH]` (since arrays elements are contiguous in memory), which can
53        // be transmuted to `[MontyField31; WIDTH]` (since `MontyField31` is `repr(transparent)`), which in
54        // turn can be transmuted to `PackedMontyField31Neon` (since `PackedMontyField31Neon` is also
55        // `repr(transparent)`).
56        transmute(vector)
57    }
58
59    /// Copy `value` to all positions in a packed vector. This is the same as
60    /// `From<MontyField31>::from`, but `const`.
61    #[inline]
62    #[must_use]
63    const fn broadcast(value: MontyField31<PMP>) -> Self {
64        Self([value; WIDTH])
65    }
66}
67
68impl<PMP: PackedMontyParameters> Add for PackedMontyField31Neon<PMP> {
69    type Output = Self;
70    #[inline]
71    fn add(self, rhs: Self) -> Self {
72        let lhs = self.to_vector();
73        let rhs = rhs.to_vector();
74        let res = add::<PMP>(lhs, rhs);
75        unsafe {
76            // Safety: `add` returns values in canonical form when given values in canonical form.
77            Self::from_vector(res)
78        }
79    }
80}
81
82impl<PMP: PackedMontyParameters> Mul for PackedMontyField31Neon<PMP> {
83    type Output = Self;
84    #[inline]
85    fn mul(self, rhs: Self) -> Self {
86        let lhs = self.to_vector();
87        let rhs = rhs.to_vector();
88        let res = mul::<PMP>(lhs, rhs);
89        unsafe {
90            // Safety: `mul` returns values in canonical form when given values in canonical form.
91            Self::from_vector(res)
92        }
93    }
94}
95
96impl<PMP: PackedMontyParameters> Neg for PackedMontyField31Neon<PMP> {
97    type Output = Self;
98    #[inline]
99    fn neg(self) -> Self {
100        let val = self.to_vector();
101        let res = neg::<PMP>(val);
102        unsafe {
103            // Safety: `neg` returns values in canonical form when given values in canonical form.
104            Self::from_vector(res)
105        }
106    }
107}
108
109impl<PMP: PackedMontyParameters> Sub for PackedMontyField31Neon<PMP> {
110    type Output = Self;
111    #[inline]
112    fn sub(self, rhs: Self) -> Self {
113        let lhs = self.to_vector();
114        let rhs = rhs.to_vector();
115        let res = sub::<PMP>(lhs, rhs);
116        unsafe {
117            // Safety: `sub` returns values in canonical form when given values in canonical form.
118            Self::from_vector(res)
119        }
120    }
121}
122
123/// No-op. Prevents the compiler from deducing the value of the vector.
124///
125/// Similar to `std::hint::black_box`, it can be used to stop the compiler applying undesirable
126/// "optimizations". Unlike the built-in `black_box`, it does not force the value to be written to
127/// and then read from the stack.
128#[inline]
129#[must_use]
130fn confuse_compiler(x: uint32x4_t) -> uint32x4_t {
131    let y;
132    unsafe {
133        asm!(
134            "/*{0:v}*/",
135            inlateout(vreg) x => y,
136            options(nomem, nostack, preserves_flags, pure),
137        );
138        // Below tells the compiler the semantics of this so it can still do constant folding, etc.
139        // You may ask, doesn't it defeat the point of the inline asm block to tell the compiler
140        // what it does? The answer is that we still inhibit the transform we want to avoid, so
141        // apparently not. Idk, LLVM works in mysterious ways.
142        if transmute::<uint32x4_t, [u32; 4]>(x) != transmute::<uint32x4_t, [u32; 4]>(y) {
143            unreachable_unchecked();
144        }
145    }
146    y
147}
148
149/// Add two vectors of Monty31 field elements in canonical form.
150/// If the inputs are not in canonical form, the result is undefined.
151#[inline]
152#[must_use]
153fn add<MPNeon: MontyParametersNeon>(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t {
154    // We want this to compile to:
155    //      add   t.4s, lhs.4s, rhs.4s
156    //      sub   u.4s, t.4s, P.4s
157    //      umin  res.4s, t.4s, u.4s
158    // throughput: .75 cyc/vec (5.33 els/cyc)
159    // latency: 6 cyc
160
161    //   Let `t := lhs + rhs`. We want to return `t mod P`. Recall that `lhs` and `rhs` are in
162    // `0, ..., P - 1`, so `t` is in `0, ..., 2 P - 2 (< 2^32)`. It suffices to return `t` if
163    // `t < P` and `t - P` otherwise.
164    //   Let `u := (t - P) mod 2^32` and `r := unsigned_min(t, u)`.
165    //   If `t` is in `0, ..., P - 1`, then `u` is in `(P - 1 <) 2^32 - P, ..., 2^32 - 1` and
166    // `r = t`. Otherwise `t` is in `P, ..., 2 P - 2`, `u` is in `0, ..., P - 2 (< P)` and `r = u`.
167    // Hence, `r` is `t` if `t < P` and `t - P` otherwise, as desired.
168
169    unsafe {
170        // Safety: If this code got compiled then NEON intrinsics are available.
171        let t = aarch64::vaddq_u32(lhs, rhs);
172        let u = aarch64::vsubq_u32(t, MPNeon::PACKED_P);
173        aarch64::vminq_u32(t, u)
174    }
175}
176
177// MONTGOMERY MULTIPLICATION
178//   This implementation is based on [1] but with changes. The reduction is as follows:
179//
180// Constants: P < 2^31
181//            B = 2^32
182//            μ = P^-1 mod B
183// Input: -P^2 <= C <= P^2
184// Output: -P < D < P such that D = C B^-1 (mod P)
185// Define:
186//   smod_B(a) = r, where -B/2 <= r <= B/2 - 1 and r = a (mod B).
187// Algorithm:
188//   1. Q := smod_B(μ C)
189//   2. D := (C - Q P) / B
190//
191// We first show that the division in step 2. is exact. It suffices to show that C = Q P (mod B). By
192// definition of Q, smod_B, and μ, we have Q P = smod_B(μ C) P = μ C P = P^-1 C P = C (mod B).
193//
194// We also have C - Q P = C (mod P), so thus D = C B^-1 (mod P).
195//
196// It remains to show that D is in the correct range. It suffices to show that -P B < C - Q P < P B.
197// We know that -P^2 <= C <= P^2 and (-B / 2) P <= Q P <= (B/2 - 1) P. Then
198// (1 - B/2) P - P^2 <= C - Q P <= (B/2) P + P^2. Now, P < B/2, so B/2 + P < B and
199// (B/2) P + P^2 < P B; also B/2 - 1 + P < B, so -P B < (1 - B/2) P - P^2.
200// Hence, -P B < C - Q P < P B as desired.
201//
202// [1] Modern Computer Arithmetic, Richard Brent and Paul Zimmermann, Cambridge University Press,
203//     2010, algorithm 2.7.
204
205#[inline]
206#[must_use]
207fn mulby_mu<MPNeon: MontyParametersNeon>(val: int32x4_t) -> int32x4_t {
208    // We want this to compile to:
209    //      mul      res.4s, val.4s, MU.4s
210    // throughput: .25 cyc/vec (16 els/cyc)
211    // latency: 3 cyc
212
213    unsafe { aarch64::vmulq_s32(val, MPNeon::PACKED_MU) }
214}
215
216#[inline]
217#[must_use]
218fn get_c_hi(lhs: int32x4_t, rhs: int32x4_t) -> int32x4_t {
219    // We want this to compile to:
220    //      sqdmulh  c_hi.4s, lhs.4s, rhs.4s
221    // throughput: .25 cyc/vec (16 els/cyc)
222    // latency: 3 cyc
223
224    unsafe {
225        // Get bits 31, ..., 62 of C. Note that `sqdmulh` saturates when the product doesn't fit in
226        // an `i63`, but this cannot happen here due to our bounds on `lhs` and `rhs`.
227        aarch64::vqdmulhq_s32(lhs, rhs)
228    }
229}
230
231#[inline]
232#[must_use]
233fn get_qp_hi<MPNeon: MontyParametersNeon>(lhs: int32x4_t, mu_rhs: int32x4_t) -> int32x4_t {
234    // We want this to compile to:
235    //      mul      q.4s, lhs.4s, mu_rhs.4s
236    //      sqdmulh  qp_hi.4s, q.4s, P.4s
237    // throughput: .5 cyc/vec (8 els/cyc)
238    // latency: 6 cyc
239
240    unsafe {
241        // Form `Q`.
242        let q = aarch64::vmulq_s32(lhs, mu_rhs);
243
244        // Gets bits 31, ..., 62 of Q P. Again, saturation is not an issue because `P` is not
245        // -2**31.
246        aarch64::vqdmulhq_s32(q, aarch64::vreinterpretq_s32_u32(MPNeon::PACKED_P))
247    }
248}
249
250#[inline]
251#[must_use]
252fn get_d(c_hi: int32x4_t, qp_hi: int32x4_t) -> int32x4_t {
253    // We want this to compile to:
254    //      shsub    res.4s, c_hi.4s, qp_hi.4s
255    // throughput: .25 cyc/vec (16 els/cyc)
256    // latency: 2 cyc
257
258    unsafe {
259        // Form D. Note that `c_hi` is C >> 31 and `qp_hi` is (Q P) >> 31, whereas we want
260        // (C - Q P) >> 32, so we need to subtract and divide by 2. Luckily NEON has an instruction
261        // for that! The lowest bit of `c_hi` and `qp_hi` is the same, so the division is exact.
262        aarch64::vhsubq_s32(c_hi, qp_hi)
263    }
264}
265
266#[inline]
267#[must_use]
268fn get_reduced_d<MPNeon: MontyParametersNeon>(c_hi: int32x4_t, qp_hi: int32x4_t) -> uint32x4_t {
269    // We want this to compile to:
270    //      shsub    res.4s, c_hi.4s, qp_hi.4s
271    //      cmgt     underflow.4s, qp_hi.4s, c_hi.4s
272    //      mls      res.4s, underflow.4s, P.4s
273    // throughput: .75 cyc/vec (5.33 els/cyc)
274    // latency: 5 cyc
275
276    unsafe {
277        let d = aarch64::vreinterpretq_u32_s32(get_d(c_hi, qp_hi));
278
279        // Finally, we reduce D to canonical form. D is negative iff `c_hi > qp_hi`, so if that's the
280        // case then we add P. Note that if `c_hi > qp_hi` then `underflow` is -1, so we must
281        // _subtract_ `underflow` * P.
282        let underflow = aarch64::vcltq_s32(c_hi, qp_hi);
283        aarch64::vmlsq_u32(d, confuse_compiler(underflow), MPNeon::PACKED_P)
284    }
285}
286
287#[inline]
288#[must_use]
289fn mul<MPNeon: MontyParametersNeon>(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t {
290    // We want this to compile to:
291    //      sqdmulh  c_hi.4s, lhs.4s, rhs.4s
292    //      mul      mu_rhs.4s, rhs.4s, MU.4s
293    //      mul      q.4s, lhs.4s, mu_rhs.4s
294    //      sqdmulh  qp_hi.4s, q.4s, P.4s
295    //      shsub    res.4s, c_hi.4s, qp_hi.4s
296    //      cmgt     underflow.4s, qp_hi.4s, c_hi.4s
297    //      mls      res.4s, underflow.4s, P.4s
298    // throughput: 1.75 cyc/vec (2.29 els/cyc)
299    // latency: (lhs->) 11 cyc, (rhs->) 14 cyc
300
301    unsafe {
302        // No-op. The inputs are non-negative so we're free to interpret them as signed numbers.
303        let lhs = aarch64::vreinterpretq_s32_u32(lhs);
304        let rhs = aarch64::vreinterpretq_s32_u32(rhs);
305
306        let mu_rhs = mulby_mu::<MPNeon>(rhs);
307        let c_hi = get_c_hi(lhs, rhs);
308        let qp_hi = get_qp_hi::<MPNeon>(lhs, mu_rhs);
309        get_reduced_d::<MPNeon>(c_hi, qp_hi)
310    }
311}
312
313#[inline]
314#[must_use]
315fn cube<MPNeon: MontyParametersNeon>(val: uint32x4_t) -> uint32x4_t {
316    // throughput: 2.75 cyc/vec (1.45 els/cyc)
317    // latency: 22 cyc
318
319    unsafe {
320        let val = aarch64::vreinterpretq_s32_u32(val);
321        let mu_val = mulby_mu::<MPNeon>(val);
322
323        let c_hi_2 = get_c_hi(val, val);
324        let qp_hi_2 = get_qp_hi::<MPNeon>(val, mu_val);
325        let val_2 = get_d(c_hi_2, qp_hi_2);
326
327        let c_hi_3 = get_c_hi(val_2, val);
328        let qp_hi_3 = get_qp_hi::<MPNeon>(val_2, mu_val);
329        get_reduced_d::<MPNeon>(c_hi_3, qp_hi_3)
330    }
331}
332
333/// Negate a vector of Monty31 field elements in canonical form.
334/// If the inputs are not in canonical form, the result is undefined.
335#[inline]
336#[must_use]
337fn neg<MPNeon: MontyParametersNeon>(val: uint32x4_t) -> uint32x4_t {
338    // We want this to compile to:
339    //      sub   t.4s, P.4s, val.4s
340    //      cmeq  is_zero.4s, val.4s, #0
341    //      bic   res.4s, t.4s, is_zero.4s
342    // throughput: .75 cyc/vec (5.33 els/cyc)
343    // latency: 4 cyc
344
345    // This has the same throughput as `sub(0, val)` but slightly lower latency.
346
347    //   We want to return (-val) mod P. This is equivalent to returning `0` if `val = 0` and
348    // `P - val` otherwise, since `val` is in `0, ..., P - 1`.
349    //   Let `t := P - val` and let `is_zero := (-1) mod 2^32` if `val = 0` and `0` otherwise.
350    //   We return `r := t & ~is_zero`, which is `t` if `val > 0` and `0` otherwise, as desired.
351    unsafe {
352        // Safety: If this code got compiled then NEON intrinsics are available.
353        let t = aarch64::vsubq_u32(MPNeon::PACKED_P, val);
354        let is_zero = aarch64::vceqzq_u32(val);
355        aarch64::vbicq_u32(t, is_zero)
356    }
357}
358
359/// Subtract vectors of Monty31 field elements in canonical form.
360/// If the inputs are not in canonical form, the result is undefined.
361#[inline]
362#[must_use]
363fn sub<MPNeon: MontyParametersNeon>(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t {
364    // We want this to compile to:
365    //      sub   res.4s, lhs.4s, rhs.4s
366    //      cmhi  underflow.4s, rhs.4s, lhs.4s
367    //      mls   res.4s, underflow.4s, P.4s
368    // throughput: .75 cyc/vec (5.33 els/cyc)
369    // latency: 5 cyc
370
371    //   Let `d := lhs - rhs`. We want to return `d mod P`.
372    //   Since `lhs` and `rhs` are both in `0, ..., P - 1`, `d` is in `-P + 1, ..., P - 1`. It
373    // suffices to return `d + P` if `d < 0` and `d` otherwise.
374    //   Equivalently, we return `d + P` if `rhs > lhs` and `d` otherwise.  Observe that this
375    // permits us to perform all calculations `mod 2^32`, so define `diff := d mod 2^32`.
376    //   Let `underflow` be `-1 mod 2^32` if `rhs > lhs` and `0` otherwise.
377    //   Finally, let `r := (diff - underflow * P) mod 2^32` and observe that
378    // `r = (diff + P) mod 2^32` if `rhs > lhs` and `diff` otherwise, as desired.
379    unsafe {
380        // Safety: If this code got compiled then NEON intrinsics are available.
381        let diff = aarch64::vsubq_u32(lhs, rhs);
382        let underflow = aarch64::vcltq_u32(lhs, rhs);
383        // We really want to emit a `mls` instruction here. The compiler knows that `underflow` is
384        // either 0 or -1 and will try to do an `and` and `add` instead, which is slower on the M1.
385        // The `confuse_compiler` prevents this "optimization".
386        aarch64::vmlsq_u32(diff, confuse_compiler(underflow), MPNeon::PACKED_P)
387    }
388}
389
390impl<PMP: PackedMontyParameters> From<MontyField31<PMP>> for PackedMontyField31Neon<PMP> {
391    #[inline]
392    fn from(value: MontyField31<PMP>) -> Self {
393        Self::broadcast(value)
394    }
395}
396
397impl<PMP: PackedMontyParameters> Default for PackedMontyField31Neon<PMP> {
398    #[inline]
399    fn default() -> Self {
400        MontyField31::<PMP>::default().into()
401    }
402}
403
404impl<PMP: PackedMontyParameters> AddAssign for PackedMontyField31Neon<PMP> {
405    #[inline]
406    fn add_assign(&mut self, rhs: Self) {
407        *self = *self + rhs;
408    }
409}
410
411impl<PMP: PackedMontyParameters> MulAssign for PackedMontyField31Neon<PMP> {
412    #[inline]
413    fn mul_assign(&mut self, rhs: Self) {
414        *self = *self * rhs;
415    }
416}
417
418impl<PMP: PackedMontyParameters> SubAssign for PackedMontyField31Neon<PMP> {
419    #[inline]
420    fn sub_assign(&mut self, rhs: Self) {
421        *self = *self - rhs;
422    }
423}
424
425impl<FP: FieldParameters> Sum for PackedMontyField31Neon<FP> {
426    #[inline]
427    fn sum<I>(iter: I) -> Self
428    where
429        I: Iterator<Item = Self>,
430    {
431        iter.reduce(|lhs, rhs| lhs + rhs).unwrap_or(Self::ZERO)
432    }
433}
434
435impl<FP: FieldParameters> Product for PackedMontyField31Neon<FP> {
436    #[inline]
437    fn product<I>(iter: I) -> Self
438    where
439        I: Iterator<Item = Self>,
440    {
441        iter.reduce(|lhs, rhs| lhs * rhs).unwrap_or(Self::ONE)
442    }
443}
444
445impl<FP: FieldParameters> FieldAlgebra for PackedMontyField31Neon<FP> {
446    type F = MontyField31<FP>;
447
448    const ZERO: Self = Self::broadcast(MontyField31::ZERO);
449    const ONE: Self = Self::broadcast(MontyField31::ONE);
450    const TWO: Self = Self::broadcast(MontyField31::TWO);
451    const NEG_ONE: Self = Self::broadcast(MontyField31::NEG_ONE);
452
453    #[inline]
454    fn from_f(f: Self::F) -> Self {
455        f.into()
456    }
457    #[inline]
458    fn from_canonical_u8(n: u8) -> Self {
459        MontyField31::from_canonical_u8(n).into()
460    }
461    #[inline]
462    fn from_canonical_u16(n: u16) -> Self {
463        MontyField31::from_canonical_u16(n).into()
464    }
465    #[inline]
466    fn from_canonical_u32(n: u32) -> Self {
467        MontyField31::from_canonical_u32(n).into()
468    }
469    #[inline]
470    fn from_canonical_u64(n: u64) -> Self {
471        MontyField31::from_canonical_u64(n).into()
472    }
473    #[inline]
474    fn from_canonical_usize(n: usize) -> Self {
475        MontyField31::from_canonical_usize(n).into()
476    }
477
478    #[inline]
479    fn from_wrapped_u32(n: u32) -> Self {
480        MontyField31::from_wrapped_u32(n).into()
481    }
482    #[inline]
483    fn from_wrapped_u64(n: u64) -> Self {
484        MontyField31::from_wrapped_u64(n).into()
485    }
486
487    #[inline]
488    fn cube(&self) -> Self {
489        let val = self.to_vector();
490        let res = cube::<FP>(val);
491        unsafe {
492            // Safety: `cube` returns values in canonical form when given values in canonical form.
493            Self::from_vector(res)
494        }
495    }
496
497    #[inline(always)]
498    fn zero_vec(len: usize) -> Vec<Self> {
499        // SAFETY: this is a repr(transparent) wrapper around an array.
500        unsafe { convert_vec(Self::F::zero_vec(len * WIDTH)) }
501    }
502}
503
504impl<PMP: PackedMontyParameters> Add<MontyField31<PMP>> for PackedMontyField31Neon<PMP> {
505    type Output = Self;
506    #[inline]
507    fn add(self, rhs: MontyField31<PMP>) -> Self {
508        self + Self::from(rhs)
509    }
510}
511
512impl<PMP: PackedMontyParameters> Mul<MontyField31<PMP>> for PackedMontyField31Neon<PMP> {
513    type Output = Self;
514    #[inline]
515    fn mul(self, rhs: MontyField31<PMP>) -> Self {
516        self * Self::from(rhs)
517    }
518}
519
520impl<PMP: PackedMontyParameters> Sub<MontyField31<PMP>> for PackedMontyField31Neon<PMP> {
521    type Output = Self;
522    #[inline]
523    fn sub(self, rhs: MontyField31<PMP>) -> Self {
524        self - Self::from(rhs)
525    }
526}
527
528impl<PMP: PackedMontyParameters> AddAssign<MontyField31<PMP>> for PackedMontyField31Neon<PMP> {
529    #[inline]
530    fn add_assign(&mut self, rhs: MontyField31<PMP>) {
531        *self += Self::from(rhs)
532    }
533}
534
535impl<PMP: PackedMontyParameters> MulAssign<MontyField31<PMP>> for PackedMontyField31Neon<PMP> {
536    #[inline]
537    fn mul_assign(&mut self, rhs: MontyField31<PMP>) {
538        *self *= Self::from(rhs)
539    }
540}
541
542impl<PMP: PackedMontyParameters> SubAssign<MontyField31<PMP>> for PackedMontyField31Neon<PMP> {
543    #[inline]
544    fn sub_assign(&mut self, rhs: MontyField31<PMP>) {
545        *self -= Self::from(rhs)
546    }
547}
548
549impl<FP: FieldParameters> Sum<MontyField31<FP>> for PackedMontyField31Neon<FP> {
550    #[inline]
551    fn sum<I>(iter: I) -> Self
552    where
553        I: Iterator<Item = MontyField31<FP>>,
554    {
555        iter.sum::<MontyField31<FP>>().into()
556    }
557}
558
559impl<FP: FieldParameters> Product<MontyField31<FP>> for PackedMontyField31Neon<FP> {
560    #[inline]
561    fn product<I>(iter: I) -> Self
562    where
563        I: Iterator<Item = MontyField31<FP>>,
564    {
565        iter.product::<MontyField31<FP>>().into()
566    }
567}
568
569impl<FP: FieldParameters> Div<MontyField31<FP>> for PackedMontyField31Neon<FP> {
570    type Output = Self;
571    #[allow(clippy::suspicious_arithmetic_impl)]
572    #[inline]
573    fn div(self, rhs: MontyField31<FP>) -> Self {
574        self * rhs.inverse()
575    }
576}
577
578impl<PMP: PackedMontyParameters> Add<PackedMontyField31Neon<PMP>> for MontyField31<PMP> {
579    type Output = PackedMontyField31Neon<PMP>;
580    #[inline]
581    fn add(self, rhs: PackedMontyField31Neon<PMP>) -> PackedMontyField31Neon<PMP> {
582        PackedMontyField31Neon::<PMP>::from(self) + rhs
583    }
584}
585
586impl<PMP: PackedMontyParameters> Mul<PackedMontyField31Neon<PMP>> for MontyField31<PMP> {
587    type Output = PackedMontyField31Neon<PMP>;
588    #[inline]
589    fn mul(self, rhs: PackedMontyField31Neon<PMP>) -> PackedMontyField31Neon<PMP> {
590        PackedMontyField31Neon::<PMP>::from(self) * rhs
591    }
592}
593
594impl<PMP: PackedMontyParameters> Sub<PackedMontyField31Neon<PMP>> for MontyField31<PMP> {
595    type Output = PackedMontyField31Neon<PMP>;
596    #[inline]
597    fn sub(self, rhs: PackedMontyField31Neon<PMP>) -> PackedMontyField31Neon<PMP> {
598        PackedMontyField31Neon::<PMP>::from(self) - rhs
599    }
600}
601
602impl<PMP: PackedMontyParameters> Distribution<PackedMontyField31Neon<PMP>> for Standard {
603    #[inline]
604    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedMontyField31Neon<PMP> {
605        PackedMontyField31Neon::<PMP>(rng.gen())
606    }
607}
608
609#[inline]
610#[must_use]
611fn interleave1(v0: uint32x4_t, v1: uint32x4_t) -> (uint32x4_t, uint32x4_t) {
612    // We want this to compile to:
613    //      trn1  res0.4s, v0.4s, v1.4s
614    //      trn2  res1.4s, v0.4s, v1.4s
615    // throughput: .5 cyc/2 vec (16 els/cyc)
616    // latency: 2 cyc
617    unsafe {
618        // Safety: If this code got compiled then NEON intrinsics are available.
619        (aarch64::vtrn1q_u32(v0, v1), aarch64::vtrn2q_u32(v0, v1))
620    }
621}
622
623#[inline]
624#[must_use]
625fn interleave2(v0: uint32x4_t, v1: uint32x4_t) -> (uint32x4_t, uint32x4_t) {
626    // We want this to compile to:
627    //      trn1  res0.2d, v0.2d, v1.2d
628    //      trn2  res1.2d, v0.2d, v1.2d
629    // throughput: .5 cyc/2 vec (16 els/cyc)
630    // latency: 2 cyc
631
632    // To transpose 64-bit blocks, cast the [u32; 4] vectors to [u64; 2], transpose, and cast back.
633    unsafe {
634        // Safety: If this code got compiled then NEON intrinsics are available.
635        let v0 = aarch64::vreinterpretq_u64_u32(v0);
636        let v1 = aarch64::vreinterpretq_u64_u32(v1);
637        (
638            aarch64::vreinterpretq_u32_u64(aarch64::vtrn1q_u64(v0, v1)),
639            aarch64::vreinterpretq_u32_u64(aarch64::vtrn2q_u64(v0, v1)),
640        )
641    }
642}
643
644unsafe impl<FP: FieldParameters> PackedValue for PackedMontyField31Neon<FP> {
645    type Value = MontyField31<FP>;
646    const WIDTH: usize = WIDTH;
647
648    #[inline]
649    fn from_slice(slice: &[MontyField31<FP>]) -> &Self {
650        assert_eq!(slice.len(), Self::WIDTH);
651        unsafe {
652            // Safety: `[MontyField31; WIDTH]` can be transmuted to `PackedMontyField31Neon` since the
653            // latter is `repr(transparent)`. They have the same alignment, so the reference cast is
654            // safe too.
655            &*slice.as_ptr().cast()
656        }
657    }
658    #[inline]
659    fn from_slice_mut(slice: &mut [MontyField31<FP>]) -> &mut Self {
660        assert_eq!(slice.len(), Self::WIDTH);
661        unsafe {
662            // Safety: `[MontyField31; WIDTH]` can be transmuted to `PackedMontyField31Neon` since the
663            // latter is `repr(transparent)`. They have the same alignment, so the reference cast is
664            // safe too.
665            &mut *slice.as_mut_ptr().cast()
666        }
667    }
668
669    /// Similar to `core:array::from_fn`.
670    #[inline]
671    fn from_fn<F: FnMut(usize) -> MontyField31<FP>>(f: F) -> Self {
672        let vals_arr: [_; WIDTH] = core::array::from_fn(f);
673        Self(vals_arr)
674    }
675
676    #[inline]
677    fn as_slice(&self) -> &[MontyField31<FP>] {
678        &self.0[..]
679    }
680    #[inline]
681    fn as_slice_mut(&mut self) -> &mut [MontyField31<FP>] {
682        &mut self.0[..]
683    }
684}
685
686unsafe impl<FP: FieldParameters> PackedField for PackedMontyField31Neon<FP> {
687    type Scalar = MontyField31<FP>;
688}
689
690unsafe impl<FP: FieldParameters> PackedFieldPow2 for PackedMontyField31Neon<FP> {
691    #[inline]
692    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
693        let (v0, v1) = (self.to_vector(), other.to_vector());
694        let (res0, res1) = match block_len {
695            1 => interleave1(v0, v1),
696            2 => interleave2(v0, v1),
697            4 => (v0, v1),
698            _ => panic!("unsupported block_len"),
699        };
700        unsafe {
701            // Safety: all values are in canonical form (we haven't changed them).
702            (Self::from_vector(res0), Self::from_vector(res1))
703        }
704    }
705}