p3_field/
packed.rs

1use core::mem::MaybeUninit;
2use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Sub, SubAssign};
3use core::slice;
4
5use crate::field::Field;
6use crate::FieldAlgebra;
7
8/// A trait to constrain types that can be packed into a packed value.
9///
10/// The `Packable` trait allows us to specify implementations for potentially conflicting types.
11pub trait Packable: 'static + Default + Copy + Send + Sync + PartialEq + Eq {}
12
13/// # Safety
14/// - If `P` implements `PackedField` then `P` must be castable to/from `[P::Value; P::WIDTH]`
15///   without UB.
16pub unsafe trait PackedValue: 'static + Copy + Send + Sync {
17    type Value: Packable;
18
19    const WIDTH: usize;
20
21    fn from_slice(slice: &[Self::Value]) -> &Self;
22    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self;
23
24    /// Similar to `core:array::from_fn`.
25    fn from_fn<F>(f: F) -> Self
26    where
27        F: FnMut(usize) -> Self::Value;
28
29    fn as_slice(&self) -> &[Self::Value];
30    fn as_slice_mut(&mut self) -> &mut [Self::Value];
31
32    fn pack_slice(buf: &[Self::Value]) -> &[Self] {
33        // Sources vary, but this should be true on all platforms we care about.
34        // This should be a const assert, but trait methods can't access `Self` in a const context,
35        // even with inner struct instantiation. So we will trust LLVM to optimize this out.
36        assert!(align_of::<Self>() <= align_of::<Self::Value>());
37        assert!(
38            buf.len() % Self::WIDTH == 0,
39            "Slice length (got {}) must be a multiple of packed field width ({}).",
40            buf.len(),
41            Self::WIDTH
42        );
43        let buf_ptr = buf.as_ptr().cast::<Self>();
44        let n = buf.len() / Self::WIDTH;
45        unsafe { slice::from_raw_parts(buf_ptr, n) }
46    }
47
48    fn pack_slice_with_suffix(buf: &[Self::Value]) -> (&[Self], &[Self::Value]) {
49        let (packed, suffix) = buf.split_at(buf.len() - buf.len() % Self::WIDTH);
50        (Self::pack_slice(packed), suffix)
51    }
52
53    fn pack_slice_mut(buf: &mut [Self::Value]) -> &mut [Self] {
54        assert!(align_of::<Self>() <= align_of::<Self::Value>());
55        assert!(
56            buf.len() % Self::WIDTH == 0,
57            "Slice length (got {}) must be a multiple of packed field width ({}).",
58            buf.len(),
59            Self::WIDTH
60        );
61        let buf_ptr = buf.as_mut_ptr().cast::<Self>();
62        let n = buf.len() / Self::WIDTH;
63        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
64    }
65
66    fn pack_maybe_uninit_slice_mut(
67        buf: &mut [MaybeUninit<Self::Value>],
68    ) -> &mut [MaybeUninit<Self>] {
69        assert!(align_of::<Self>() <= align_of::<Self::Value>());
70        assert!(
71            buf.len() % Self::WIDTH == 0,
72            "Slice length (got {}) must be a multiple of packed field width ({}).",
73            buf.len(),
74            Self::WIDTH
75        );
76        let buf_ptr = buf.as_mut_ptr().cast::<MaybeUninit<Self>>();
77        let n = buf.len() / Self::WIDTH;
78        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
79    }
80
81    fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) {
82        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
83        (Self::pack_slice_mut(packed), suffix)
84    }
85
86    fn pack_maybe_uninit_slice_with_suffix_mut(
87        buf: &mut [MaybeUninit<Self::Value>],
88    ) -> (&mut [MaybeUninit<Self>], &mut [MaybeUninit<Self::Value>]) {
89        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
90        (Self::pack_maybe_uninit_slice_mut(packed), suffix)
91    }
92
93    fn unpack_slice(buf: &[Self]) -> &[Self::Value] {
94        assert!(align_of::<Self>() >= align_of::<Self::Value>());
95        let buf_ptr = buf.as_ptr().cast::<Self::Value>();
96        let n = buf.len() * Self::WIDTH;
97        unsafe { slice::from_raw_parts(buf_ptr, n) }
98    }
99}
100
101unsafe impl<T: Packable, const WIDTH: usize> PackedValue for [T; WIDTH] {
102    type Value = T;
103    const WIDTH: usize = WIDTH;
104
105    fn from_slice(slice: &[Self::Value]) -> &Self {
106        assert_eq!(slice.len(), Self::WIDTH);
107        slice.try_into().unwrap()
108    }
109
110    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
111        assert_eq!(slice.len(), Self::WIDTH);
112        slice.try_into().unwrap()
113    }
114
115    fn from_fn<F>(f: F) -> Self
116    where
117        F: FnMut(usize) -> Self::Value,
118    {
119        core::array::from_fn(f)
120    }
121
122    fn as_slice(&self) -> &[Self::Value] {
123        self
124    }
125
126    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
127        self
128    }
129}
130
131/// # Safety
132/// - See `PackedValue` above.
133pub unsafe trait PackedField: FieldAlgebra<F = Self::Scalar>
134    + PackedValue<Value = Self::Scalar>
135    + From<Self::Scalar>
136    + Add<Self::Scalar, Output = Self>
137    + AddAssign<Self::Scalar>
138    + Sub<Self::Scalar, Output = Self>
139    + SubAssign<Self::Scalar>
140    + Mul<Self::Scalar, Output = Self>
141    + MulAssign<Self::Scalar>
142    // TODO: Implement packed / packed division
143    + Div<Self::Scalar, Output = Self>
144{
145    type Scalar: Field;
146}
147
148/// # Safety
149/// - `WIDTH` is assumed to be a power of 2.
150pub unsafe trait PackedFieldPow2: PackedField {
151    /// Take interpret two vectors as chunks of `block_len` elements. Unpack and interleave those
152    /// chunks. This is best seen with an example. If we have:
153    /// ```text
154    /// A = [x0, y0, x1, y1]
155    /// B = [x2, y2, x3, y3]
156    /// ```
157    ///
158    /// then
159    ///
160    /// ```text
161    /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3])
162    /// ```
163    ///
164    /// Pairs that were adjacent in the input are at corresponding positions in the output.
165    ///
166    /// `r` lets us set the size of chunks we're interleaving. If we set `block_len = 2`, then for
167    ///
168    /// ```text
169    /// A = [x0, x1, y0, y1]
170    /// B = [x2, x3, y2, y3]
171    /// ```
172    ///
173    /// we obtain
174    ///
175    /// ```text
176    /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3])
177    /// ```
178    ///
179    /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
180    /// transposing those matrices.
181    ///
182    /// When `block_len = WIDTH`, this operation is a no-op. `block_len` must divide `WIDTH`. Since
183    /// `WIDTH` is specified to be a power of 2, `block_len` must also be a power of 2. It cannot be
184    /// 0 and it cannot exceed `WIDTH`.
185    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
186}
187
188unsafe impl<T: Packable> PackedValue for T {
189    type Value = Self;
190
191    const WIDTH: usize = 1;
192
193    fn from_slice(slice: &[Self::Value]) -> &Self {
194        &slice[0]
195    }
196
197    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
198        &mut slice[0]
199    }
200
201    fn from_fn<Fn>(mut f: Fn) -> Self
202    where
203        Fn: FnMut(usize) -> Self::Value,
204    {
205        f(0)
206    }
207
208    fn as_slice(&self) -> &[Self::Value] {
209        slice::from_ref(self)
210    }
211
212    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
213        slice::from_mut(self)
214    }
215}
216
217unsafe impl<F: Field> PackedField for F {
218    type Scalar = Self;
219}
220
221unsafe impl<F: Field> PackedFieldPow2 for F {
222    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
223        match block_len {
224            1 => (*self, other),
225            _ => panic!("unsupported block length"),
226        }
227    }
228}
229
230impl Packable for u8 {}
231
232impl Packable for u16 {}
233
234impl Packable for u32 {}
235
236impl Packable for u64 {}
237
238impl Packable for u128 {}