p3_poseidon2_air/
vectorized.rs

1use core::borrow::{Borrow, BorrowMut};
2
3use p3_air::{Air, AirBuilder, BaseAir};
4use p3_field::{PrimeCharacteristicRing, PrimeField};
5use p3_matrix::Matrix;
6use p3_matrix::dense::RowMajorMatrix;
7use p3_poseidon2::GenericPoseidon2LinearLayers;
8use rand::distr::StandardUniform;
9use rand::prelude::Distribution;
10use rand::rngs::SmallRng;
11use rand::{Rng, SeedableRng};
12
13use crate::air::eval;
14use crate::constants::RoundConstants;
15use crate::{Poseidon2Air, Poseidon2Cols, generate_vectorized_trace_rows};
16
17/// A "vectorized" version of Poseidon2Cols, for computing multiple Poseidon2 permutations per row.
18#[repr(C)]
19pub struct VectorizedPoseidon2Cols<
20    T,
21    const WIDTH: usize,
22    const SBOX_DEGREE: u64,
23    const SBOX_REGISTERS: usize,
24    const HALF_FULL_ROUNDS: usize,
25    const PARTIAL_ROUNDS: usize,
26    const VECTOR_LEN: usize,
27> {
28    pub(crate) cols:
29        [Poseidon2Cols<T, WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>;
30            VECTOR_LEN],
31}
32
33impl<
34    T,
35    const WIDTH: usize,
36    const SBOX_DEGREE: u64,
37    const SBOX_REGISTERS: usize,
38    const HALF_FULL_ROUNDS: usize,
39    const PARTIAL_ROUNDS: usize,
40    const VECTOR_LEN: usize,
41>
42    Borrow<
43        VectorizedPoseidon2Cols<
44            T,
45            WIDTH,
46            SBOX_DEGREE,
47            SBOX_REGISTERS,
48            HALF_FULL_ROUNDS,
49            PARTIAL_ROUNDS,
50            VECTOR_LEN,
51        >,
52    > for [T]
53{
54    fn borrow(
55        &self,
56    ) -> &VectorizedPoseidon2Cols<
57        T,
58        WIDTH,
59        SBOX_DEGREE,
60        SBOX_REGISTERS,
61        HALF_FULL_ROUNDS,
62        PARTIAL_ROUNDS,
63        VECTOR_LEN,
64    > {
65        // debug_assert_eq!(self.len(), NUM_COLS);
66        let (prefix, shorts, suffix) = unsafe {
67            self.align_to::<VectorizedPoseidon2Cols<
68                T,
69                WIDTH,
70                SBOX_DEGREE,
71                SBOX_REGISTERS,
72                HALF_FULL_ROUNDS,
73                PARTIAL_ROUNDS,
74                VECTOR_LEN,
75            >>()
76        };
77        debug_assert!(prefix.is_empty(), "Alignment should match");
78        debug_assert!(suffix.is_empty(), "Alignment should match");
79        debug_assert_eq!(shorts.len(), 1);
80        &shorts[0]
81    }
82}
83
84impl<
85    T,
86    const WIDTH: usize,
87    const SBOX_DEGREE: u64,
88    const SBOX_REGISTERS: usize,
89    const HALF_FULL_ROUNDS: usize,
90    const PARTIAL_ROUNDS: usize,
91    const VECTOR_LEN: usize,
92>
93    BorrowMut<
94        VectorizedPoseidon2Cols<
95            T,
96            WIDTH,
97            SBOX_DEGREE,
98            SBOX_REGISTERS,
99            HALF_FULL_ROUNDS,
100            PARTIAL_ROUNDS,
101            VECTOR_LEN,
102        >,
103    > for [T]
104{
105    fn borrow_mut(
106        &mut self,
107    ) -> &mut VectorizedPoseidon2Cols<
108        T,
109        WIDTH,
110        SBOX_DEGREE,
111        SBOX_REGISTERS,
112        HALF_FULL_ROUNDS,
113        PARTIAL_ROUNDS,
114        VECTOR_LEN,
115    > {
116        // debug_assert_eq!(self.len(), NUM_COLS);
117        let (prefix, shorts, suffix) = unsafe {
118            self.align_to_mut::<VectorizedPoseidon2Cols<
119                T,
120                WIDTH,
121                SBOX_DEGREE,
122                SBOX_REGISTERS,
123                HALF_FULL_ROUNDS,
124                PARTIAL_ROUNDS,
125                VECTOR_LEN,
126            >>()
127        };
128        debug_assert!(prefix.is_empty(), "Alignment should match");
129        debug_assert!(suffix.is_empty(), "Alignment should match");
130        debug_assert_eq!(shorts.len(), 1);
131        &mut shorts[0]
132    }
133}
134
135/// A "vectorized" version of Poseidon2Air, for computing multiple Poseidon2 permutations per row.
136pub struct VectorizedPoseidon2Air<
137    F: PrimeCharacteristicRing,
138    LinearLayers,
139    const WIDTH: usize,
140    const SBOX_DEGREE: u64,
141    const SBOX_REGISTERS: usize,
142    const HALF_FULL_ROUNDS: usize,
143    const PARTIAL_ROUNDS: usize,
144    const VECTOR_LEN: usize,
145> {
146    pub(crate) air: Poseidon2Air<
147        F,
148        LinearLayers,
149        WIDTH,
150        SBOX_DEGREE,
151        SBOX_REGISTERS,
152        HALF_FULL_ROUNDS,
153        PARTIAL_ROUNDS,
154    >,
155}
156
157impl<
158    F: PrimeCharacteristicRing,
159    LinearLayers,
160    const WIDTH: usize,
161    const SBOX_DEGREE: u64,
162    const SBOX_REGISTERS: usize,
163    const HALF_FULL_ROUNDS: usize,
164    const PARTIAL_ROUNDS: usize,
165    const VECTOR_LEN: usize,
166>
167    VectorizedPoseidon2Air<
168        F,
169        LinearLayers,
170        WIDTH,
171        SBOX_DEGREE,
172        SBOX_REGISTERS,
173        HALF_FULL_ROUNDS,
174        PARTIAL_ROUNDS,
175        VECTOR_LEN,
176    >
177{
178    pub const fn new(
179        constants: RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
180    ) -> Self {
181        Self {
182            air: Poseidon2Air::new(constants),
183        }
184    }
185
186    pub fn generate_vectorized_trace_rows(
187        &self,
188        num_hashes: usize,
189        extra_capacity_bits: usize,
190    ) -> RowMajorMatrix<F>
191    where
192        F: PrimeField,
193        LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
194        StandardUniform: Distribution<[F; WIDTH]>,
195    {
196        let mut rng = SmallRng::seed_from_u64(1);
197        let inputs = (0..num_hashes).map(|_| rng.random()).collect();
198        generate_vectorized_trace_rows::<
199            _,
200            LinearLayers,
201            WIDTH,
202            SBOX_DEGREE,
203            SBOX_REGISTERS,
204            HALF_FULL_ROUNDS,
205            PARTIAL_ROUNDS,
206            VECTOR_LEN,
207        >(inputs, &self.air.constants, extra_capacity_bits)
208    }
209}
210
211impl<
212    F: PrimeCharacteristicRing + Sync,
213    LinearLayers: Sync,
214    const WIDTH: usize,
215    const SBOX_DEGREE: u64,
216    const SBOX_REGISTERS: usize,
217    const HALF_FULL_ROUNDS: usize,
218    const PARTIAL_ROUNDS: usize,
219    const VECTOR_LEN: usize,
220> BaseAir<F>
221    for VectorizedPoseidon2Air<
222        F,
223        LinearLayers,
224        WIDTH,
225        SBOX_DEGREE,
226        SBOX_REGISTERS,
227        HALF_FULL_ROUNDS,
228        PARTIAL_ROUNDS,
229        VECTOR_LEN,
230    >
231{
232    fn width(&self) -> usize {
233        self.air.width() * VECTOR_LEN
234    }
235}
236
237impl<
238    AB: AirBuilder,
239    LinearLayers: GenericPoseidon2LinearLayers<WIDTH>,
240    const WIDTH: usize,
241    const SBOX_DEGREE: u64,
242    const SBOX_REGISTERS: usize,
243    const HALF_FULL_ROUNDS: usize,
244    const PARTIAL_ROUNDS: usize,
245    const VECTOR_LEN: usize,
246> Air<AB>
247    for VectorizedPoseidon2Air<
248        AB::F,
249        LinearLayers,
250        WIDTH,
251        SBOX_DEGREE,
252        SBOX_REGISTERS,
253        HALF_FULL_ROUNDS,
254        PARTIAL_ROUNDS,
255        VECTOR_LEN,
256    >
257{
258    #[inline]
259    fn eval(&self, builder: &mut AB) {
260        let main = builder.main();
261        let local = main.row_slice(0).expect("The matrix is empty?");
262        let local: &VectorizedPoseidon2Cols<
263            _,
264            WIDTH,
265            SBOX_DEGREE,
266            SBOX_REGISTERS,
267            HALF_FULL_ROUNDS,
268            PARTIAL_ROUNDS,
269            VECTOR_LEN,
270        > = (*local).borrow();
271        for perm in &local.cols {
272            eval(&self.air, builder, perm);
273        }
274    }
275}